├── .gitignore
├── README.md
├── setup.py
├── src
├── conlleval.py
├── dataset.py
├── models.py
├── run.py
└── task_utils.py
├── train.sh
└── transformers
├── __init__.py
├── commands
├── __init__.py
└── user.py
├── configuration_bert.py
├── configuration_utils.py
├── data
├── __init__.py
├── metrics
│ ├── __init__.py
│ └── squad_metrics.py
└── processors
│ ├── __init__.py
│ ├── glue.py
│ ├── squad.py
│ ├── utils.py
│ └── xnli.py
├── file_utils.py
├── hf_api.py
├── modeling_bert.py
├── modeling_utils.py
├── optimization.py
├── tests
├── __init__.py
├── configuration_common_test.py
├── fixtures
│ ├── input.txt
│ ├── sample_text.txt
│ ├── spiece.model
│ └── test_sentencepiece.model
├── hf_api_test.py
├── modeling_albert_test.py
├── modeling_auto_test.py
├── modeling_bert_test.py
├── modeling_common_test.py
├── modeling_ctrl_test.py
├── modeling_distilbert_test.py
├── modeling_encoder_decoder_test.py
├── modeling_gpt2_test.py
├── modeling_openai_test.py
├── modeling_roberta_test.py
├── modeling_tf_albert_test.py
├── modeling_tf_auto_test.py
├── modeling_tf_bert_test.py
├── modeling_tf_common_test.py
├── modeling_tf_ctrl_test.py
├── modeling_tf_distilbert_test.py
├── modeling_tf_gpt2_test.py
├── modeling_tf_openai_gpt_test.py
├── modeling_tf_roberta_test.py
├── modeling_tf_transfo_xl_test.py
├── modeling_tf_xlm_test.py
├── modeling_tf_xlnet_test.py
├── modeling_transfo_xl_test.py
├── modeling_xlm_test.py
├── modeling_xlnet_test.py
├── optimization_test.py
├── optimization_tf_test.py
├── tokenization_albert_test.py
├── tokenization_auto_test.py
├── tokenization_bert_japanese_test.py
├── tokenization_bert_test.py
├── tokenization_ctrl_test.py
├── tokenization_distilbert_test.py
├── tokenization_gpt2_test.py
├── tokenization_openai_test.py
├── tokenization_roberta_test.py
├── tokenization_tests_commons.py
├── tokenization_transfo_xl_test.py
├── tokenization_utils_test.py
├── tokenization_xlm_test.py
├── tokenization_xlnet_test.py
└── utils.py
├── tokenization_bert.py
└── tokenization_utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 | *.zip
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
Improving BERT with Syntax-aware Local Attention
3 |
4 |
5 | The implementation of the ACL2021 Findings paper "Improving BERT with Syntax-aware Local Attention".
6 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | """
2 | Simple check list from AllenNLP repo: https://github.com/allenai/allennlp/blob/master/setup.py
3 |
4 | To create the package for pypi.
5 |
6 | 1. Change the version in __init__.py, setup.py as well as docs/source/conf.py.
7 |
8 | 2. Commit these changes with the message: "Release: VERSION"
9 |
10 | 3. Add a tag in git to mark the release: "git tag VERSION -m'Adds tag VERSION for pypi' "
11 | Push the tag to git: git push --tags origin master
12 |
13 | 4. Build both the sources and the wheel. Do not change anything in setup.py between
14 | creating the wheel and the source distribution (obviously).
15 |
16 | For the wheel, run: "python setup.py bdist_wheel" in the top level directory.
17 | (this will build a wheel for the python version you use to build it - make sure you use python 3.x).
18 |
19 | For the sources, run: "python setup.py sdist"
20 | You should now have a /dist directory with both .whl and .tar.gz source versions.
21 |
22 | 5. Check that everything looks correct by uploading the package to the pypi test server:
23 |
24 | twine upload dist/* -r pypitest
25 | (pypi suggest using twine as other methods upload files via plaintext.)
26 |
27 | Check that you can install it in a virtualenv by running:
28 | pip install -i https://testpypi.python.org/pypi transformers
29 |
30 | 6. Upload the final version to actual pypi:
31 | twine upload dist/* -r pypi
32 |
33 | 7. Copy the release notes from RELEASE.md to the tag in github once everything is looking hunky-dory.
34 |
35 | """
36 | from io import open
37 | from setuptools import find_packages, setup
38 |
39 |
40 | extras = {
41 | 'serving': ['uvicorn', 'fastapi']
42 | }
43 | extras['all'] = [package for package in extras.values()]
44 |
45 | setup(
46 | name="transformers",
47 | version="2.2.2",
48 | author="Thomas Wolf, Lysandre Debut, Victor Sanh, Julien Chaumond, Google AI Language Team Authors, Open AI team Authors, Facebook AI Authors, Carnegie Mellon University Authors",
49 | author_email="thomas@huggingface.co",
50 | description="State-of-the-art Natural Language Processing for TensorFlow 2.0 and PyTorch",
51 | keywords='NLP deep learning transformer pytorch tensorflow BERT GPT GPT-2 google openai CMU',
52 | license='Apache',
53 | packages=find_packages(exclude=["*.tests", "*.tests.*",
54 | "tests.*", "tests"]),
55 | install_requires=['numpy',
56 | 'boto3',
57 | 'requests',
58 | 'tqdm',
59 | 'regex',
60 | 'sentencepiece',
61 | 'sacremoses'],
62 | extras_require=extras,
63 | # python_requires='>=3.5.0',
64 | classifiers=[
65 | 'Intended Audience :: Science/Research',
66 | 'License :: OSI Approved :: Apache Software License',
67 | 'Programming Language :: Python :: 3',
68 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
69 | ],
70 | )
71 |
--------------------------------------------------------------------------------
/src/conlleval.py:
--------------------------------------------------------------------------------
1 | """
2 | This script applies to IOB2 or IOBES tagging scheme.
3 | If you are using a different scheme, please convert to IOB2 or IOBES.
4 |
5 | IOB2:
6 | - B = begin,
7 | - I = inside but not the first,
8 | - O = outside
9 |
10 | e.g.
11 | John lives in New York City .
12 | B-PER O O B-LOC I-LOC I-LOC O
13 |
14 | IOBES:
15 | - B = begin,
16 | - E = end,
17 | - S = singleton,
18 | - I = inside but not the first or the last,
19 | - O = outside
20 |
21 | e.g.
22 | John lives in New York City .
23 | S-PER O O B-LOC I-LOC E-LOC O
24 |
25 | prefix: IOBES
26 | chunk_type: PER, LOC, etc.
27 | """
28 | from __future__ import division, print_function, unicode_literals
29 |
30 | import sys
31 | from collections import defaultdict
32 |
33 | def split_tag(chunk_tag):
34 | """
35 | split chunk tag into IOBES prefix and chunk_type
36 | e.g.
37 | B-PER -> (B, PER)
38 | O -> (O, None)
39 | """
40 | if chunk_tag == 'O':
41 | return ('O', None)
42 | return chunk_tag.split('-', maxsplit=1)
43 |
44 | def is_chunk_end(prev_tag, tag):
45 | """
46 | check if the previous chunk ended between the previous and current word
47 | e.g.
48 | (B-PER, I-PER) -> False
49 | (B-LOC, O) -> True
50 |
51 | Note: in case of contradicting tags, e.g. (B-PER, I-LOC)
52 | this is considered as (B-PER, B-LOC)
53 | """
54 | prefix1, chunk_type1 = split_tag(prev_tag)
55 | prefix2, chunk_type2 = split_tag(tag)
56 |
57 | if prefix1 == 'O':
58 | return False
59 | if prefix2 == 'O':
60 | return prefix1 != 'O'
61 |
62 | if chunk_type1 != chunk_type2:
63 | return True
64 |
65 | return prefix2 in ['B', 'S'] or prefix1 in ['E', 'S']
66 |
67 | def is_chunk_start(prev_tag, tag):
68 | """
69 | check if a new chunk started between the previous and current word
70 | """
71 | prefix1, chunk_type1 = split_tag(prev_tag)
72 | prefix2, chunk_type2 = split_tag(tag)
73 |
74 | if prefix2 == 'O':
75 | return False
76 | if prefix1 == 'O':
77 | return prefix2 != 'O'
78 |
79 | if chunk_type1 != chunk_type2:
80 | return True
81 |
82 | return prefix2 in ['B', 'S'] or prefix1 in ['E', 'S']
83 |
84 |
85 | def calc_metrics(tp, p, t, percent=True):
86 | """
87 | compute overall precision, recall and FB1 (default values are 0.0)
88 | if percent is True, return 100 * original decimal value
89 | """
90 | precision = tp / p if p else 0
91 | recall = tp / t if t else 0
92 | fb1 = 2 * precision * recall / (precision + recall) if precision + recall else 0
93 | if percent:
94 | return 100 * precision, 100 * recall, 100 * fb1
95 | else:
96 | return precision, recall, fb1
97 |
98 |
99 | def count_chunks(true_seqs, pred_seqs):
100 | """
101 | true_seqs: a list of true tags
102 | pred_seqs: a list of predicted tags
103 |
104 | return:
105 | correct_chunks: a dict (counter),
106 | key = chunk types,
107 | value = number of correctly identified chunks per type
108 | true_chunks: a dict, number of true chunks per type
109 | pred_chunks: a dict, number of identified chunks per type
110 |
111 | correct_counts, true_counts, pred_counts: similar to above, but for tags
112 | """
113 | correct_chunks = defaultdict(int)
114 | true_chunks = defaultdict(int)
115 | pred_chunks = defaultdict(int)
116 |
117 | correct_counts = defaultdict(int)
118 | true_counts = defaultdict(int)
119 | pred_counts = defaultdict(int)
120 |
121 | prev_true_tag, prev_pred_tag = 'O', 'O'
122 | correct_chunk = None
123 |
124 | for true_tag, pred_tag in zip(true_seqs, pred_seqs):
125 | if true_tag == pred_tag:
126 | correct_counts[true_tag] += 1
127 | true_counts[true_tag] += 1
128 | pred_counts[pred_tag] += 1
129 |
130 | _, true_type = split_tag(true_tag)
131 | _, pred_type = split_tag(pred_tag)
132 |
133 | if correct_chunk is not None:
134 | true_end = is_chunk_end(prev_true_tag, true_tag)
135 | pred_end = is_chunk_end(prev_pred_tag, pred_tag)
136 |
137 | if pred_end and true_end:
138 | correct_chunks[correct_chunk] += 1
139 | correct_chunk = None
140 | elif pred_end != true_end or true_type != pred_type:
141 | correct_chunk = None
142 |
143 | true_start = is_chunk_start(prev_true_tag, true_tag)
144 | pred_start = is_chunk_start(prev_pred_tag, pred_tag)
145 |
146 | if true_start and pred_start and true_type == pred_type:
147 | correct_chunk = true_type
148 | if true_start:
149 | true_chunks[true_type] += 1
150 | if pred_start:
151 | pred_chunks[pred_type] += 1
152 |
153 | prev_true_tag, prev_pred_tag = true_tag, pred_tag
154 | if correct_chunk is not None:
155 | correct_chunks[correct_chunk] += 1
156 |
157 | return (correct_chunks, true_chunks, pred_chunks,
158 | correct_counts, true_counts, pred_counts)
159 |
160 | def get_result(correct_chunks, true_chunks, pred_chunks,
161 | correct_counts, true_counts, pred_counts, verbose=True):
162 | """
163 | if verbose, print overall performance, as well as preformance per chunk type;
164 | otherwise, simply return overall prec, rec, f1 scores
165 | """
166 | # sum counts
167 | sum_correct_chunks = sum(correct_chunks.values())
168 | sum_true_chunks = sum(true_chunks.values())
169 | sum_pred_chunks = sum(pred_chunks.values())
170 |
171 | sum_correct_counts = sum(correct_counts.values())
172 | sum_true_counts = sum(true_counts.values())
173 |
174 | nonO_correct_counts = sum(v for k, v in correct_counts.items() if k != 'O')
175 | nonO_true_counts = sum(v for k, v in true_counts.items() if k != 'O')
176 |
177 | chunk_types = sorted(list(set(list(true_chunks) + list(pred_chunks))))
178 |
179 | # compute overall precision, recall and FB1 (default values are 0.0)
180 | prec, rec, f1 = calc_metrics(sum_correct_chunks, sum_pred_chunks, sum_true_chunks)
181 | res = (prec, rec, f1)
182 | if not verbose:
183 | return res
184 |
185 | # print overall performance, and performance per chunk type
186 |
187 | print("processed %i tokens with %i phrases; " % (sum_true_counts, sum_true_chunks), end='')
188 | print("found: %i phrases; correct: %i.\n" % (sum_pred_chunks, sum_correct_chunks), end='')
189 |
190 | print("accuracy: %6.2f%%; (non-O)" % (100*nonO_correct_counts/nonO_true_counts))
191 | print("accuracy: %6.2f%%; " % (100*sum_correct_counts/sum_true_counts), end='')
192 | print("precision: %6.2f%%; recall: %6.2f%%; FB1: %6.2f" % (prec, rec, f1))
193 |
194 | # for each chunk type, compute precision, recall and FB1 (default values are 0.0)
195 | for t in chunk_types:
196 | prec, rec, f1 = calc_metrics(correct_chunks[t], pred_chunks[t], true_chunks[t])
197 | print("%17s: " %t , end='')
198 | print("precision: %6.2f%%; recall: %6.2f%%; FB1: %6.2f" %
199 | (prec, rec, f1), end='')
200 | print(" %d" % pred_chunks[t])
201 |
202 | return res
203 | # you can generate LaTeX output for tables like in
204 | # http://cnts.uia.ac.be/conll2003/ner/example.tex
205 | # but I'm not implementing this
206 |
207 | def evaluate(true_seqs, pred_seqs, verbose=True):
208 | (correct_chunks, true_chunks, pred_chunks,
209 | correct_counts, true_counts, pred_counts) = count_chunks(true_seqs, pred_seqs)
210 | result = get_result(correct_chunks, true_chunks, pred_chunks,
211 | correct_counts, true_counts, pred_counts, verbose=verbose)
212 | return result
213 |
214 | def evaluate_conll_file(fileIterator):
215 | true_seqs, pred_seqs = [], []
216 |
217 | for line in fileIterator:
218 | cols = line.strip().split()
219 | # each non-empty line must contain >= 3 columns
220 | if not cols:
221 | true_seqs.append('O')
222 | pred_seqs.append('O')
223 | elif len(cols) < 3:
224 | raise IOError("conlleval: too few columns in line %s\n" % line)
225 | else:
226 | # extract tags from last 2 columns
227 | true_seqs.append(cols[-2])
228 | pred_seqs.append(cols[-1])
229 | return evaluate(true_seqs, pred_seqs)
230 |
231 | if __name__ == '__main__':
232 | """
233 | usage: conlleval < file
234 | """
235 | evaluate_conll_file(sys.stdin)
--------------------------------------------------------------------------------
/src/dataset.py:
--------------------------------------------------------------------------------
1 | import sys, os, time
2 | import argparse
3 | import re, csv
4 | import pickle
5 | import numpy as np
6 | from transformers import BertTokenizer
7 | import spacy
8 | from spacy.tokens import Token
9 | from spacy.tokenizer import Tokenizer
10 | import task_utils
11 | from tqdm import tqdm
12 | import copy
13 |
14 | global spacy_parser, tokenizer
15 | Token.set_extension('tid', default=0)
16 | spacy_parser = spacy.load("en_core_web_sm", disable=['ner'])
17 | spacy_parser.tokenizer = Tokenizer(spacy_parser.vocab)
18 | tokenizer = {
19 | 'en-cased': BertTokenizer.from_pretrained('bert-base-cased'),
20 | 'en-uncased': BertTokenizer.from_pretrained('bert-base-uncased'),
21 | }
22 |
23 | class InputRawExample(object):
24 | def __init__(self, text, label):
25 | self.text = text
26 | self.label = label
27 |
28 | class Graph(object):
29 | """docstring for Graph"""
30 | def __init__(self, n):
31 | super(Graph, self).__init__()
32 | self.n = n
33 | self.link_list = []
34 | self.vis = [0] * self.n
35 | for i in range(self.n):
36 | self.link_list.append([])
37 |
38 | def add_edge(self, u, v):
39 | if u == v:
40 | return
41 | self.link_list[u].append(v)
42 | self.link_list[v].append(u)
43 |
44 | def bfs(self, start, dist):
45 | que = [start]
46 | self.vis[start] = 1
47 | for _ in range(dist):
48 | que2 = []
49 | for u in que:
50 | #self.vis[u] = 1
51 | for v in self.link_list[u]:
52 | if self.vis[v]:
53 | continue
54 | que2.append(v)
55 | self.vis[v] = 1
56 | que = copy.deepcopy(que2)
57 |
58 |
59 | def solve(self, start, dist):
60 | self.vis = [0] * self.n
61 | self.bfs(start, dist)
62 | self.vis[0] = 1
63 | return copy.deepcopy(self.vis)
64 |
65 | def process(args, text, label):
66 | # text: str
67 | # label: list[str] or int
68 | global spacy_parser, tokenizer
69 | if args.lang == 'en':
70 | local_tokenizer = tokenizer['en-uncased'] if args.do_lower_case else tokenizer['en-cased']
71 |
72 | while ' ' in text:
73 | text = text.replace(' ', ' ')
74 | doc = spacy_parser(text)
75 |
76 | tokens = ['[CLS]']
77 | for token in doc:
78 | token._.tid = len(tokens)
79 | tokens.append(token.text)
80 |
81 | G = Graph(len(tokens))
82 | for token in doc:
83 | if token.dep_ == 'ROOT':
84 | continue
85 | G.add_edge(token._.tid, token.head._.tid)
86 |
87 |
88 | ntokens = []
89 | ws = []
90 |
91 | for i, token in enumerate(tokens):
92 | if token == '[CLS]':
93 | ntokens.append(token)
94 | ws.append(1)
95 | else:
96 | sub_tokens = local_tokenizer.tokenize(token)
97 | ws.append(len(sub_tokens))
98 | for j, st in enumerate(sub_tokens):
99 | ntokens.append(st)
100 |
101 | dep_att_mask = []
102 | for i, token in enumerate(tokens):
103 | vis = G.solve(i, args.dist)
104 |
105 | if i-1>=0:
106 | vis_tmp = G.solve(i-1, args.dist)
107 | for j in range(len(vis_tmp)):
108 | vis[j] |= vis_tmp[j]
109 |
110 | if i+1?@[\\]^_':
155 | quote_idx.append(i)
156 |
157 | for i in range(len(ntokens)):
158 | for j in quote_idx:
159 | dep_att_mask[i][j] = 1
160 |
161 | input_ids = local_tokenizer.convert_tokens_to_ids(ntokens)
162 | else:
163 | raise KeyError(args.lang)
164 |
165 | example = {
166 | 'input_ids': input_ids,
167 | 'dep_att_mask': dep_att_mask,
168 | 'labels': labels,
169 | }
170 | if isinstance(label, list):
171 | example['loss_mask'] = loss_mask
172 | return example
173 |
174 | def _read_tsv(input_file, quotechar=None):
175 | """Reads a tab separated value file."""
176 | with open(input_file, "r", encoding="utf-8-sig") as f:
177 | reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
178 | lines = []
179 | for line in reader:
180 | if sys.version_info[0] == 2:
181 | line = list(unicode(cell, 'utf-8') for cell in line)
182 | lines.append(line)
183 | return lines
184 |
185 | def _read_conll_format_file(input_file):
186 | lines = []
187 | with open(input_file, 'r', encoding='utf-8') as f:
188 | words = []
189 | labels = []
190 | for line in f:
191 | contends = line.strip()
192 | feature_vector = contends.split(' ')
193 | word = feature_vector[0].strip()
194 | label = feature_vector[-1].strip()
195 | if len(contends) == 0:
196 | w = ' '.join(words)
197 | l = ' '.join(labels)
198 | lines.append((w, l))
199 | words = []
200 | labels = []
201 | continue
202 | #word = re.sub(u"([^\u4e00-\u9fa5\u0030-\u0039\u0041-\u005a\u0061-\u007a'!'#$%&\'()*+,-./:;<=>?@,。?★、…【】《》?“”‘'![\\]^_`{|}~])","",word)
203 | if len(word) == 0 or len(label) == 0:
204 | continue
205 | words.append(word)
206 | labels.append(label)
207 | return lines
208 |
209 | def main(args):
210 | raw_examples = []
211 | if args.task == 'cola':
212 | lines = _read_tsv(args.data_file)
213 | for line in lines:
214 | text = line[3].lower() if args.do_lower_case else line[3]
215 | raw_examples.append(InputRawExample(text, int(line[1])))
216 | elif args.task == 'sst-2':
217 | lines = _read_tsv(args.data_file)
218 | for line in lines[1:]:
219 | text = line[0].lower() if args.do_lower_case else line[0]
220 | raw_examples.append(InputRawExample(text, int(line[1])))
221 | elif args.task == 'fce':
222 | lines = _read_conll_format_file(args.data_file)
223 | for w, l in lines:
224 | text = w.lower() if args.do_lower_case else w
225 | raw_examples.append(InputRawExample(text, l.split(' ')))
226 | else:
227 | raise KeyError(args.task)
228 | examples = []
229 | for item in tqdm(raw_examples, desc='Convert'):
230 | examples.append(process(args, item.text, item.label))
231 |
232 | filename = args.data_file.split('/')[-1]
233 | pickle.dump(examples, open('./%s.%s.d%d.pkl'%(args.task, filename, args.dist), 'wb'))
234 |
235 |
236 | if __name__=='__main__':
237 | p = argparse.ArgumentParser()
238 | p.add_argument("--task", default="cola", type=str)
239 | p.add_argument("--data_file", default="cola.tsv", type=str)
240 | p.add_argument("--lang", default="en", type=str)
241 | p.add_argument("--do_lower_case", action="store_true")
242 | p.add_argument("--dist", default=3, type=int)
243 | args = p.parse_args()
244 | main(args)
245 |
--------------------------------------------------------------------------------
/src/models.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function, unicode_literals
2 | import logging
3 | import math
4 | import os
5 | import sys
6 |
7 | import torch
8 | from torch import nn
9 | from torch.nn import CrossEntropyLoss, MSELoss
10 |
11 | from transformers.modeling_bert import *
12 |
13 | class TokenClsModel(BertPreTrainedModel):
14 | def __init__(self, config):
15 | super(TokenClsModel, self).__init__(config)
16 | self.num_labels = config.num_labels
17 |
18 | self.bert = BertModel(config)
19 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
20 | self.classifier = nn.Linear(config.hidden_size, config.num_labels)
21 |
22 | self.init_weights()
23 |
24 | @staticmethod
25 | def build_batch(batch, tokenizer):
26 | return batch
27 |
28 | def forward(self, batch):
29 | input_ids = batch['input_ids']
30 | attention_mask = batch['attention_mask']
31 | dep_att_mask = batch['dep_att_mask']
32 | labels = batch['labels']
33 | loss_mask = batch['loss_mask'] if 'loss_mask' in batch else None
34 |
35 | outputs = self.bert(input_ids,
36 | attention_mask=attention_mask,
37 | dep_att_mask=dep_att_mask)
38 |
39 | sequence_output = outputs[0]
40 |
41 | sequence_output = self.dropout(sequence_output)
42 | logits = self.classifier(sequence_output)
43 |
44 | outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
45 | if labels is not None:
46 | loss_fct = CrossEntropyLoss()
47 | # Only keep active parts of the loss
48 | if loss_mask is not None:
49 | active_loss = loss_mask.view(-1) == 1
50 | active_logits = logits.view(-1, self.num_labels)[active_loss]
51 | active_labels = labels.view(-1)[active_loss]
52 | loss = loss_fct(active_logits, active_labels)
53 | else:
54 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
55 | outputs = (loss,) + outputs
56 | return outputs # (loss), scores, (hidden_states), (attentions)
57 |
58 | class SentClsModel(BertPreTrainedModel):
59 | def __init__(self, config):
60 | super(SentClsModel, self).__init__(config)
61 | self.num_labels = config.num_labels
62 |
63 | self.bert = BertModel(config)
64 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
65 | self.classifier = nn.Linear(config.hidden_size, config.num_labels)
66 |
67 | self.init_weights()
68 |
69 | @staticmethod
70 | def build_batch(batch, tokenizer):
71 | return batch
72 |
73 | def forward(self, batch):
74 | input_ids = batch['input_ids']
75 | attention_mask = batch['attention_mask']
76 | dep_att_mask = batch['dep_att_mask']
77 | labels = batch['labels']
78 |
79 | outputs = self.bert(input_ids,
80 | attention_mask=attention_mask,
81 | dep_att_mask=dep_att_mask)
82 | pooled_output = outputs[1]
83 | pooled_output = self.dropout(pooled_output)
84 | logits = self.classifier(pooled_output)
85 |
86 | outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
87 |
88 | if labels is not None:
89 | loss_fct = CrossEntropyLoss()
90 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
91 | outputs = (loss,) + outputs
92 | return outputs # (loss), logits, (hidden_states), (attentions)
93 |
94 |
--------------------------------------------------------------------------------
/src/task_utils.py:
--------------------------------------------------------------------------------
1 | import sys, os, time
2 | import logging
3 | import numpy as np
4 | logger = logging.getLogger(__name__)
5 |
6 | try:
7 | from scipy.stats import pearsonr, spearmanr
8 | from sklearn.metrics import matthews_corrcoef, f1_score
9 | _has_sklearn = True
10 | except (AttributeError, ImportError) as e:
11 | logger.warning("To use data.metrics please install scikit-learn. See https://scikit-learn.org/stable/index.html")
12 | _has_sklearn = False
13 |
14 | from conlleval import evaluate as con_eval
15 |
16 | num_labels = {
17 | 'cola': 2,
18 | 'sst-2': 2,
19 | 'chns': 2,
20 | 'cged': 9,
21 | 'fce': 6,
22 | }
23 |
24 | class CgedProcessor(object):
25 | def get_labels(self):
26 | return ['O', 'B-R', 'I-R', 'B-M', 'I-M', 'B-S', 'I-S', 'B-W', 'I-W']
27 |
28 | class FceProcessor(object):
29 | def get_labels(self):
30 | return ['O', 'B-R', 'I-R', 'B-U', 'I-U', 'B-M']
31 |
32 | task_processors = {
33 | 'cged': CgedProcessor,
34 | 'fce' : FceProcessor,
35 | }
36 |
37 | def is_sklearn_available():
38 | return _has_sklearn
39 |
40 | if _has_sklearn:
41 |
42 | def simple_accuracy(preds, labels):
43 | return {'acc': (preds == labels).mean()}
44 |
45 |
46 | def acc_and_f1(preds, labels):
47 | acc = simple_accuracy(preds, labels)
48 | f1 = f1_score(y_true=labels, y_pred=preds)
49 | return {
50 | "acc": acc,
51 | "f1": f1,
52 | "acc_and_f1": (acc + f1) / 2,
53 | }
54 |
55 |
56 | def pearson_and_spearman(preds, labels):
57 | pearson_corr = pearsonr(preds, labels)[0]
58 | spearman_corr = spearmanr(preds, labels)[0]
59 | return {
60 | "pearson": pearson_corr,
61 | "spearmanr": spearman_corr,
62 | "corr": (pearson_corr + spearman_corr) / 2,
63 | }
64 |
65 | def evaluate(task, eval_data):
66 | if task == 'cola':
67 | preds = np.array(eval_data['preds'])
68 | labels = np.array(eval_data['labels'])
69 | return {"mcc": matthews_corrcoef(labels, preds)}
70 | elif task == 'chns' or task == 'sst-2':
71 | preds = np.array(eval_data['preds'])
72 | labels = np.array(eval_data['labels'])
73 | return simple_accuracy(preds, labels)
74 | elif task == 'fce':
75 | preds = eval_data['preds']
76 | labels = eval_data['labels']
77 | loss_mask = eval_data['loss_mask']
78 | true_tags, pred_tags = [], []
79 | label_list = FceProcessor().get_labels()
80 | for i in range(len(preds)):
81 | for pid, tid, mask in zip(preds[i], labels[i], loss_mask[i]):
82 | if mask == 1:
83 | pred_tags.append(label_list[pid])
84 | true_tags.append(label_list[tid])
85 | pred_tags.append('O')
86 | true_tags.append('O')
87 | prec, rec, f1 = con_eval(true_tags, pred_tags, verbose=False)
88 | return {
89 | 'precision': prec,
90 | 'recall': rec,
91 | 'f1': f1,
92 | 'f0.5': 1.25 * prec * rec / (rec + 0.25 * prec)
93 | }
94 | else:
95 | raise KeyError(task)
--------------------------------------------------------------------------------
/train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | export INPUT_DIR=/path_to_load/
3 | export OUTPUT_DIR=/path_to_store/
4 |
5 | pip install --user --editable .
6 | python src/run.py \
7 | --task cola \
8 | --model_type sent-cls \
9 | --model_name_or_path bert-base-uncased \
10 | --data_dir $INPUT_DIR \
11 | --output_dir $OUTPUT_DIR \
12 | --train_file cola.train.tsv.d3.pkl \
13 | --dev_file cola.dev.tsv.d3.pkl \
14 | --do_train --do_eval --do_lower_case \
15 | --learning_rate 3e-5 \
16 | --num_train_epochs 2 \
17 | --order_metric mcc \
18 | --metric_reverse \
19 | --remove_unused_ckpts \
20 | --per_gpu_train_batch_size 16 \
21 | --eval_all_checkpoints \
22 | --overwrite_output_dir \
23 | --save_steps 100 \
24 | --seed 17
25 |
26 |
--------------------------------------------------------------------------------
/transformers/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "2.2.2"
2 |
3 | # Work around to update TensorFlow's absl.logging threshold which alters the
4 | # default Python logging output behavior when present.
5 | # see: https://github.com/abseil/abseil-py/issues/99
6 | # and: https://github.com/tensorflow/tensorflow/issues/26691#issuecomment-500369493
7 | try:
8 | import absl.logging
9 | absl.logging.set_verbosity('info')
10 | absl.logging.set_stderrthreshold('info')
11 | absl.logging._warn_preinit_stderr = False
12 | except:
13 | pass
14 |
15 | import logging
16 |
17 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name
18 |
19 | # Files and general utilities
20 | from .file_utils import (TRANSFORMERS_CACHE, PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE,
21 | cached_path, add_start_docstrings, add_end_docstrings,
22 | WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME, CONFIG_NAME,
23 | is_tf_available, is_torch_available)
24 |
25 | from .data import (is_sklearn_available,
26 | InputExample, InputFeatures, DataProcessor,
27 | glue_output_modes, glue_convert_examples_to_features,
28 | glue_processors, glue_tasks_num_labels,
29 | xnli_output_modes, xnli_processors, xnli_tasks_num_labels,
30 | squad_convert_examples_to_features, SquadFeatures,
31 | SquadExample, SquadV1Processor, SquadV2Processor)
32 |
33 | if is_sklearn_available():
34 | from .data import glue_compute_metrics, xnli_compute_metrics
35 |
36 | # Tokenizers
37 | from .tokenization_utils import (PreTrainedTokenizer)
38 | from .tokenization_bert import BertTokenizer, BasicTokenizer, WordpieceTokenizer
39 |
40 | # Configurations
41 | from .configuration_utils import PretrainedConfig
42 | from .configuration_bert import BertConfig, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
43 |
44 | # Modeling
45 | if is_torch_available():
46 | from .modeling_utils import (PreTrainedModel, prune_layer, Conv1D)
47 |
48 | from .modeling_bert import (BertModel,
49 | load_tf_weights_in_bert, BERT_PRETRAINED_MODEL_ARCHIVE_MAP)
50 |
51 | # Optimization
52 | from .optimization import (AdamW, get_constant_schedule, get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup,
53 | get_cosine_with_hard_restarts_schedule_with_warmup, get_linear_schedule_with_warmup)
54 |
55 | if not is_tf_available() and not is_torch_available():
56 | logger.warning("Neither PyTorch nor TensorFlow >= 2.0 have been found."
57 | "Models won't be available and only tokenizers, configuration"
58 | "and file/data utilities can be used.")
59 |
--------------------------------------------------------------------------------
/transformers/commands/__init__.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from argparse import ArgumentParser
3 |
4 | class BaseTransformersCLICommand(ABC):
5 | @staticmethod
6 | @abstractmethod
7 | def register_subcommand(parser: ArgumentParser):
8 | raise NotImplementedError()
9 |
10 | @abstractmethod
11 | def run(self):
12 | raise NotImplementedError()
13 |
--------------------------------------------------------------------------------
/transformers/commands/user.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 | from getpass import getpass
3 | import os
4 |
5 | from transformers.commands import BaseTransformersCLICommand
6 | from transformers.hf_api import HfApi, HfFolder, HTTPError
7 |
8 |
9 | class UserCommands(BaseTransformersCLICommand):
10 | @staticmethod
11 | def register_subcommand(parser: ArgumentParser):
12 | login_parser = parser.add_parser('login')
13 | login_parser.set_defaults(func=lambda args: LoginCommand(args))
14 | whoami_parser = parser.add_parser('whoami')
15 | whoami_parser.set_defaults(func=lambda args: WhoamiCommand(args))
16 | logout_parser = parser.add_parser('logout')
17 | logout_parser.set_defaults(func=lambda args: LogoutCommand(args))
18 | list_parser = parser.add_parser('ls')
19 | list_parser.set_defaults(func=lambda args: ListObjsCommand(args))
20 | # upload
21 | upload_parser = parser.add_parser('upload')
22 | upload_parser.add_argument('path', type=str, help='Local path of the folder or individual file to upload.')
23 | upload_parser.add_argument('--filename', type=str, default=None, help='Optional: override individual object filename on S3.')
24 | upload_parser.set_defaults(func=lambda args: UploadCommand(args))
25 |
26 |
27 |
28 | class ANSI:
29 | """
30 | Helper for en.wikipedia.org/wiki/ANSI_escape_code
31 | """
32 | _bold = u"\u001b[1m"
33 | _reset = u"\u001b[0m"
34 | @classmethod
35 | def bold(cls, s):
36 | return "{}{}{}".format(cls._bold, s, cls._reset)
37 |
38 |
39 | class BaseUserCommand:
40 | def __init__(self, args):
41 | self.args = args
42 | self._api = HfApi()
43 |
44 |
45 | class LoginCommand(BaseUserCommand):
46 | def run(self):
47 | print("""
48 | _| _| _| _| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _|_|_|_| _|_| _|_|_| _|_|_|_|
49 | _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|
50 | _|_|_|_| _| _| _| _|_| _| _|_| _| _| _| _| _| _|_| _|_|_| _|_|_|_| _| _|_|_|
51 | _| _| _| _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|
52 | _| _| _|_| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _| _| _| _|_|_| _|_|_|_|
53 |
54 | """)
55 | username = input("Username: ")
56 | password = getpass()
57 | try:
58 | token = self._api.login(username, password)
59 | except HTTPError as e:
60 | # probably invalid credentials, display error message.
61 | print(e)
62 | exit(1)
63 | HfFolder.save_token(token)
64 | print("Login successful")
65 | print("Your token:", token, "\n")
66 | print("Your token has been saved to", HfFolder.path_token)
67 |
68 |
69 | class WhoamiCommand(BaseUserCommand):
70 | def run(self):
71 | token = HfFolder.get_token()
72 | if token is None:
73 | print("Not logged in")
74 | exit()
75 | try:
76 | user = self._api.whoami(token)
77 | print(user)
78 | except HTTPError as e:
79 | print(e)
80 |
81 |
82 | class LogoutCommand(BaseUserCommand):
83 | def run(self):
84 | token = HfFolder.get_token()
85 | if token is None:
86 | print("Not logged in")
87 | exit()
88 | HfFolder.delete_token()
89 | self._api.logout(token)
90 | print("Successfully logged out.")
91 |
92 |
93 | class ListObjsCommand(BaseUserCommand):
94 | def tabulate(self, rows, headers):
95 | # type: (List[List[Union[str, int]]], List[str]) -> str
96 | """
97 | Inspired by:
98 | stackoverflow.com/a/8356620/593036
99 | stackoverflow.com/questions/9535954/printing-lists-as-tabular-data
100 | """
101 | col_widths = [max(len(str(x)) for x in col) for col in zip(*rows, headers)]
102 | row_format = ("{{:{}}} " * len(headers)).format(*col_widths)
103 | lines = []
104 | lines.append(
105 | row_format.format(*headers)
106 | )
107 | lines.append(
108 | row_format.format(*["-" * w for w in col_widths])
109 | )
110 | for row in rows:
111 | lines.append(
112 | row_format.format(*row)
113 | )
114 | return "\n".join(lines)
115 |
116 | def run(self):
117 | token = HfFolder.get_token()
118 | if token is None:
119 | print("Not logged in")
120 | exit(1)
121 | try:
122 | objs = self._api.list_objs(token)
123 | except HTTPError as e:
124 | print(e)
125 | exit(1)
126 | if len(objs) == 0:
127 | print("No shared file yet")
128 | exit()
129 | rows = [ [
130 | obj.filename,
131 | obj.LastModified,
132 | obj.ETag,
133 | obj.Size
134 | ] for obj in objs ]
135 | print(
136 | self.tabulate(rows, headers=["Filename", "LastModified", "ETag", "Size"])
137 | )
138 |
139 |
140 | class UploadCommand(BaseUserCommand):
141 | def walk_dir(self, rel_path):
142 | """
143 | Recursively list all files in a folder.
144 | """
145 | entries: List[os.DirEntry] = list(os.scandir(rel_path))
146 | files = [
147 | (
148 | os.path.join(os.getcwd(), f.path), # filepath
149 | f.path # filename
150 | )
151 | for f in entries if f.is_file()
152 | ]
153 | for f in entries:
154 | if f.is_dir():
155 | files += self.walk_dir(f.path)
156 | return files
157 |
158 | def run(self):
159 | token = HfFolder.get_token()
160 | if token is None:
161 | print("Not logged in")
162 | exit(1)
163 | local_path = os.path.abspath(self.args.path)
164 | if os.path.isdir(local_path):
165 | if self.args.filename is not None:
166 | raise ValueError("Cannot specify a filename override when uploading a folder.")
167 | rel_path = os.path.basename(local_path)
168 | files = self.walk_dir(rel_path)
169 | elif os.path.isfile(local_path):
170 | filename = self.args.filename if self.args.filename is not None else os.path.basename(local_path)
171 | files = [(local_path, filename)]
172 | else:
173 | raise ValueError("Not a valid file or directory: {}".format(local_path))
174 |
175 | for filepath, filename in files:
176 | print(
177 | "About to upload file {} to S3 under filename {}".format(
178 | ANSI.bold(filepath), ANSI.bold(filename)
179 | )
180 | )
181 |
182 | choice = input("Proceed? [Y/n] ").lower()
183 | if not(choice == "" or choice == "y" or choice == "yes"):
184 | print("Abort")
185 | exit()
186 | print(
187 | ANSI.bold("Uploading... This might take a while if files are large")
188 | )
189 | for filepath, filename in files:
190 | access_url = self._api.presign_and_upload(
191 | token=token, filename=filename, filepath=filepath
192 | )
193 | print("Your file now lives at:")
194 | print(access_url)
195 |
--------------------------------------------------------------------------------
/transformers/configuration_bert.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """ BERT model configuration """
17 |
18 | from __future__ import absolute_import, division, print_function, unicode_literals
19 |
20 | import json
21 | import logging
22 | import sys
23 | from io import open
24 |
25 | from .configuration_utils import PretrainedConfig
26 |
27 | logger = logging.getLogger(__name__)
28 |
29 | BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
30 | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json",
31 | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json",
32 | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json",
33 | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json",
34 | 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json",
35 | 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json",
36 | 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json",
37 | 'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json",
38 | 'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json",
39 | 'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json",
40 | 'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json",
41 | 'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json",
42 | 'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json",
43 | 'bert-base-german-dbmdz-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-config.json",
44 | 'bert-base-german-dbmdz-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-config.json",
45 | 'bert-base-japanese': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-config.json",
46 | 'bert-base-japanese-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking-config.json",
47 | 'bert-base-japanese-char': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-config.json",
48 | 'bert-base-japanese-char-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking-config.json"
49 | }
50 |
51 |
52 | class BertConfig(PretrainedConfig):
53 | r"""
54 | :class:`~transformers.BertConfig` is the configuration class to store the configuration of a
55 | `BertModel`.
56 |
57 |
58 | Arguments:
59 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`.
60 | hidden_size: Size of the encoder layers and the pooler layer.
61 | num_hidden_layers: Number of hidden layers in the Transformer encoder.
62 | num_attention_heads: Number of attention heads for each attention layer in
63 | the Transformer encoder.
64 | intermediate_size: The size of the "intermediate" (i.e., feed-forward)
65 | layer in the Transformer encoder.
66 | hidden_act: The non-linear activation function (function or string) in the
67 | encoder and pooler. If string, "gelu", "relu", "swish" and "gelu_new" are supported.
68 | hidden_dropout_prob: The dropout probabilitiy for all fully connected
69 | layers in the embeddings, encoder, and pooler.
70 | attention_probs_dropout_prob: The dropout ratio for the attention
71 | probabilities.
72 | max_position_embeddings: The maximum sequence length that this model might
73 | ever be used with. Typically set this to something large just in case
74 | (e.g., 512 or 1024 or 2048).
75 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into
76 | `BertModel`.
77 | initializer_range: The sttdev of the truncated_normal_initializer for
78 | initializing all weight matrices.
79 | layer_norm_eps: The epsilon used by LayerNorm.
80 | """
81 | pretrained_config_archive_map = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
82 |
83 | def __init__(self,
84 | vocab_size_or_config_json_file=30522,
85 | hidden_size=768,
86 | num_hidden_layers=12,
87 | num_attention_heads=12,
88 | intermediate_size=3072,
89 | hidden_act="gelu",
90 | hidden_dropout_prob=0.1,
91 | attention_probs_dropout_prob=0.1,
92 | max_position_embeddings=512,
93 | type_vocab_size=2,
94 | initializer_range=0.02,
95 | layer_norm_eps=1e-12,
96 | **kwargs):
97 | super(BertConfig, self).__init__(**kwargs)
98 | if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
99 | and isinstance(vocab_size_or_config_json_file, unicode)):
100 | with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
101 | json_config = json.loads(reader.read())
102 | for key, value in json_config.items():
103 | self.__dict__[key] = value
104 | elif isinstance(vocab_size_or_config_json_file, int):
105 | self.vocab_size = vocab_size_or_config_json_file
106 | self.hidden_size = hidden_size
107 | self.num_hidden_layers = num_hidden_layers
108 | self.num_attention_heads = num_attention_heads
109 | self.hidden_act = hidden_act
110 | self.intermediate_size = intermediate_size
111 | self.hidden_dropout_prob = hidden_dropout_prob
112 | self.attention_probs_dropout_prob = attention_probs_dropout_prob
113 | self.max_position_embeddings = max_position_embeddings
114 | self.type_vocab_size = type_vocab_size
115 | self.initializer_range = initializer_range
116 | self.layer_norm_eps = layer_norm_eps
117 | else:
118 | raise ValueError("First argument must be either a vocabulary size (int)"
119 | " or the path to a pretrained model config file (str)")
120 |
--------------------------------------------------------------------------------
/transformers/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .processors import InputExample, InputFeatures, DataProcessor, SquadFeatures
2 | from .processors import glue_output_modes, glue_processors, glue_tasks_num_labels, glue_convert_examples_to_features
3 | from .processors import squad_convert_examples_to_features, SquadExample, SquadV1Processor, SquadV2Processor
4 | from .processors import xnli_output_modes, xnli_processors, xnli_tasks_num_labels
5 |
6 | from .metrics import is_sklearn_available
7 | if is_sklearn_available():
8 | from .metrics import glue_compute_metrics, xnli_compute_metrics
9 |
--------------------------------------------------------------------------------
/transformers/data/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import csv
18 | import sys
19 | import logging
20 |
21 | logger = logging.getLogger(__name__)
22 |
23 | try:
24 | from scipy.stats import pearsonr, spearmanr
25 | from sklearn.metrics import matthews_corrcoef, f1_score
26 | _has_sklearn = True
27 | except (AttributeError, ImportError) as e:
28 | logger.warning("To use data.metrics please install scikit-learn. See https://scikit-learn.org/stable/index.html")
29 | _has_sklearn = False
30 |
31 | def is_sklearn_available():
32 | return _has_sklearn
33 |
34 | if _has_sklearn:
35 |
36 | def simple_accuracy(preds, labels):
37 | return (preds == labels).mean()
38 |
39 |
40 | def acc_and_f1(preds, labels):
41 | acc = simple_accuracy(preds, labels)
42 | f1 = f1_score(y_true=labels, y_pred=preds)
43 | return {
44 | "acc": acc,
45 | "f1": f1,
46 | "acc_and_f1": (acc + f1) / 2,
47 | }
48 |
49 |
50 | def pearson_and_spearman(preds, labels):
51 | pearson_corr = pearsonr(preds, labels)[0]
52 | spearman_corr = spearmanr(preds, labels)[0]
53 | return {
54 | "pearson": pearson_corr,
55 | "spearmanr": spearman_corr,
56 | "corr": (pearson_corr + spearman_corr) / 2,
57 | }
58 |
59 |
60 | def glue_compute_metrics(task_name, preds, labels):
61 | assert len(preds) == len(labels)
62 | if task_name == "cola":
63 | return {"mcc": matthews_corrcoef(labels, preds)}
64 | elif task_name == "sst-2":
65 | return {"acc": simple_accuracy(preds, labels)}
66 | elif task_name == "mrpc":
67 | return acc_and_f1(preds, labels)
68 | elif task_name == "sts-b":
69 | return pearson_and_spearman(preds, labels)
70 | elif task_name == "qqp":
71 | return acc_and_f1(preds, labels)
72 | elif task_name == "mnli":
73 | return {"acc": simple_accuracy(preds, labels)}
74 | elif task_name == "mnli-mm":
75 | return {"acc": simple_accuracy(preds, labels)}
76 | elif task_name == "qnli":
77 | return {"acc": simple_accuracy(preds, labels)}
78 | elif task_name == "rte":
79 | return {"acc": simple_accuracy(preds, labels)}
80 | elif task_name == "wnli":
81 | return {"acc": simple_accuracy(preds, labels)}
82 | else:
83 | raise KeyError(task_name)
84 |
85 |
86 | def xnli_compute_metrics(task_name, preds, labels):
87 | assert len(preds) == len(labels)
88 | if task_name == "xnli":
89 | return {"acc": simple_accuracy(preds, labels)}
90 | else:
91 | raise KeyError(task_name)
92 |
--------------------------------------------------------------------------------
/transformers/data/processors/__init__.py:
--------------------------------------------------------------------------------
1 | from .utils import InputExample, InputFeatures, DataProcessor
2 | from .glue import glue_output_modes, glue_processors, glue_tasks_num_labels, glue_convert_examples_to_features
3 | from .squad import squad_convert_examples_to_features, SquadFeatures, SquadExample, SquadV1Processor, SquadV2Processor
4 | from .xnli import xnli_output_modes, xnli_processors, xnli_tasks_num_labels
--------------------------------------------------------------------------------
/transformers/data/processors/utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import csv
18 | import sys
19 | import copy
20 | import json
21 |
22 | class InputExample(object):
23 | """
24 | A single training/test example for simple sequence classification.
25 |
26 | Args:
27 | guid: Unique id for the example.
28 | text_a: string. The untokenized text of the first sequence. For single
29 | sequence tasks, only this sequence must be specified.
30 | text_b: (Optional) string. The untokenized text of the second sequence.
31 | Only must be specified for sequence pair tasks.
32 | label: (Optional) string. The label of the example. This should be
33 | specified for train and dev examples, but not for test examples.
34 | """
35 | def __init__(self, guid, text_a, text_b=None, label=None):
36 | self.guid = guid
37 | self.text_a = text_a
38 | self.text_b = text_b
39 | self.label = label
40 |
41 | def __repr__(self):
42 | return str(self.to_json_string())
43 |
44 | def to_dict(self):
45 | """Serializes this instance to a Python dictionary."""
46 | output = copy.deepcopy(self.__dict__)
47 | return output
48 |
49 | def to_json_string(self):
50 | """Serializes this instance to a JSON string."""
51 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
52 |
53 |
54 | class InputFeatures(object):
55 | """
56 | A single set of features of data.
57 |
58 | Args:
59 | input_ids: Indices of input sequence tokens in the vocabulary.
60 | attention_mask: Mask to avoid performing attention on padding token indices.
61 | Mask values selected in ``[0, 1]``:
62 | Usually ``1`` for tokens that are NOT MASKED, ``0`` for MASKED (padded) tokens.
63 | token_type_ids: Segment token indices to indicate first and second portions of the inputs.
64 | label: Label corresponding to the input
65 | """
66 |
67 | def __init__(self, input_ids, attention_mask, token_type_ids, label):
68 | self.input_ids = input_ids
69 | self.attention_mask = attention_mask
70 | self.token_type_ids = token_type_ids
71 | self.label = label
72 |
73 | def __repr__(self):
74 | return str(self.to_json_string())
75 |
76 | def to_dict(self):
77 | """Serializes this instance to a Python dictionary."""
78 | output = copy.deepcopy(self.__dict__)
79 | return output
80 |
81 | def to_json_string(self):
82 | """Serializes this instance to a JSON string."""
83 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
84 |
85 |
86 | class DataProcessor(object):
87 | """Base class for data converters for sequence classification data sets."""
88 |
89 | def get_example_from_tensor_dict(self, tensor_dict):
90 | """Gets an example from a dict with tensorflow tensors
91 |
92 | Args:
93 | tensor_dict: Keys and values should match the corresponding Glue
94 | tensorflow_dataset examples.
95 | """
96 | raise NotImplementedError()
97 |
98 | def get_train_examples(self, data_dir):
99 | """Gets a collection of `InputExample`s for the train set."""
100 | raise NotImplementedError()
101 |
102 | def get_dev_examples(self, data_dir):
103 | """Gets a collection of `InputExample`s for the dev set."""
104 | raise NotImplementedError()
105 |
106 | def get_labels(self):
107 | """Gets the list of labels for this data set."""
108 | raise NotImplementedError()
109 |
110 | def tfds_map(self, example):
111 | """Some tensorflow_datasets datasets are not formatted the same way the GLUE datasets are.
112 | This method converts examples to the correct format."""
113 | if len(self.get_labels()) > 1:
114 | example.label = self.get_labels()[int(example.label)]
115 | return example
116 |
117 | @classmethod
118 | def _read_tsv(cls, input_file, quotechar=None):
119 | """Reads a tab separated value file."""
120 | with open(input_file, "r", encoding="utf-8-sig") as f:
121 | reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
122 | lines = []
123 | for line in reader:
124 | if sys.version_info[0] == 2:
125 | line = list(unicode(cell, 'utf-8') for cell in line)
126 | lines.append(line)
127 | return lines
128 |
--------------------------------------------------------------------------------
/transformers/data/processors/xnli.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """ XNLI utils (dataset loading and evaluation) """
17 |
18 | from __future__ import absolute_import, division, print_function
19 |
20 | import logging
21 | import os
22 |
23 | from .utils import DataProcessor, InputExample
24 |
25 | logger = logging.getLogger(__name__)
26 |
27 | class XnliProcessor(DataProcessor):
28 | """Processor for the XNLI dataset.
29 | Adapted from https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/run_classifier.py#L207"""
30 |
31 | def __init__(self, language, train_language = None):
32 | self.language = language
33 | self.train_language = train_language
34 |
35 | def get_train_examples(self, data_dir):
36 | """See base class."""
37 | lg = self.language if self.train_language is None else self.train_language
38 | lines = self._read_tsv(os.path.join(data_dir, "XNLI-MT-1.0/multinli/multinli.train.{}.tsv".format(lg)))
39 | examples = []
40 | for (i, line) in enumerate(lines):
41 | if i == 0:
42 | continue
43 | guid = "%s-%s" % ('train', i)
44 | text_a = line[0]
45 | text_b = line[1]
46 | label = "contradiction" if line[2] == "contradictory" else line[2]
47 | assert isinstance(text_a, str) and isinstance(text_b, str) and isinstance(label, str)
48 | examples.append(
49 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
50 | return examples
51 |
52 | def get_test_examples(self, data_dir):
53 | """See base class."""
54 | lines = self._read_tsv(os.path.join(data_dir, "XNLI-1.0/xnli.test.tsv"))
55 | examples = []
56 | for (i, line) in enumerate(lines):
57 | if i == 0:
58 | continue
59 | language = line[0]
60 | if language != self.language:
61 | continue
62 | guid = "%s-%s" % ('test', i)
63 | text_a = line[6]
64 | text_b = line[7]
65 | label = line[1]
66 | assert isinstance(text_a, str) and isinstance(text_b, str) and isinstance(label, str)
67 | examples.append(
68 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
69 | return examples
70 |
71 | def get_labels(self):
72 | """See base class."""
73 | return ["contradiction", "entailment", "neutral"]
74 |
75 | xnli_processors = {
76 | "xnli": XnliProcessor,
77 | }
78 |
79 | xnli_output_modes = {
80 | "xnli": "classification",
81 | }
82 |
83 | xnli_tasks_num_labels = {
84 | "xnli": 3,
85 | }
86 |
--------------------------------------------------------------------------------
/transformers/hf_api.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019-present, the HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | from __future__ import absolute_import, division, print_function
16 |
17 | import os
18 | from os.path import expanduser
19 |
20 | import requests
21 | import six
22 | from requests.exceptions import HTTPError
23 | from tqdm import tqdm
24 |
25 | ENDPOINT = "https://huggingface.co"
26 |
27 | class S3Obj:
28 | def __init__(
29 | self,
30 | filename, # type: str
31 | LastModified, # type: str
32 | ETag, # type: str
33 | Size, # type: int
34 | **kwargs
35 | ):
36 | self.filename = filename
37 | self.LastModified = LastModified
38 | self.ETag = ETag
39 | self.Size = Size
40 |
41 |
42 | class PresignedUrl:
43 | def __init__(
44 | self,
45 | write, # type: str
46 | access, # type: str
47 | type, # type: str
48 | **kwargs
49 | ):
50 | self.write = write
51 | self.access = access
52 | self.type = type # mime-type to send to S3.
53 |
54 |
55 | class HfApi:
56 | def __init__(self, endpoint=None):
57 | self.endpoint = endpoint if endpoint is not None else ENDPOINT
58 |
59 | def login(
60 | self,
61 | username, # type: str
62 | password, # type: str
63 | ):
64 | # type: (...) -> str
65 | """
66 | Call HF API to sign in a user and get a token if credentials are valid.
67 |
68 | Outputs:
69 | token if credentials are valid
70 |
71 | Throws:
72 | requests.exceptions.HTTPError if credentials are invalid
73 | """
74 | path = "{}/api/login".format(self.endpoint)
75 | r = requests.post(path, json={"username": username, "password": password})
76 | r.raise_for_status()
77 | d = r.json()
78 | return d["token"]
79 |
80 | def whoami(
81 | self,
82 | token, # type: str
83 | ):
84 | # type: (...) -> str
85 | """
86 | Call HF API to know "whoami"
87 | """
88 | path = "{}/api/whoami".format(self.endpoint)
89 | r = requests.get(path, headers={"authorization": "Bearer {}".format(token)})
90 | r.raise_for_status()
91 | d = r.json()
92 | return d["user"]
93 |
94 | def logout(self, token):
95 | # type: (...) -> void
96 | """
97 | Call HF API to log out.
98 | """
99 | path = "{}/api/logout".format(self.endpoint)
100 | r = requests.post(path, headers={"authorization": "Bearer {}".format(token)})
101 | r.raise_for_status()
102 |
103 | def presign(self, token, filename):
104 | # type: (...) -> PresignedUrl
105 | """
106 | Call HF API to get a presigned url to upload `filename` to S3.
107 | """
108 | path = "{}/api/presign".format(self.endpoint)
109 | r = requests.post(
110 | path,
111 | headers={"authorization": "Bearer {}".format(token)},
112 | json={"filename": filename},
113 | )
114 | r.raise_for_status()
115 | d = r.json()
116 | return PresignedUrl(**d)
117 |
118 | def presign_and_upload(self, token, filename, filepath):
119 | # type: (...) -> str
120 | """
121 | Get a presigned url, then upload file to S3.
122 |
123 | Outputs:
124 | url: Read-only url for the stored file on S3.
125 | """
126 | urls = self.presign(token, filename=filename)
127 | # streaming upload:
128 | # https://2.python-requests.org/en/master/user/advanced/#streaming-uploads
129 | #
130 | # Even though we presign with the correct content-type,
131 | # the client still has to specify it when uploading the file.
132 | with open(filepath, "rb") as f:
133 | pf = TqdmProgressFileReader(f)
134 |
135 | r = requests.put(urls.write, data=f, headers={
136 | "content-type": urls.type,
137 | })
138 | r.raise_for_status()
139 | pf.close()
140 | return urls.access
141 |
142 | def list_objs(self, token):
143 | # type: (...) -> List[S3Obj]
144 | """
145 | Call HF API to list all stored files for user.
146 | """
147 | path = "{}/api/listObjs".format(self.endpoint)
148 | r = requests.get(path, headers={"authorization": "Bearer {}".format(token)})
149 | r.raise_for_status()
150 | d = r.json()
151 | return [S3Obj(**x) for x in d]
152 |
153 |
154 |
155 | class TqdmProgressFileReader:
156 | """
157 | Wrap an io.BufferedReader `f` (such as the output of `open(…, "rb")`)
158 | and override `f.read()` so as to display a tqdm progress bar.
159 |
160 | see github.com/huggingface/transformers/pull/2078#discussion_r354739608
161 | for implementation details.
162 | """
163 | def __init__(
164 | self,
165 | f # type: io.BufferedReader
166 | ):
167 | self.f = f
168 | self.total_size = os.fstat(f.fileno()).st_size # type: int
169 | self.pbar = tqdm(total=self.total_size, leave=False)
170 | if six.PY3:
171 | # does not work unless PY3
172 | # no big deal as the CLI does not currently support PY2 anyways.
173 | self.read = f.read
174 | f.read = self._read
175 |
176 | def _read(self, n=-1):
177 | self.pbar.update(n)
178 | return self.read(n)
179 |
180 | def close(self):
181 | self.pbar.close()
182 |
183 |
184 |
185 | class HfFolder:
186 | path_token = expanduser("~/.huggingface/token")
187 |
188 | @classmethod
189 | def save_token(cls, token):
190 | """
191 | Save token, creating folder as needed.
192 | """
193 | if six.PY3:
194 | os.makedirs(os.path.dirname(cls.path_token), exist_ok=True)
195 | else:
196 | # Python 2
197 | try:
198 | os.makedirs(os.path.dirname(cls.path_token))
199 | except OSError as e:
200 | if e.errno != os.errno.EEXIST:
201 | raise e
202 | pass
203 | with open(cls.path_token, 'w+') as f:
204 | f.write(token)
205 |
206 | @classmethod
207 | def get_token(cls):
208 | """
209 | Get token or None if not existent.
210 | """
211 | try:
212 | with open(cls.path_token, 'r') as f:
213 | return f.read()
214 | except:
215 | # this is too wide. When Py2 is dead use:
216 | # `except FileNotFoundError:` instead
217 | return None
218 |
219 | @classmethod
220 | def delete_token(cls):
221 | """
222 | Delete token.
223 | Do not fail if token does not exist.
224 | """
225 | try:
226 | os.remove(cls.path_token)
227 | except:
228 | return
229 |
--------------------------------------------------------------------------------
/transformers/optimization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """PyTorch optimization for BERT model."""
16 |
17 | import logging
18 | import math
19 |
20 | import torch
21 | from torch.optim import Optimizer
22 | from torch.optim.lr_scheduler import LambdaLR
23 |
24 | logger = logging.getLogger(__name__)
25 |
26 |
27 | def get_constant_schedule(optimizer, last_epoch=-1):
28 | """ Create a schedule with a constant learning rate.
29 | """
30 | return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch)
31 |
32 |
33 | def get_constant_schedule_with_warmup(optimizer, num_warmup_steps, last_epoch=-1):
34 | """ Create a schedule with a constant learning rate preceded by a warmup
35 | period during which the learning rate increases linearly between 0 and 1.
36 | """
37 | def lr_lambda(current_step):
38 | if current_step < num_warmup_steps:
39 | return float(current_step) / float(max(1.0, num_warmup_steps))
40 | return 1.
41 |
42 | return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
43 |
44 |
45 | def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
46 | """ Create a schedule with a learning rate that decreases linearly after
47 | linearly increasing during a warmup period.
48 | """
49 | def lr_lambda(current_step):
50 | if current_step < num_warmup_steps:
51 | return float(current_step) / float(max(1, num_warmup_steps))
52 | return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)))
53 |
54 | return LambdaLR(optimizer, lr_lambda, last_epoch)
55 |
56 |
57 | def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=.5, last_epoch=-1):
58 | """ Create a schedule with a learning rate that decreases following the
59 | values of the cosine function between 0 and `pi * cycles` after a warmup
60 | period during which it increases linearly between 0 and 1.
61 | """
62 | def lr_lambda(current_step):
63 | if current_step < num_warmup_steps:
64 | return float(current_step) / float(max(1, num_warmup_steps))
65 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
66 | return max(0., 0.5 * (1. + math.cos(math.pi * float(num_cycles) * 2. * progress)))
67 |
68 | return LambdaLR(optimizer, lr_lambda, last_epoch)
69 |
70 |
71 | def get_cosine_with_hard_restarts_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=1., last_epoch=-1):
72 | """ Create a schedule with a learning rate that decreases following the
73 | values of the cosine function with several hard restarts, after a warmup
74 | period during which it increases linearly between 0 and 1.
75 | """
76 | def lr_lambda(current_step):
77 | if current_step < num_warmup_steps:
78 | return float(current_step) / float(max(1, num_warmup_steps))
79 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
80 | if progress >= 1.:
81 | return 0.
82 | return max(0., 0.5 * (1. + math.cos(math.pi * ((float(num_cycles) * progress) % 1.))))
83 |
84 | return LambdaLR(optimizer, lr_lambda, last_epoch)
85 |
86 |
87 | class AdamW(Optimizer):
88 | """ Implements Adam algorithm with weight decay fix.
89 |
90 | Parameters:
91 | lr (float): learning rate. Default 1e-3.
92 | betas (tuple of 2 floats): Adams beta parameters (b1, b2). Default: (0.9, 0.999)
93 | eps (float): Adams epsilon. Default: 1e-6
94 | weight_decay (float): Weight decay. Default: 0.0
95 | correct_bias (bool): can be set to False to avoid correcting bias in Adam (e.g. like in Bert TF repository). Default True.
96 | """
97 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.0, correct_bias=True):
98 | if lr < 0.0:
99 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
100 | if not 0.0 <= betas[0] < 1.0:
101 | raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0]))
102 | if not 0.0 <= betas[1] < 1.0:
103 | raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1]))
104 | if not 0.0 <= eps:
105 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps))
106 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
107 | correct_bias=correct_bias)
108 | super(AdamW, self).__init__(params, defaults)
109 |
110 | def step(self, closure=None):
111 | """Performs a single optimization step.
112 |
113 | Arguments:
114 | closure (callable, optional): A closure that reevaluates the model
115 | and returns the loss.
116 | """
117 | loss = None
118 | if closure is not None:
119 | loss = closure()
120 |
121 | for group in self.param_groups:
122 | for p in group['params']:
123 | if p.grad is None:
124 | continue
125 | grad = p.grad.data
126 | if grad.is_sparse:
127 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
128 |
129 | state = self.state[p]
130 |
131 | # State initialization
132 | if len(state) == 0:
133 | state['step'] = 0
134 | # Exponential moving average of gradient values
135 | state['exp_avg'] = torch.zeros_like(p.data)
136 | # Exponential moving average of squared gradient values
137 | state['exp_avg_sq'] = torch.zeros_like(p.data)
138 |
139 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
140 | beta1, beta2 = group['betas']
141 |
142 | state['step'] += 1
143 |
144 | # Decay the first and second moment running average coefficient
145 | # In-place operations to update the averages at the same time
146 | exp_avg.mul_(beta1).add_(1.0 - beta1, grad)
147 | exp_avg_sq.mul_(beta2).addcmul_(1.0 - beta2, grad, grad)
148 | denom = exp_avg_sq.sqrt().add_(group['eps'])
149 |
150 | step_size = group['lr']
151 | if group['correct_bias']: # No bias correction for Bert
152 | bias_correction1 = 1.0 - beta1 ** state['step']
153 | bias_correction2 = 1.0 - beta2 ** state['step']
154 | step_size = step_size * math.sqrt(bias_correction2) / bias_correction1
155 |
156 | p.data.addcdiv_(-step_size, exp_avg, denom)
157 |
158 | # Just adding the square of the weights to the loss function is *not*
159 | # the correct way of using L2 regularization/weight decay with Adam,
160 | # since that will interact with the m and v parameters in strange ways.
161 | #
162 | # Instead we want to decay the weights in a manner that doesn't interact
163 | # with the m/v parameters. This is equivalent to adding the square
164 | # of the weights to the loss with plain (non-momentum) SGD.
165 | # Add weight decay at the end (fixed version)
166 | if group['weight_decay'] > 0.0:
167 | p.data.add_(-group['lr'] * group['weight_decay'], p.data)
168 |
169 | return loss
170 |
--------------------------------------------------------------------------------
/transformers/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Neutralzz/syntax_aware_local_attention/e1b9397278fedb56b09320ad18e9cb9c543f5306/transformers/tests/__init__.py
--------------------------------------------------------------------------------
/transformers/tests/configuration_common_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 HuggingFace Inc.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import copy
20 | import os
21 | import shutil
22 | import json
23 | import random
24 | import uuid
25 |
26 | import unittest
27 | import logging
28 |
29 |
30 | class ConfigTester(object):
31 | def __init__(self, parent, config_class=None, **kwargs):
32 | self.parent = parent
33 | self.config_class = config_class
34 | self.inputs_dict = kwargs
35 |
36 | def create_and_test_config_common_properties(self):
37 | config = self.config_class(**self.inputs_dict)
38 | self.parent.assertTrue(hasattr(config, 'vocab_size'))
39 | self.parent.assertTrue(hasattr(config, 'hidden_size'))
40 | self.parent.assertTrue(hasattr(config, 'num_attention_heads'))
41 | self.parent.assertTrue(hasattr(config, 'num_hidden_layers'))
42 |
43 | def create_and_test_config_to_json_string(self):
44 | config = self.config_class(**self.inputs_dict)
45 | obj = json.loads(config.to_json_string())
46 | for key, value in self.inputs_dict.items():
47 | self.parent.assertEqual(obj[key], value)
48 |
49 | def create_and_test_config_to_json_file(self):
50 | config_first = self.config_class(**self.inputs_dict)
51 | json_file_path = os.path.join(os.getcwd(), "config_" + str(uuid.uuid4()) + ".json")
52 | config_first.to_json_file(json_file_path)
53 | config_second = self.config_class.from_json_file(json_file_path)
54 | os.remove(json_file_path)
55 | self.parent.assertEqual(config_second.to_dict(), config_first.to_dict())
56 |
57 | def run_common_tests(self):
58 | self.create_and_test_config_common_properties()
59 | self.create_and_test_config_to_json_string()
60 | self.create_and_test_config_to_json_file()
61 |
62 | if __name__ == "__main__":
63 | unittest.main()
--------------------------------------------------------------------------------
/transformers/tests/fixtures/input.txt:
--------------------------------------------------------------------------------
1 | Who was Jim Henson ? ||| Jim Henson was a puppeteer
2 |
--------------------------------------------------------------------------------
/transformers/tests/fixtures/sample_text.txt:
--------------------------------------------------------------------------------
1 | This text is included to make sure Unicode is handled properly: 力加勝北区ᴵᴺᵀᵃছজটডণত
2 | Text should be one-sentence-per-line, with empty lines between documents.
3 | This sample text is public domain and was randomly selected from Project Guttenberg.
4 |
5 | The rain had only ceased with the gray streaks of morning at Blazing Star, and the settlement awoke to a moral sense of cleanliness, and the finding of forgotten knives, tin cups, and smaller camp utensils, where the heavy showers had washed away the debris and dust heaps before the cabin doors.
6 | Indeed, it was recorded in Blazing Star that a fortunate early riser had once picked up on the highway a solid chunk of gold quartz which the rain had freed from its incumbering soil, and washed into immediate and glittering popularity.
7 | Possibly this may have been the reason why early risers in that locality, during the rainy season, adopted a thoughtful habit of body, and seldom lifted their eyes to the rifted or india-ink washed skies above them.
8 | "Cass" Beard had risen early that morning, but not with a view to discovery.
9 | A leak in his cabin roof,--quite consistent with his careless, improvident habits,--had roused him at 4 A. M., with a flooded "bunk" and wet blankets.
10 | The chips from his wood pile refused to kindle a fire to dry his bed-clothes, and he had recourse to a more provident neighbor's to supply the deficiency.
11 | This was nearly opposite.
12 | Mr. Cassius crossed the highway, and stopped suddenly.
13 | Something glittered in the nearest red pool before him.
14 | Gold, surely!
15 | But, wonderful to relate, not an irregular, shapeless fragment of crude ore, fresh from Nature's crucible, but a bit of jeweler's handicraft in the form of a plain gold ring.
16 | Looking at it more attentively, he saw that it bore the inscription, "May to Cass."
17 | Like most of his fellow gold-seekers, Cass was superstitious.
18 |
19 | The fountain of classic wisdom, Hypatia herself.
20 | As the ancient sage--the name is unimportant to a monk--pumped water nightly that he might study by day, so I, the guardian of cloaks and parasols, at the sacred doors of her lecture-room, imbibe celestial knowledge.
21 | From my youth I felt in me a soul above the matter-entangled herd.
22 | She revealed to me the glorious fact, that I am a spark of Divinity itself.
23 | A fallen star, I am, sir!' continued he, pensively, stroking his lean stomach--'a fallen star!--fallen, if the dignity of philosophy will allow of the simile, among the hogs of the lower world--indeed, even into the hog-bucket itself. Well, after all, I will show you the way to the Archbishop's.
24 | There is a philosophic pleasure in opening one's treasures to the modest young.
25 | Perhaps you will assist me by carrying this basket of fruit?' And the little man jumped up, put his basket on Philammon's head, and trotted off up a neighbouring street.
26 | Philammon followed, half contemptuous, half wondering at what this philosophy might be, which could feed the self-conceit of anything so abject as his ragged little apish guide;
27 | but the novel roar and whirl of the street, the perpetual stream of busy faces, the line of curricles, palanquins, laden asses, camels, elephants, which met and passed him, and squeezed him up steps and into doorways, as they threaded their way through the great Moon-gate into the ample street beyond, drove everything from his mind but wondering curiosity, and a vague, helpless dread of that great living wilderness, more terrible than any dead wilderness of sand which he had left behind.
28 | Already he longed for the repose, the silence of the Laura--for faces which knew him and smiled upon him; but it was too late to turn back now.
29 | His guide held on for more than a mile up the great main street, crossed in the centre of the city, at right angles, by one equally magnificent, at each end of which, miles away, appeared, dim and distant over the heads of the living stream of passengers, the yellow sand-hills of the desert;
30 | while at the end of the vista in front of them gleamed the blue harbour, through a network of countless masts.
31 | At last they reached the quay at the opposite end of the street;
32 | and there burst on Philammon's astonished eyes a vast semicircle of blue sea, ringed with palaces and towers.
33 | He stopped involuntarily; and his little guide stopped also, and looked askance at the young monk, to watch the effect which that grand panorama should produce on him.
34 |
--------------------------------------------------------------------------------
/transformers/tests/fixtures/spiece.model:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Neutralzz/syntax_aware_local_attention/e1b9397278fedb56b09320ad18e9cb9c543f5306/transformers/tests/fixtures/spiece.model
--------------------------------------------------------------------------------
/transformers/tests/fixtures/test_sentencepiece.model:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Neutralzz/syntax_aware_local_attention/e1b9397278fedb56b09320ad18e9cb9c543f5306/transformers/tests/fixtures/test_sentencepiece.model
--------------------------------------------------------------------------------
/transformers/tests/hf_api_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019-present, the HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | from __future__ import absolute_import, division, print_function
16 |
17 | import os
18 | import six
19 | import time
20 | import unittest
21 |
22 | from transformers.hf_api import HfApi, S3Obj, PresignedUrl, HfFolder, HTTPError
23 |
24 | USER = "__DUMMY_TRANSFORMERS_USER__"
25 | PASS = "__DUMMY_TRANSFORMERS_PASS__"
26 | FILE_KEY = "Test-{}.txt".format(int(time.time()))
27 | FILE_PATH = os.path.join(
28 | os.path.dirname(os.path.abspath(__file__)), "fixtures/input.txt"
29 | )
30 |
31 |
32 |
33 | class HfApiCommonTest(unittest.TestCase):
34 | _api = HfApi(endpoint="https://moon-staging.huggingface.co")
35 |
36 |
37 | class HfApiLoginTest(HfApiCommonTest):
38 | def test_login_invalid(self):
39 | with self.assertRaises(HTTPError):
40 | self._api.login(username=USER, password="fake")
41 |
42 | def test_login_valid(self):
43 | token = self._api.login(username=USER, password=PASS)
44 | self.assertIsInstance(token, six.string_types)
45 |
46 |
47 | class HfApiEndpointsTest(HfApiCommonTest):
48 | @classmethod
49 | def setUpClass(cls):
50 | """
51 | Share this valid token in all tests below.
52 | """
53 | cls._token = cls._api.login(username=USER, password=PASS)
54 |
55 | def test_whoami(self):
56 | user = self._api.whoami(token=self._token)
57 | self.assertEqual(user, USER)
58 |
59 | def test_presign(self):
60 | urls = self._api.presign(token=self._token, filename=FILE_KEY)
61 | self.assertIsInstance(urls, PresignedUrl)
62 | self.assertEqual(urls.type, "text/plain")
63 |
64 | def test_presign_and_upload(self):
65 | access_url = self._api.presign_and_upload(
66 | token=self._token, filename=FILE_KEY, filepath=FILE_PATH
67 | )
68 | self.assertIsInstance(access_url, six.string_types)
69 |
70 | def test_list_objs(self):
71 | objs = self._api.list_objs(token=self._token)
72 | self.assertIsInstance(objs, list)
73 | if len(objs) > 0:
74 | o = objs[-1]
75 | self.assertIsInstance(o, S3Obj)
76 |
77 |
78 |
79 | class HfFolderTest(unittest.TestCase):
80 | def test_token_workflow(self):
81 | """
82 | Test the whole token save/get/delete workflow,
83 | with the desired behavior with respect to non-existent tokens.
84 | """
85 | token = "token-{}".format(int(time.time()))
86 | HfFolder.save_token(token)
87 | self.assertEqual(
88 | HfFolder.get_token(),
89 | token
90 | )
91 | HfFolder.delete_token()
92 | HfFolder.delete_token()
93 | # ^^ not an error, we test that the
94 | # second call does not fail.
95 | self.assertEqual(
96 | HfFolder.get_token(),
97 | None
98 | )
99 |
100 |
101 | if __name__ == "__main__":
102 | unittest.main()
103 |
--------------------------------------------------------------------------------
/transformers/tests/modeling_auto_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import unittest
20 | import shutil
21 | import logging
22 |
23 | from transformers import is_torch_available
24 |
25 | from .utils import require_torch, slow, SMALL_MODEL_IDENTIFIER
26 |
27 | if is_torch_available():
28 | from transformers import (AutoConfig, BertConfig,
29 | AutoModel, BertModel,
30 | AutoModelWithLMHead, BertForMaskedLM,
31 | AutoModelForSequenceClassification, BertForSequenceClassification,
32 | AutoModelForQuestionAnswering, BertForQuestionAnswering)
33 | from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP
34 |
35 | from .modeling_common_test import (CommonTestCases, ids_tensor)
36 | from .configuration_common_test import ConfigTester
37 |
38 |
39 | @require_torch
40 | class AutoModelTest(unittest.TestCase):
41 | @slow
42 | def test_model_from_pretrained(self):
43 | logging.basicConfig(level=logging.INFO)
44 | for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
45 | config = AutoConfig.from_pretrained(model_name)
46 | self.assertIsNotNone(config)
47 | self.assertIsInstance(config, BertConfig)
48 |
49 | model = AutoModel.from_pretrained(model_name)
50 | model, loading_info = AutoModel.from_pretrained(model_name, output_loading_info=True)
51 | self.assertIsNotNone(model)
52 | self.assertIsInstance(model, BertModel)
53 | for value in loading_info.values():
54 | self.assertEqual(len(value), 0)
55 |
56 | @slow
57 | def test_lmhead_model_from_pretrained(self):
58 | logging.basicConfig(level=logging.INFO)
59 | for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
60 | config = AutoConfig.from_pretrained(model_name)
61 | self.assertIsNotNone(config)
62 | self.assertIsInstance(config, BertConfig)
63 |
64 | model = AutoModelWithLMHead.from_pretrained(model_name)
65 | model, loading_info = AutoModelWithLMHead.from_pretrained(model_name, output_loading_info=True)
66 | self.assertIsNotNone(model)
67 | self.assertIsInstance(model, BertForMaskedLM)
68 |
69 | @slow
70 | def test_sequence_classification_model_from_pretrained(self):
71 | logging.basicConfig(level=logging.INFO)
72 | for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
73 | config = AutoConfig.from_pretrained(model_name)
74 | self.assertIsNotNone(config)
75 | self.assertIsInstance(config, BertConfig)
76 |
77 | model = AutoModelForSequenceClassification.from_pretrained(model_name)
78 | model, loading_info = AutoModelForSequenceClassification.from_pretrained(model_name, output_loading_info=True)
79 | self.assertIsNotNone(model)
80 | self.assertIsInstance(model, BertForSequenceClassification)
81 |
82 | @slow
83 | def test_question_answering_model_from_pretrained(self):
84 | logging.basicConfig(level=logging.INFO)
85 | for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
86 | config = AutoConfig.from_pretrained(model_name)
87 | self.assertIsNotNone(config)
88 | self.assertIsInstance(config, BertConfig)
89 |
90 | model = AutoModelForQuestionAnswering.from_pretrained(model_name)
91 | model, loading_info = AutoModelForQuestionAnswering.from_pretrained(model_name, output_loading_info=True)
92 | self.assertIsNotNone(model)
93 | self.assertIsInstance(model, BertForQuestionAnswering)
94 |
95 | def test_from_pretrained_identifier(self):
96 | logging.basicConfig(level=logging.INFO)
97 | model = AutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER)
98 | self.assertIsInstance(model, BertForMaskedLM)
99 |
100 |
101 | if __name__ == "__main__":
102 | unittest.main()
103 |
--------------------------------------------------------------------------------
/transformers/tests/modeling_ctrl_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Salesforce and HuggingFace Inc. team.
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from __future__ import absolute_import
15 | from __future__ import division
16 | from __future__ import print_function
17 |
18 | import unittest
19 | import shutil
20 | import pdb
21 |
22 | from transformers import is_torch_available
23 |
24 | if is_torch_available():
25 | from transformers import (CTRLConfig, CTRLModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
26 | CTRLLMHeadModel)
27 |
28 | from .modeling_common_test import (CommonTestCases, ids_tensor)
29 | from .configuration_common_test import ConfigTester
30 | from .utils import require_torch, slow, torch_device
31 |
32 |
33 | @require_torch
34 | class CTRLModelTest(CommonTestCases.CommonModelTester):
35 |
36 | all_model_classes = (CTRLModel, CTRLLMHeadModel) if is_torch_available() else ()
37 | test_pruning = False
38 | test_torchscript = False
39 | test_resize_embeddings = False
40 | test_head_masking = False
41 |
42 | class CTRLModelTester(object):
43 |
44 | def __init__(self,
45 | parent,
46 | batch_size=13,
47 | seq_length=7,
48 | is_training=True,
49 | use_token_type_ids=True,
50 | use_input_mask=True,
51 | use_labels=True,
52 | use_mc_token_ids=True,
53 | vocab_size=99,
54 | hidden_size=32,
55 | num_hidden_layers=5,
56 | num_attention_heads=4,
57 | intermediate_size=37,
58 | hidden_act="gelu",
59 | hidden_dropout_prob=0.1,
60 | attention_probs_dropout_prob=0.1,
61 | max_position_embeddings=512,
62 | type_vocab_size=16,
63 | type_sequence_label_size=2,
64 | initializer_range=0.02,
65 | num_labels=3,
66 | num_choices=4,
67 | scope=None,
68 | ):
69 | self.parent = parent
70 | self.batch_size = batch_size
71 | self.seq_length = seq_length
72 | self.is_training = is_training
73 | self.use_token_type_ids = use_token_type_ids
74 | self.use_input_mask = use_input_mask
75 | self.use_labels = use_labels
76 | self.use_mc_token_ids = use_mc_token_ids
77 | self.vocab_size = vocab_size
78 | self.hidden_size = hidden_size
79 | self.num_hidden_layers = num_hidden_layers
80 | self.num_attention_heads = num_attention_heads
81 | self.intermediate_size = intermediate_size
82 | self.hidden_act = hidden_act
83 | self.hidden_dropout_prob = hidden_dropout_prob
84 | self.attention_probs_dropout_prob = attention_probs_dropout_prob
85 | self.max_position_embeddings = max_position_embeddings
86 | self.type_vocab_size = type_vocab_size
87 | self.type_sequence_label_size = type_sequence_label_size
88 | self.initializer_range = initializer_range
89 | self.num_labels = num_labels
90 | self.num_choices = num_choices
91 | self.scope = scope
92 |
93 | def prepare_config_and_inputs(self):
94 | input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
95 |
96 | input_mask = None
97 | if self.use_input_mask:
98 | input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
99 |
100 | token_type_ids = None
101 | if self.use_token_type_ids:
102 | token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
103 |
104 | mc_token_ids = None
105 | if self.use_mc_token_ids:
106 | mc_token_ids = ids_tensor([self.batch_size, self.num_choices], self.seq_length)
107 |
108 | sequence_labels = None
109 | token_labels = None
110 | choice_labels = None
111 | if self.use_labels:
112 | sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
113 | token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
114 | choice_labels = ids_tensor([self.batch_size], self.num_choices)
115 |
116 | config = CTRLConfig(
117 | vocab_size_or_config_json_file=self.vocab_size,
118 | n_embd=self.hidden_size,
119 | n_layer=self.num_hidden_layers,
120 | n_head=self.num_attention_heads,
121 | # intermediate_size=self.intermediate_size,
122 | # hidden_act=self.hidden_act,
123 | # hidden_dropout_prob=self.hidden_dropout_prob,
124 | # attention_probs_dropout_prob=self.attention_probs_dropout_prob,
125 | n_positions=self.max_position_embeddings,
126 | n_ctx=self.max_position_embeddings
127 | # type_vocab_size=self.type_vocab_size,
128 | # initializer_range=self.initializer_range
129 | )
130 |
131 | head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
132 |
133 | return config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, sequence_labels, token_labels, choice_labels
134 |
135 | def check_loss_output(self, result):
136 | self.parent.assertListEqual(
137 | list(result["loss"].size()),
138 | [])
139 |
140 | def create_and_check_ctrl_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
141 | model = CTRLModel(config=config)
142 | model.to(torch_device)
143 | model.eval()
144 |
145 | model(input_ids, token_type_ids=token_type_ids, head_mask=head_mask)
146 | model(input_ids, token_type_ids=token_type_ids)
147 | sequence_output, presents = model(input_ids)
148 |
149 | result = {
150 | "sequence_output": sequence_output,
151 | "presents": presents,
152 | }
153 | self.parent.assertListEqual(
154 | list(result["sequence_output"].size()),
155 | [self.batch_size, self.seq_length, self.hidden_size])
156 | self.parent.assertEqual(len(result["presents"]), config.n_layer)
157 |
158 | def create_and_check_lm_head_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
159 | model = CTRLLMHeadModel(config)
160 | model.to(torch_device)
161 | model.eval()
162 |
163 | loss, lm_logits, _ = model(input_ids, token_type_ids=token_type_ids, labels=input_ids)
164 |
165 | result = {
166 | "loss": loss,
167 | "lm_logits": lm_logits
168 | }
169 | self.parent.assertListEqual(
170 | list(result["loss"].size()),
171 | [])
172 | self.parent.assertListEqual(
173 | list(result["lm_logits"].size()),
174 | [self.batch_size, self.seq_length, self.vocab_size])
175 |
176 |
177 | def prepare_config_and_inputs_for_common(self):
178 | config_and_inputs = self.prepare_config_and_inputs()
179 |
180 | (config, input_ids, input_mask, head_mask, token_type_ids,
181 | mc_token_ids, sequence_labels, token_labels, choice_labels) = config_and_inputs
182 |
183 | inputs_dict = {
184 | 'input_ids': input_ids,
185 | 'token_type_ids': token_type_ids,
186 | 'head_mask': head_mask
187 | }
188 |
189 | return config, inputs_dict
190 |
191 | def setUp(self):
192 | self.model_tester = CTRLModelTest.CTRLModelTester(self)
193 | self.config_tester = ConfigTester(self, config_class=CTRLConfig, n_embd=37)
194 |
195 | def test_config(self):
196 | self.config_tester.run_common_tests()
197 |
198 | def test_ctrl_model(self):
199 | config_and_inputs = self.model_tester.prepare_config_and_inputs()
200 | self.model_tester.create_and_check_ctrl_model(*config_and_inputs)
201 |
202 | def test_ctrl_lm_head_model(self):
203 | config_and_inputs = self.model_tester.prepare_config_and_inputs()
204 | self.model_tester.create_and_check_lm_head_model(*config_and_inputs)
205 |
206 | @slow
207 | def test_model_from_pretrained(self):
208 | cache_dir = "/tmp/transformers_test/"
209 | for model_name in list(CTRL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
210 | model = CTRLModel.from_pretrained(model_name, cache_dir=cache_dir)
211 | shutil.rmtree(cache_dir)
212 | self.assertIsNotNone(model)
213 |
214 |
215 | if __name__ == "__main__":
216 | unittest.main()
217 |
--------------------------------------------------------------------------------
/transformers/tests/modeling_encoder_decoder_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Hugging Face Inc. Team
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import logging
17 | import unittest
18 |
19 | from transformers import is_torch_available
20 | from .utils import require_torch, slow
21 |
22 | if is_torch_available():
23 | from transformers import BertModel, BertForMaskedLM, Model2Model
24 | from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP
25 |
26 |
27 | @require_torch
28 | class EncoderDecoderModelTest(unittest.TestCase):
29 | @slow
30 | def test_model2model_from_pretrained(self):
31 | logging.basicConfig(level=logging.INFO)
32 | for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
33 | model = Model2Model.from_pretrained(model_name)
34 | self.assertIsInstance(model.encoder, BertModel)
35 | self.assertIsInstance(model.decoder, BertForMaskedLM)
36 | self.assertEqual(model.decoder.config.is_decoder, True)
37 | self.assertEqual(model.encoder.config.is_decoder, False)
38 |
39 | def test_model2model_from_pretrained_not_bert(self):
40 | logging.basicConfig(level=logging.INFO)
41 | with self.assertRaises(ValueError):
42 | _ = Model2Model.from_pretrained('roberta')
43 |
44 | with self.assertRaises(ValueError):
45 | _ = Model2Model.from_pretrained('distilbert')
46 |
47 | with self.assertRaises(ValueError):
48 | _ = Model2Model.from_pretrained('does-not-exist')
49 |
50 |
51 | if __name__ == "__main__":
52 | unittest.main()
53 |
--------------------------------------------------------------------------------
/transformers/tests/modeling_openai_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import unittest
20 | import shutil
21 |
22 | from transformers import is_torch_available
23 |
24 | if is_torch_available():
25 | from transformers import (OpenAIGPTConfig, OpenAIGPTModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
26 | OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel)
27 |
28 | from .modeling_common_test import (CommonTestCases, ids_tensor)
29 | from .configuration_common_test import ConfigTester
30 | from .utils import require_torch, slow, torch_device
31 |
32 |
33 | @require_torch
34 | class OpenAIGPTModelTest(CommonTestCases.CommonModelTester):
35 |
36 | all_model_classes = (OpenAIGPTModel, OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel) if is_torch_available() else ()
37 |
38 | class OpenAIGPTModelTester(object):
39 |
40 | def __init__(self,
41 | parent,
42 | batch_size=13,
43 | seq_length=7,
44 | is_training=True,
45 | use_token_type_ids=True,
46 | use_labels=True,
47 | vocab_size=99,
48 | hidden_size=32,
49 | num_hidden_layers=5,
50 | num_attention_heads=4,
51 | intermediate_size=37,
52 | hidden_act="gelu",
53 | hidden_dropout_prob=0.1,
54 | attention_probs_dropout_prob=0.1,
55 | max_position_embeddings=512,
56 | type_vocab_size=16,
57 | type_sequence_label_size=2,
58 | initializer_range=0.02,
59 | num_labels=3,
60 | num_choices=4,
61 | scope=None,
62 | ):
63 | self.parent = parent
64 | self.batch_size = batch_size
65 | self.seq_length = seq_length
66 | self.is_training = is_training
67 | self.use_token_type_ids = use_token_type_ids
68 | self.use_labels = use_labels
69 | self.vocab_size = vocab_size
70 | self.hidden_size = hidden_size
71 | self.num_hidden_layers = num_hidden_layers
72 | self.num_attention_heads = num_attention_heads
73 | self.intermediate_size = intermediate_size
74 | self.hidden_act = hidden_act
75 | self.hidden_dropout_prob = hidden_dropout_prob
76 | self.attention_probs_dropout_prob = attention_probs_dropout_prob
77 | self.max_position_embeddings = max_position_embeddings
78 | self.type_vocab_size = type_vocab_size
79 | self.type_sequence_label_size = type_sequence_label_size
80 | self.initializer_range = initializer_range
81 | self.num_labels = num_labels
82 | self.num_choices = num_choices
83 | self.scope = scope
84 |
85 | def prepare_config_and_inputs(self):
86 | input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
87 |
88 | token_type_ids = None
89 | if self.use_token_type_ids:
90 | token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
91 |
92 | sequence_labels = None
93 | token_labels = None
94 | choice_labels = None
95 | if self.use_labels:
96 | sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
97 | token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
98 | choice_labels = ids_tensor([self.batch_size], self.num_choices)
99 |
100 | config = OpenAIGPTConfig(
101 | vocab_size_or_config_json_file=self.vocab_size,
102 | n_embd=self.hidden_size,
103 | n_layer=self.num_hidden_layers,
104 | n_head=self.num_attention_heads,
105 | # intermediate_size=self.intermediate_size,
106 | # hidden_act=self.hidden_act,
107 | # hidden_dropout_prob=self.hidden_dropout_prob,
108 | # attention_probs_dropout_prob=self.attention_probs_dropout_prob,
109 | n_positions=self.max_position_embeddings,
110 | n_ctx=self.max_position_embeddings
111 | # type_vocab_size=self.type_vocab_size,
112 | # initializer_range=self.initializer_range
113 | )
114 |
115 | head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
116 |
117 | return config, input_ids, head_mask, token_type_ids, sequence_labels, token_labels, choice_labels
118 |
119 | def check_loss_output(self, result):
120 | self.parent.assertListEqual(
121 | list(result["loss"].size()),
122 | [])
123 |
124 | def create_and_check_openai_gpt_model(self, config, input_ids, head_mask, token_type_ids, *args):
125 | model = OpenAIGPTModel(config=config)
126 | model.to(torch_device)
127 | model.eval()
128 |
129 | model(input_ids, token_type_ids=token_type_ids, head_mask=head_mask)
130 | model(input_ids, token_type_ids=token_type_ids)
131 | (sequence_output,) = model(input_ids)
132 |
133 | result = {
134 | "sequence_output": sequence_output
135 | }
136 | self.parent.assertListEqual(
137 | list(result["sequence_output"].size()),
138 | [self.batch_size, self.seq_length, self.hidden_size])
139 |
140 | def create_and_check_lm_head_model(self, config, input_ids, head_mask, token_type_ids, *args):
141 | model = OpenAIGPTLMHeadModel(config)
142 | model.to(torch_device)
143 | model.eval()
144 |
145 | loss, lm_logits = model(input_ids, token_type_ids=token_type_ids, labels=input_ids)
146 |
147 | result = {
148 | "loss": loss,
149 | "lm_logits": lm_logits
150 | }
151 |
152 | self.parent.assertListEqual(
153 | list(result["loss"].size()),
154 | [])
155 | self.parent.assertListEqual(
156 | list(result["lm_logits"].size()),
157 | [self.batch_size, self.seq_length, self.vocab_size])
158 |
159 | def create_and_check_double_lm_head_model(self, config, input_ids, head_mask, token_type_ids, *args):
160 | model = OpenAIGPTDoubleHeadsModel(config)
161 | model.to(torch_device)
162 | model.eval()
163 |
164 | loss, lm_logits, mc_logits = model(input_ids, token_type_ids=token_type_ids, lm_labels=input_ids)
165 |
166 | result = {
167 | "loss": loss,
168 | "lm_logits": lm_logits
169 | }
170 |
171 | self.parent.assertListEqual(
172 | list(result["loss"].size()),
173 | [])
174 | self.parent.assertListEqual(
175 | list(result["lm_logits"].size()),
176 | [self.batch_size, self.seq_length, self.vocab_size])
177 |
178 | def prepare_config_and_inputs_for_common(self):
179 | config_and_inputs = self.prepare_config_and_inputs()
180 | (config, input_ids, head_mask, token_type_ids, sequence_labels, token_labels, choice_labels) = config_and_inputs
181 | inputs_dict = {
182 | 'input_ids': input_ids,
183 | 'token_type_ids': token_type_ids,
184 | 'head_mask': head_mask
185 | }
186 |
187 | return config, inputs_dict
188 |
189 | def setUp(self):
190 | self.model_tester = OpenAIGPTModelTest.OpenAIGPTModelTester(self)
191 | self.config_tester = ConfigTester(self, config_class=OpenAIGPTConfig, n_embd=37)
192 |
193 | def test_config(self):
194 | self.config_tester.run_common_tests()
195 |
196 | def test_openai_gpt_model(self):
197 | config_and_inputs = self.model_tester.prepare_config_and_inputs()
198 | self.model_tester.create_and_check_openai_gpt_model(*config_and_inputs)
199 |
200 | def test_openai_gpt_lm_head_model(self):
201 | config_and_inputs = self.model_tester.prepare_config_and_inputs()
202 | self.model_tester.create_and_check_lm_head_model(*config_and_inputs)
203 |
204 | def test_openai_gpt_double_lm_head_model(self):
205 | config_and_inputs = self.model_tester.prepare_config_and_inputs()
206 | self.model_tester.create_and_check_double_lm_head_model(*config_and_inputs)
207 |
208 | @slow
209 | def test_model_from_pretrained(self):
210 | cache_dir = "/tmp/transformers_test/"
211 | for model_name in list(OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
212 | model = OpenAIGPTModel.from_pretrained(model_name, cache_dir=cache_dir)
213 | shutil.rmtree(cache_dir)
214 | self.assertIsNotNone(model)
215 |
216 |
217 | if __name__ == "__main__":
218 | unittest.main()
219 |
--------------------------------------------------------------------------------
/transformers/tests/modeling_tf_auto_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import unittest
20 | import shutil
21 | import logging
22 |
23 | from transformers import is_tf_available
24 |
25 | from .utils import require_tf, slow, SMALL_MODEL_IDENTIFIER
26 |
27 | if is_tf_available():
28 | from transformers import (AutoConfig, BertConfig,
29 | TFAutoModel, TFBertModel,
30 | TFAutoModelWithLMHead, TFBertForMaskedLM,
31 | TFAutoModelForSequenceClassification, TFBertForSequenceClassification,
32 | TFAutoModelForQuestionAnswering, TFBertForQuestionAnswering)
33 | from transformers.modeling_tf_bert import TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP
34 |
35 | from .modeling_common_test import (CommonTestCases, ids_tensor)
36 | from .configuration_common_test import ConfigTester
37 |
38 |
39 | @require_tf
40 | class TFAutoModelTest(unittest.TestCase):
41 | @slow
42 | def test_model_from_pretrained(self):
43 | import h5py
44 | self.assertTrue(h5py.version.hdf5_version.startswith("1.10"))
45 |
46 | logging.basicConfig(level=logging.INFO)
47 | # for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
48 | for model_name in ['bert-base-uncased']:
49 | config = AutoConfig.from_pretrained(model_name, force_download=True)
50 | self.assertIsNotNone(config)
51 | self.assertIsInstance(config, BertConfig)
52 |
53 | model = TFAutoModel.from_pretrained(model_name, force_download=True)
54 | self.assertIsNotNone(model)
55 | self.assertIsInstance(model, TFBertModel)
56 |
57 | @slow
58 | def test_lmhead_model_from_pretrained(self):
59 | logging.basicConfig(level=logging.INFO)
60 | # for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
61 | for model_name in ['bert-base-uncased']:
62 | config = AutoConfig.from_pretrained(model_name, force_download=True)
63 | self.assertIsNotNone(config)
64 | self.assertIsInstance(config, BertConfig)
65 |
66 | model = TFAutoModelWithLMHead.from_pretrained(model_name, force_download=True)
67 | self.assertIsNotNone(model)
68 | self.assertIsInstance(model, TFBertForMaskedLM)
69 |
70 | @slow
71 | def test_sequence_classification_model_from_pretrained(self):
72 | logging.basicConfig(level=logging.INFO)
73 | # for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
74 | for model_name in ['bert-base-uncased']:
75 | config = AutoConfig.from_pretrained(model_name, force_download=True)
76 | self.assertIsNotNone(config)
77 | self.assertIsInstance(config, BertConfig)
78 |
79 | model = TFAutoModelForSequenceClassification.from_pretrained(model_name, force_download=True)
80 | self.assertIsNotNone(model)
81 | self.assertIsInstance(model, TFBertForSequenceClassification)
82 |
83 | @slow
84 | def test_question_answering_model_from_pretrained(self):
85 | logging.basicConfig(level=logging.INFO)
86 | # for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
87 | for model_name in ['bert-base-uncased']:
88 | config = AutoConfig.from_pretrained(model_name, force_download=True)
89 | self.assertIsNotNone(config)
90 | self.assertIsInstance(config, BertConfig)
91 |
92 | model = TFAutoModelForQuestionAnswering.from_pretrained(model_name, force_download=True)
93 | self.assertIsNotNone(model)
94 | self.assertIsInstance(model, TFBertForQuestionAnswering)
95 |
96 | def test_from_pretrained_identifier(self):
97 | logging.basicConfig(level=logging.INFO)
98 | model = TFAutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER, force_download=True)
99 | self.assertIsInstance(model, TFBertForMaskedLM)
100 |
101 |
102 | if __name__ == "__main__":
103 | unittest.main()
104 |
--------------------------------------------------------------------------------
/transformers/tests/modeling_tf_ctrl_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import unittest
20 | import shutil
21 | import sys
22 |
23 | from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
24 | from .configuration_common_test import ConfigTester
25 | from .utils import require_tf, slow
26 |
27 | from transformers import CTRLConfig, is_tf_available
28 |
29 | if is_tf_available():
30 | import tensorflow as tf
31 | from transformers.modeling_tf_ctrl import (TFCTRLModel, TFCTRLLMHeadModel,
32 | TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP)
33 |
34 |
35 | @require_tf
36 | class TFCTRLModelTest(TFCommonTestCases.TFCommonModelTester):
37 |
38 | all_model_classes = (TFCTRLModel, TFCTRLLMHeadModel) if is_tf_available() else ()
39 |
40 | class TFCTRLModelTester(object):
41 |
42 | def __init__(self,
43 | parent,
44 | batch_size=13,
45 | seq_length=7,
46 | is_training=True,
47 | use_token_type_ids=True,
48 | use_input_mask=True,
49 | use_labels=True,
50 | use_mc_token_ids=True,
51 | vocab_size=99,
52 | hidden_size=32,
53 | num_hidden_layers=5,
54 | num_attention_heads=4,
55 | intermediate_size=37,
56 | hidden_act="gelu",
57 | hidden_dropout_prob=0.1,
58 | attention_probs_dropout_prob=0.1,
59 | max_position_embeddings=512,
60 | type_vocab_size=16,
61 | type_sequence_label_size=2,
62 | initializer_range=0.02,
63 | num_labels=3,
64 | num_choices=4,
65 | scope=None,
66 | ):
67 | self.parent = parent
68 | self.batch_size = batch_size
69 | self.seq_length = seq_length
70 | self.is_training = is_training
71 | self.use_token_type_ids = use_token_type_ids
72 | self.use_input_mask = use_input_mask
73 | self.use_labels = use_labels
74 | self.use_mc_token_ids = use_mc_token_ids
75 | self.vocab_size = vocab_size
76 | self.hidden_size = hidden_size
77 | self.num_hidden_layers = num_hidden_layers
78 | self.num_attention_heads = num_attention_heads
79 | self.intermediate_size = intermediate_size
80 | self.hidden_act = hidden_act
81 | self.hidden_dropout_prob = hidden_dropout_prob
82 | self.attention_probs_dropout_prob = attention_probs_dropout_prob
83 | self.max_position_embeddings = max_position_embeddings
84 | self.type_vocab_size = type_vocab_size
85 | self.type_sequence_label_size = type_sequence_label_size
86 | self.initializer_range = initializer_range
87 | self.num_labels = num_labels
88 | self.num_choices = num_choices
89 | self.scope = scope
90 |
91 | def prepare_config_and_inputs(self):
92 | input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
93 |
94 | input_mask = None
95 | if self.use_input_mask:
96 | input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
97 |
98 | token_type_ids = None
99 | if self.use_token_type_ids:
100 | token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
101 |
102 | mc_token_ids = None
103 | if self.use_mc_token_ids:
104 | mc_token_ids = ids_tensor([self.batch_size, self.num_choices], self.seq_length)
105 |
106 | sequence_labels = None
107 | token_labels = None
108 | choice_labels = None
109 | if self.use_labels:
110 | sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
111 | token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
112 | choice_labels = ids_tensor([self.batch_size], self.num_choices)
113 |
114 | config = CTRLConfig(
115 | vocab_size_or_config_json_file=self.vocab_size,
116 | n_embd=self.hidden_size,
117 | n_layer=self.num_hidden_layers,
118 | n_head=self.num_attention_heads,
119 | # intermediate_size=self.intermediate_size,
120 | # hidden_act=self.hidden_act,
121 | # hidden_dropout_prob=self.hidden_dropout_prob,
122 | # attention_probs_dropout_prob=self.attention_probs_dropout_prob,
123 | n_positions=self.max_position_embeddings,
124 | n_ctx=self.max_position_embeddings
125 | # type_vocab_size=self.type_vocab_size,
126 | # initializer_range=self.initializer_range
127 | )
128 |
129 | head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
130 |
131 | return config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, sequence_labels, token_labels, choice_labels
132 |
133 | def create_and_check_ctrl_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
134 | model = TFCTRLModel(config=config)
135 | inputs = {'input_ids': input_ids,
136 | 'attention_mask': input_mask,
137 | 'token_type_ids': token_type_ids}
138 | sequence_output = model(inputs)[0]
139 |
140 | inputs = [input_ids, None, input_mask] # None is the input for 'past'
141 | sequence_output = model(inputs)[0]
142 |
143 | sequence_output = model(input_ids)[0]
144 |
145 | result = {
146 | "sequence_output": sequence_output.numpy(),
147 | }
148 | self.parent.assertListEqual(
149 | list(result["sequence_output"].shape),
150 | [self.batch_size, self.seq_length, self.hidden_size])
151 |
152 |
153 | def create_and_check_ctrl_lm_head(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
154 | model = TFCTRLLMHeadModel(config=config)
155 | inputs = {'input_ids': input_ids,
156 | 'attention_mask': input_mask,
157 | 'token_type_ids': token_type_ids}
158 | prediction_scores = model(inputs)[0]
159 | result = {
160 | "prediction_scores": prediction_scores.numpy(),
161 | }
162 | self.parent.assertListEqual(
163 | list(result["prediction_scores"].shape),
164 | [self.batch_size, self.seq_length, self.vocab_size])
165 |
166 | def prepare_config_and_inputs_for_common(self):
167 | config_and_inputs = self.prepare_config_and_inputs()
168 |
169 | (config, input_ids, input_mask, head_mask, token_type_ids,
170 | mc_token_ids, sequence_labels, token_labels, choice_labels) = config_and_inputs
171 |
172 | inputs_dict = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'attention_mask': input_mask}
173 | return config, inputs_dict
174 |
175 | def setUp(self):
176 | self.model_tester = TFCTRLModelTest.TFCTRLModelTester(self)
177 | self.config_tester = ConfigTester(self, config_class=CTRLConfig, n_embd=37)
178 |
179 | def test_config(self):
180 | self.config_tester.run_common_tests()
181 |
182 | def test_ctrl_model(self):
183 | config_and_inputs = self.model_tester.prepare_config_and_inputs()
184 | self.model_tester.create_and_check_ctrl_model(*config_and_inputs)
185 |
186 | def test_ctrl_lm_head(self):
187 | config_and_inputs = self.model_tester.prepare_config_and_inputs()
188 | self.model_tester.create_and_check_ctrl_lm_head(*config_and_inputs)
189 |
190 | @slow
191 | def test_model_from_pretrained(self):
192 | cache_dir = "/tmp/transformers_test/"
193 | for model_name in list(TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
194 | model = TFCTRLModel.from_pretrained(model_name, cache_dir=cache_dir)
195 | shutil.rmtree(cache_dir)
196 | self.assertIsNotNone(model)
197 |
198 | if __name__ == "__main__":
199 | unittest.main()
200 |
201 |
--------------------------------------------------------------------------------
/transformers/tests/modeling_tf_distilbert_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import unittest
20 |
21 | from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
22 | from .configuration_common_test import ConfigTester
23 | from .utils import require_tf, slow
24 |
25 | from transformers import DistilBertConfig, is_tf_available
26 |
27 | if is_tf_available():
28 | import tensorflow as tf
29 | from transformers.modeling_tf_distilbert import (TFDistilBertModel,
30 | TFDistilBertForMaskedLM,
31 | TFDistilBertForQuestionAnswering,
32 | TFDistilBertForSequenceClassification)
33 |
34 |
35 | @require_tf
36 | class TFDistilBertModelTest(TFCommonTestCases.TFCommonModelTester):
37 |
38 | all_model_classes = (TFDistilBertModel, TFDistilBertForMaskedLM, TFDistilBertForQuestionAnswering,
39 | TFDistilBertForSequenceClassification) if is_tf_available() else None
40 | test_pruning = True
41 | test_torchscript = True
42 | test_resize_embeddings = True
43 | test_head_masking = True
44 |
45 | class TFDistilBertModelTester(object):
46 |
47 | def __init__(self,
48 | parent,
49 | batch_size=13,
50 | seq_length=7,
51 | is_training=True,
52 | use_input_mask=True,
53 | use_token_type_ids=False,
54 | use_labels=True,
55 | vocab_size=99,
56 | hidden_size=32,
57 | num_hidden_layers=5,
58 | num_attention_heads=4,
59 | intermediate_size=37,
60 | hidden_act="gelu",
61 | hidden_dropout_prob=0.1,
62 | attention_probs_dropout_prob=0.1,
63 | max_position_embeddings=512,
64 | type_vocab_size=16,
65 | type_sequence_label_size=2,
66 | initializer_range=0.02,
67 | num_labels=3,
68 | num_choices=4,
69 | scope=None,
70 | ):
71 | self.parent = parent
72 | self.batch_size = batch_size
73 | self.seq_length = seq_length
74 | self.is_training = is_training
75 | self.use_input_mask = use_input_mask
76 | self.use_token_type_ids = use_token_type_ids
77 | self.use_labels = use_labels
78 | self.vocab_size = vocab_size
79 | self.hidden_size = hidden_size
80 | self.num_hidden_layers = num_hidden_layers
81 | self.num_attention_heads = num_attention_heads
82 | self.intermediate_size = intermediate_size
83 | self.hidden_act = hidden_act
84 | self.hidden_dropout_prob = hidden_dropout_prob
85 | self.attention_probs_dropout_prob = attention_probs_dropout_prob
86 | self.max_position_embeddings = max_position_embeddings
87 | self.type_vocab_size = type_vocab_size
88 | self.type_sequence_label_size = type_sequence_label_size
89 | self.initializer_range = initializer_range
90 | self.num_labels = num_labels
91 | self.num_choices = num_choices
92 | self.scope = scope
93 |
94 | def prepare_config_and_inputs(self):
95 | input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
96 |
97 | input_mask = None
98 | if self.use_input_mask:
99 | input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
100 |
101 | sequence_labels = None
102 | token_labels = None
103 | choice_labels = None
104 | if self.use_labels:
105 | sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
106 | token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
107 | choice_labels = ids_tensor([self.batch_size], self.num_choices)
108 |
109 | config = DistilBertConfig(
110 | vocab_size_or_config_json_file=self.vocab_size,
111 | dim=self.hidden_size,
112 | n_layers=self.num_hidden_layers,
113 | n_heads=self.num_attention_heads,
114 | hidden_dim=self.intermediate_size,
115 | hidden_act=self.hidden_act,
116 | dropout=self.hidden_dropout_prob,
117 | attention_dropout=self.attention_probs_dropout_prob,
118 | max_position_embeddings=self.max_position_embeddings,
119 | initializer_range=self.initializer_range)
120 |
121 | return config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
122 |
123 | def create_and_check_distilbert_model(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
124 | model = TFDistilBertModel(config=config)
125 | inputs = {'input_ids': input_ids,
126 | 'attention_mask': input_mask}
127 |
128 | outputs = model(inputs)
129 | sequence_output = outputs[0]
130 |
131 | inputs = [input_ids, input_mask]
132 |
133 | (sequence_output,) = model(inputs)
134 |
135 | result = {
136 | "sequence_output": sequence_output.numpy(),
137 | }
138 | self.parent.assertListEqual(
139 | list(result["sequence_output"].shape),
140 | [self.batch_size, self.seq_length, self.hidden_size])
141 |
142 | def create_and_check_distilbert_for_masked_lm(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
143 | model = TFDistilBertForMaskedLM(config=config)
144 | inputs = {'input_ids': input_ids,
145 | 'attention_mask': input_mask}
146 | (prediction_scores,) = model(inputs)
147 | result = {
148 | "prediction_scores": prediction_scores.numpy(),
149 | }
150 | self.parent.assertListEqual(
151 | list(result["prediction_scores"].shape),
152 | [self.batch_size, self.seq_length, self.vocab_size])
153 |
154 | def create_and_check_distilbert_for_question_answering(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
155 | model = TFDistilBertForQuestionAnswering(config=config)
156 | inputs = {'input_ids': input_ids,
157 | 'attention_mask': input_mask}
158 | start_logits, end_logits = model(inputs)
159 | result = {
160 | "start_logits": start_logits.numpy(),
161 | "end_logits": end_logits.numpy(),
162 | }
163 | self.parent.assertListEqual(
164 | list(result["start_logits"].shape),
165 | [self.batch_size, self.seq_length])
166 | self.parent.assertListEqual(
167 | list(result["end_logits"].shape),
168 | [self.batch_size, self.seq_length])
169 |
170 | def create_and_check_distilbert_for_sequence_classification(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
171 | config.num_labels = self.num_labels
172 | model = TFDistilBertForSequenceClassification(config)
173 | inputs = {'input_ids': input_ids,
174 | 'attention_mask': input_mask}
175 | (logits,) = model(inputs)
176 | result = {
177 | "logits": logits.numpy(),
178 | }
179 | self.parent.assertListEqual(
180 | list(result["logits"].shape),
181 | [self.batch_size, self.num_labels])
182 |
183 | def prepare_config_and_inputs_for_common(self):
184 | config_and_inputs = self.prepare_config_and_inputs()
185 | (config, input_ids, input_mask, sequence_labels, token_labels, choice_labels) = config_and_inputs
186 | inputs_dict = {'input_ids': input_ids, 'attention_mask': input_mask}
187 | return config, inputs_dict
188 |
189 | def setUp(self):
190 | self.model_tester = TFDistilBertModelTest.TFDistilBertModelTester(self)
191 | self.config_tester = ConfigTester(self, config_class=DistilBertConfig, dim=37)
192 |
193 | def test_config(self):
194 | self.config_tester.run_common_tests()
195 |
196 | def test_distilbert_model(self):
197 | config_and_inputs = self.model_tester.prepare_config_and_inputs()
198 | self.model_tester.create_and_check_distilbert_model(*config_and_inputs)
199 |
200 | def test_for_masked_lm(self):
201 | config_and_inputs = self.model_tester.prepare_config_and_inputs()
202 | self.model_tester.create_and_check_distilbert_for_masked_lm(*config_and_inputs)
203 |
204 | def test_for_question_answering(self):
205 | config_and_inputs = self.model_tester.prepare_config_and_inputs()
206 | self.model_tester.create_and_check_distilbert_for_question_answering(*config_and_inputs)
207 |
208 | def test_for_sequence_classification(self):
209 | config_and_inputs = self.model_tester.prepare_config_and_inputs()
210 | self.model_tester.create_and_check_distilbert_for_sequence_classification(*config_and_inputs)
211 |
212 | # @slow
213 | # def test_model_from_pretrained(self):
214 | # cache_dir = "/tmp/transformers_test/"
215 | # for model_name in list(DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
216 | # model = DistilBertModel.from_pretrained(model_name, cache_dir=cache_dir)
217 | # shutil.rmtree(cache_dir)
218 | # self.assertIsNotNone(model)
219 |
220 | if __name__ == "__main__":
221 | unittest.main()
222 |
--------------------------------------------------------------------------------
/transformers/tests/modeling_tf_transfo_xl_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import unittest
20 | import random
21 | import shutil
22 |
23 | from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
24 | from .configuration_common_test import ConfigTester
25 | from .utils import require_tf, slow
26 |
27 | from transformers import TransfoXLConfig, is_tf_available
28 |
29 | if is_tf_available():
30 | import tensorflow as tf
31 | from transformers.modeling_tf_transfo_xl import (TFTransfoXLModel,
32 | TFTransfoXLLMHeadModel,
33 | TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP)
34 |
35 |
36 | @require_tf
37 | class TFTransfoXLModelTest(TFCommonTestCases.TFCommonModelTester):
38 |
39 | all_model_classes = (TFTransfoXLModel, TFTransfoXLLMHeadModel) if is_tf_available() else ()
40 | test_pruning = False
41 | test_torchscript = False
42 | test_resize_embeddings = False
43 |
44 | class TFTransfoXLModelTester(object):
45 |
46 | def __init__(self,
47 | parent,
48 | batch_size=13,
49 | seq_length=7,
50 | mem_len=30,
51 | clamp_len=15,
52 | is_training=True,
53 | use_labels=True,
54 | vocab_size=99,
55 | cutoffs=[10, 50, 80],
56 | hidden_size=32,
57 | d_embed=32,
58 | num_attention_heads=4,
59 | d_head=8,
60 | d_inner=128,
61 | div_val=2,
62 | num_hidden_layers=5,
63 | scope=None,
64 | seed=1,
65 | ):
66 | self.parent = parent
67 | self.batch_size = batch_size
68 | self.seq_length = seq_length
69 | self.mem_len = mem_len
70 | self.key_len = seq_length + mem_len
71 | self.clamp_len = clamp_len
72 | self.is_training = is_training
73 | self.use_labels = use_labels
74 | self.vocab_size = vocab_size
75 | self.cutoffs = cutoffs
76 | self.hidden_size = hidden_size
77 | self.d_embed = d_embed
78 | self.num_attention_heads = num_attention_heads
79 | self.d_head = d_head
80 | self.d_inner = d_inner
81 | self.div_val = div_val
82 | self.num_hidden_layers = num_hidden_layers
83 | self.scope = scope
84 | self.seed = seed
85 |
86 | def prepare_config_and_inputs(self):
87 | input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
88 | input_ids_2 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
89 |
90 | lm_labels = None
91 | if self.use_labels:
92 | lm_labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
93 |
94 | config = TransfoXLConfig(
95 | vocab_size_or_config_json_file=self.vocab_size,
96 | mem_len=self.mem_len,
97 | clamp_len=self.clamp_len,
98 | cutoffs=self.cutoffs,
99 | d_model=self.hidden_size,
100 | d_embed=self.d_embed,
101 | n_head=self.num_attention_heads,
102 | d_head=self.d_head,
103 | d_inner=self.d_inner,
104 | div_val=self.div_val,
105 | n_layer=self.num_hidden_layers)
106 |
107 | return (config, input_ids_1, input_ids_2, lm_labels)
108 |
109 | def set_seed(self):
110 | random.seed(self.seed)
111 | tf.random.set_seed(self.seed)
112 |
113 | def create_and_check_transfo_xl_model(self, config, input_ids_1, input_ids_2, lm_labels):
114 | model = TFTransfoXLModel(config)
115 |
116 | hidden_states_1, mems_1 = model(input_ids_1)
117 |
118 | inputs = {'input_ids': input_ids_2,
119 | 'mems': mems_1}
120 |
121 | hidden_states_2, mems_2 = model(inputs)
122 |
123 | result = {
124 | "hidden_states_1": hidden_states_1.numpy(),
125 | "mems_1": [mem.numpy() for mem in mems_1],
126 | "hidden_states_2": hidden_states_2.numpy(),
127 | "mems_2": [mem.numpy() for mem in mems_2],
128 | }
129 |
130 | self.parent.assertListEqual(
131 | list(result["hidden_states_1"].shape),
132 | [self.batch_size, self.seq_length, self.hidden_size])
133 | self.parent.assertListEqual(
134 | list(result["hidden_states_2"].shape),
135 | [self.batch_size, self.seq_length, self.hidden_size])
136 | self.parent.assertListEqual(
137 | list(list(mem.shape) for mem in result["mems_1"]),
138 | [[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
139 | self.parent.assertListEqual(
140 | list(list(mem.shape) for mem in result["mems_2"]),
141 | [[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
142 |
143 |
144 | def create_and_check_transfo_xl_lm_head(self, config, input_ids_1, input_ids_2, lm_labels):
145 | model = TFTransfoXLLMHeadModel(config)
146 |
147 | lm_logits_1, mems_1 = model(input_ids_1)
148 |
149 | inputs = {'input_ids': input_ids_1,
150 | 'labels': lm_labels}
151 | _, mems_1 = model(inputs)
152 |
153 | lm_logits_2, mems_2 = model([input_ids_2, mems_1])
154 |
155 | inputs = {'input_ids': input_ids_1,
156 | 'mems': mems_1,
157 | 'labels': lm_labels}
158 |
159 | _, mems_2 = model(inputs)
160 |
161 | result = {
162 | "mems_1": [mem.numpy() for mem in mems_1],
163 | "lm_logits_1": lm_logits_1.numpy(),
164 | "mems_2": [mem.numpy() for mem in mems_2],
165 | "lm_logits_2": lm_logits_2.numpy(),
166 | }
167 |
168 | self.parent.assertListEqual(
169 | list(result["lm_logits_1"].shape),
170 | [self.batch_size, self.seq_length, self.vocab_size])
171 | self.parent.assertListEqual(
172 | list(list(mem.shape) for mem in result["mems_1"]),
173 | [[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
174 |
175 | self.parent.assertListEqual(
176 | list(result["lm_logits_2"].shape),
177 | [self.batch_size, self.seq_length, self.vocab_size])
178 | self.parent.assertListEqual(
179 | list(list(mem.shape) for mem in result["mems_2"]),
180 | [[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
181 |
182 | def prepare_config_and_inputs_for_common(self):
183 | config_and_inputs = self.prepare_config_and_inputs()
184 | (config, input_ids_1, input_ids_2, lm_labels) = config_and_inputs
185 | inputs_dict = {'input_ids': input_ids_1}
186 | return config, inputs_dict
187 |
188 |
189 | def setUp(self):
190 | self.model_tester = TFTransfoXLModelTest.TFTransfoXLModelTester(self)
191 | self.config_tester = ConfigTester(self, config_class=TransfoXLConfig, d_embed=37)
192 |
193 | def test_config(self):
194 | self.config_tester.run_common_tests()
195 |
196 | def test_transfo_xl_model(self):
197 | self.model_tester.set_seed()
198 | config_and_inputs = self.model_tester.prepare_config_and_inputs()
199 | self.model_tester.create_and_check_transfo_xl_model(*config_and_inputs)
200 |
201 | def test_transfo_xl_lm_head(self):
202 | self.model_tester.set_seed()
203 | config_and_inputs = self.model_tester.prepare_config_and_inputs()
204 | self.model_tester.create_and_check_transfo_xl_lm_head(*config_and_inputs)
205 |
206 | @slow
207 | def test_model_from_pretrained(self):
208 | cache_dir = "/tmp/transformers_test/"
209 | for model_name in list(TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
210 | model = TFTransfoXLModel.from_pretrained(model_name, cache_dir=cache_dir)
211 | shutil.rmtree(cache_dir)
212 | self.assertIsNotNone(model)
213 |
214 |
215 | if __name__ == "__main__":
216 | unittest.main()
217 |
--------------------------------------------------------------------------------
/transformers/tests/modeling_transfo_xl_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import unittest
20 | import random
21 | import shutil
22 |
23 | from transformers import is_torch_available
24 |
25 | if is_torch_available():
26 | import torch
27 | from transformers import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel)
28 | from transformers.modeling_transfo_xl import TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
29 |
30 | from .modeling_common_test import (CommonTestCases, ids_tensor)
31 | from .configuration_common_test import ConfigTester
32 | from .utils import require_torch, slow, torch_device
33 |
34 |
35 | @require_torch
36 | class TransfoXLModelTest(CommonTestCases.CommonModelTester):
37 |
38 | all_model_classes = (TransfoXLModel, TransfoXLLMHeadModel) if is_torch_available() else ()
39 | test_pruning = False
40 | test_torchscript = False
41 | test_resize_embeddings = False
42 |
43 | class TransfoXLModelTester(object):
44 |
45 | def __init__(self,
46 | parent,
47 | batch_size=13,
48 | seq_length=7,
49 | mem_len=30,
50 | clamp_len=15,
51 | is_training=True,
52 | use_labels=True,
53 | vocab_size=99,
54 | cutoffs=[10, 50, 80],
55 | hidden_size=32,
56 | d_embed=32,
57 | num_attention_heads=4,
58 | d_head=8,
59 | d_inner=128,
60 | div_val=2,
61 | num_hidden_layers=5,
62 | scope=None,
63 | seed=1,
64 | ):
65 | self.parent = parent
66 | self.batch_size = batch_size
67 | self.seq_length = seq_length
68 | self.mem_len = mem_len
69 | self.key_len = seq_length + mem_len
70 | self.clamp_len = clamp_len
71 | self.is_training = is_training
72 | self.use_labels = use_labels
73 | self.vocab_size = vocab_size
74 | self.cutoffs = cutoffs
75 | self.hidden_size = hidden_size
76 | self.d_embed = d_embed
77 | self.num_attention_heads = num_attention_heads
78 | self.d_head = d_head
79 | self.d_inner = d_inner
80 | self.div_val = div_val
81 | self.num_hidden_layers = num_hidden_layers
82 | self.scope = scope
83 | self.seed = seed
84 |
85 | def prepare_config_and_inputs(self):
86 | input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
87 | input_ids_2 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
88 |
89 | lm_labels = None
90 | if self.use_labels:
91 | lm_labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
92 |
93 | config = TransfoXLConfig(
94 | vocab_size_or_config_json_file=self.vocab_size,
95 | mem_len=self.mem_len,
96 | clamp_len=self.clamp_len,
97 | cutoffs=self.cutoffs,
98 | d_model=self.hidden_size,
99 | d_embed=self.d_embed,
100 | n_head=self.num_attention_heads,
101 | d_head=self.d_head,
102 | d_inner=self.d_inner,
103 | div_val=self.div_val,
104 | n_layer=self.num_hidden_layers)
105 |
106 | return (config, input_ids_1, input_ids_2, lm_labels)
107 |
108 | def set_seed(self):
109 | random.seed(self.seed)
110 | torch.manual_seed(self.seed)
111 |
112 | def create_transfo_xl_model(self, config, input_ids_1, input_ids_2, lm_labels):
113 | model = TransfoXLModel(config)
114 | model.to(torch_device)
115 | model.eval()
116 |
117 | hidden_states_1, mems_1 = model(input_ids_1)
118 | hidden_states_2, mems_2 = model(input_ids_2, mems_1)
119 | outputs = {
120 | "hidden_states_1": hidden_states_1,
121 | "mems_1": mems_1,
122 | "hidden_states_2": hidden_states_2,
123 | "mems_2": mems_2,
124 | }
125 | return outputs
126 |
127 | def check_transfo_xl_model_output(self, result):
128 | self.parent.assertListEqual(
129 | list(result["hidden_states_1"].size()),
130 | [self.batch_size, self.seq_length, self.hidden_size])
131 | self.parent.assertListEqual(
132 | list(result["hidden_states_2"].size()),
133 | [self.batch_size, self.seq_length, self.hidden_size])
134 | self.parent.assertListEqual(
135 | list(list(mem.size()) for mem in result["mems_1"]),
136 | [[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
137 | self.parent.assertListEqual(
138 | list(list(mem.size()) for mem in result["mems_2"]),
139 | [[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
140 |
141 |
142 | def create_transfo_xl_lm_head(self, config, input_ids_1, input_ids_2, lm_labels):
143 | model = TransfoXLLMHeadModel(config)
144 | model.to(torch_device)
145 | model.eval()
146 |
147 | lm_logits_1, mems_1 = model(input_ids_1)
148 | loss_1, _, mems_1 = model(input_ids_1, labels=lm_labels)
149 | lm_logits_2, mems_2 = model(input_ids_2, mems=mems_1)
150 | loss_2, _, mems_2 = model(input_ids_2, labels=lm_labels, mems=mems_1)
151 |
152 | outputs = {
153 | "loss_1": loss_1,
154 | "mems_1": mems_1,
155 | "lm_logits_1": lm_logits_1,
156 | "loss_2": loss_2,
157 | "mems_2": mems_2,
158 | "lm_logits_2": lm_logits_2,
159 | }
160 | return outputs
161 |
162 | def check_transfo_xl_lm_head_output(self, result):
163 | self.parent.assertListEqual(
164 | list(result["loss_1"].size()),
165 | [self.batch_size, self.seq_length])
166 | self.parent.assertListEqual(
167 | list(result["lm_logits_1"].size()),
168 | [self.batch_size, self.seq_length, self.vocab_size])
169 | self.parent.assertListEqual(
170 | list(list(mem.size()) for mem in result["mems_1"]),
171 | [[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
172 |
173 | self.parent.assertListEqual(
174 | list(result["loss_2"].size()),
175 | [self.batch_size, self.seq_length])
176 | self.parent.assertListEqual(
177 | list(result["lm_logits_2"].size()),
178 | [self.batch_size, self.seq_length, self.vocab_size])
179 | self.parent.assertListEqual(
180 | list(list(mem.size()) for mem in result["mems_2"]),
181 | [[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
182 |
183 | def prepare_config_and_inputs_for_common(self):
184 | config_and_inputs = self.prepare_config_and_inputs()
185 | (config, input_ids_1, input_ids_2, lm_labels) = config_and_inputs
186 | inputs_dict = {'input_ids': input_ids_1}
187 | return config, inputs_dict
188 |
189 |
190 | def setUp(self):
191 | self.model_tester = TransfoXLModelTest.TransfoXLModelTester(self)
192 | self.config_tester = ConfigTester(self, config_class=TransfoXLConfig, d_embed=37)
193 |
194 | def test_config(self):
195 | self.config_tester.run_common_tests()
196 |
197 | def test_transfo_xl_model(self):
198 | self.model_tester.set_seed()
199 | config_and_inputs = self.model_tester.prepare_config_and_inputs()
200 | output_result = self.model_tester.create_transfo_xl_model(*config_and_inputs)
201 | self.model_tester.check_transfo_xl_model_output(output_result)
202 |
203 | def test_transfo_xl_lm_head(self):
204 | self.model_tester.set_seed()
205 | config_and_inputs = self.model_tester.prepare_config_and_inputs()
206 | output_result = self.model_tester.create_transfo_xl_lm_head(*config_and_inputs)
207 | self.model_tester.check_transfo_xl_lm_head_output(output_result)
208 |
209 | @slow
210 | def test_model_from_pretrained(self):
211 | cache_dir = "/tmp/transformers_test/"
212 | for model_name in list(TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
213 | model = TransfoXLModel.from_pretrained(model_name, cache_dir=cache_dir)
214 | shutil.rmtree(cache_dir)
215 | self.assertIsNotNone(model)
216 |
217 |
218 | if __name__ == "__main__":
219 | unittest.main()
220 |
--------------------------------------------------------------------------------
/transformers/tests/optimization_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import unittest
20 | import os
21 |
22 | from transformers import is_torch_available
23 |
24 | if is_torch_available():
25 | import torch
26 |
27 | from transformers import (AdamW,
28 | get_constant_schedule,
29 | get_constant_schedule_with_warmup,
30 | get_cosine_schedule_with_warmup,
31 | get_cosine_with_hard_restarts_schedule_with_warmup,
32 | get_linear_schedule_with_warmup)
33 |
34 | from .tokenization_tests_commons import TemporaryDirectory
35 | from .utils import require_torch
36 |
37 |
38 | def unwrap_schedule(scheduler, num_steps=10):
39 | lrs = []
40 | for _ in range(num_steps):
41 | scheduler.step()
42 | lrs.append(scheduler.get_lr())
43 | return lrs
44 |
45 | def unwrap_and_save_reload_schedule(scheduler, num_steps=10):
46 | lrs = []
47 | for step in range(num_steps):
48 | scheduler.step()
49 | lrs.append(scheduler.get_lr())
50 | if step == num_steps // 2:
51 | with TemporaryDirectory() as tmpdirname:
52 | file_name = os.path.join(tmpdirname, 'schedule.bin')
53 | torch.save(scheduler.state_dict(), file_name)
54 |
55 | state_dict = torch.load(file_name)
56 | scheduler.load_state_dict(state_dict)
57 | return lrs
58 |
59 | @require_torch
60 | class OptimizationTest(unittest.TestCase):
61 |
62 | def assertListAlmostEqual(self, list1, list2, tol):
63 | self.assertEqual(len(list1), len(list2))
64 | for a, b in zip(list1, list2):
65 | self.assertAlmostEqual(a, b, delta=tol)
66 |
67 | def test_adam_w(self):
68 | w = torch.tensor([0.1, -0.2, -0.1], requires_grad=True)
69 | target = torch.tensor([0.4, 0.2, -0.5])
70 | criterion = torch.nn.MSELoss()
71 | # No warmup, constant schedule, no gradient clipping
72 | optimizer = AdamW(params=[w], lr=2e-1, weight_decay=0.0)
73 | for _ in range(100):
74 | loss = criterion(w, target)
75 | loss.backward()
76 | optimizer.step()
77 | w.grad.detach_() # No zero_grad() function on simple tensors. we do it ourselves.
78 | w.grad.zero_()
79 | self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2)
80 |
81 |
82 | @require_torch
83 | class ScheduleInitTest(unittest.TestCase):
84 | m = torch.nn.Linear(50, 50) if is_torch_available() else None
85 | optimizer = AdamW(m.parameters(), lr=10.) if is_torch_available() else None
86 | num_steps = 10
87 |
88 | def assertListAlmostEqual(self, list1, list2, tol):
89 | self.assertEqual(len(list1), len(list2))
90 | for a, b in zip(list1, list2):
91 | self.assertAlmostEqual(a, b, delta=tol)
92 |
93 | def test_constant_scheduler(self):
94 | scheduler = get_constant_schedule(self.optimizer)
95 | lrs = unwrap_schedule(scheduler, self.num_steps)
96 | expected_learning_rates = [10.] * self.num_steps
97 | self.assertEqual(len(lrs[0]), 1)
98 | self.assertListEqual([l[0] for l in lrs], expected_learning_rates)
99 |
100 | scheduler = get_constant_schedule(self.optimizer)
101 | lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps)
102 | self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2])
103 |
104 | def test_warmup_constant_scheduler(self):
105 | scheduler = get_constant_schedule_with_warmup(self.optimizer, num_warmup_steps=4)
106 | lrs = unwrap_schedule(scheduler, self.num_steps)
107 | expected_learning_rates = [2.5, 5.0, 7.5, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0]
108 | self.assertEqual(len(lrs[0]), 1)
109 | self.assertListEqual([l[0] for l in lrs], expected_learning_rates)
110 |
111 | scheduler = get_constant_schedule_with_warmup(self.optimizer, num_warmup_steps=4)
112 | lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps)
113 | self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2])
114 |
115 | def test_warmup_linear_scheduler(self):
116 | scheduler = get_linear_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_training_steps=10)
117 | lrs = unwrap_schedule(scheduler, self.num_steps)
118 | expected_learning_rates = [5.0, 10.0, 8.75, 7.5, 6.25, 5.0, 3.75, 2.5, 1.25, 0.0]
119 | self.assertEqual(len(lrs[0]), 1)
120 | self.assertListEqual([l[0] for l in lrs], expected_learning_rates)
121 |
122 | scheduler = get_linear_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_training_steps=10)
123 | lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps)
124 | self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2])
125 |
126 | def test_warmup_cosine_scheduler(self):
127 | scheduler = get_cosine_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_training_steps=10)
128 | lrs = unwrap_schedule(scheduler, self.num_steps)
129 | expected_learning_rates = [5.0, 10.0, 9.61, 8.53, 6.91, 5.0, 3.08, 1.46, 0.38, 0.0]
130 | self.assertEqual(len(lrs[0]), 1)
131 | self.assertListAlmostEqual([l[0] for l in lrs], expected_learning_rates, tol=1e-2)
132 |
133 | scheduler = get_cosine_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_training_steps=10)
134 | lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps)
135 | self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2])
136 |
137 | def test_warmup_cosine_hard_restart_scheduler(self):
138 | scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_cycles=2, num_training_steps=10)
139 | lrs = unwrap_schedule(scheduler, self.num_steps)
140 | expected_learning_rates = [5.0, 10.0, 8.53, 5.0, 1.46, 10.0, 8.53, 5.0, 1.46, 0.0]
141 | self.assertEqual(len(lrs[0]), 1)
142 | self.assertListAlmostEqual([l[0] for l in lrs], expected_learning_rates, tol=1e-2)
143 |
144 | scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_cycles=2, num_training_steps=10)
145 | lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps)
146 | self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2])
147 |
148 |
149 | if __name__ == "__main__":
150 | unittest.main()
151 |
--------------------------------------------------------------------------------
/transformers/tests/optimization_tf_test.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import unittest
6 |
7 | from transformers import is_tf_available
8 |
9 | from .utils import require_tf
10 |
11 | if is_tf_available():
12 | import tensorflow as tf
13 | from tensorflow.python.eager import context
14 | from tensorflow.python.framework import ops
15 | from transformers import (create_optimizer, GradientAccumulator)
16 |
17 |
18 | @require_tf
19 | class OptimizationFTest(unittest.TestCase):
20 | def assertListAlmostEqual(self, list1, list2, tol):
21 | self.assertEqual(len(list1), len(list2))
22 | for a, b in zip(list1, list2):
23 | self.assertAlmostEqual(a, b, delta=tol)
24 |
25 | def testGradientAccumulator(self):
26 | accumulator = GradientAccumulator()
27 | accumulator([tf.constant([1.0, 2.0])])
28 | accumulator([tf.constant([-2.0, 1.0])])
29 | accumulator([tf.constant([-1.0, 2.0])])
30 | with self.assertRaises(ValueError):
31 | accumulator([tf.constant([1.0, 1.0]), tf.constant([2.0, 2.0])])
32 | self.assertEqual(accumulator.step, 3)
33 | self.assertEqual(len(accumulator.gradients), 1)
34 | self.assertListAlmostEqual(accumulator.gradients[0].numpy().tolist(), [-2.0, 5.0], tol=1e-2)
35 | accumulator.reset()
36 | self.assertEqual(accumulator.step, 0)
37 | self.assertListAlmostEqual(accumulator.gradients[0].numpy().tolist(), [0.0, 0.0], tol=1e-2)
38 |
39 | def testGradientAccumulatorDistributionStrategy(self):
40 | context._context = None
41 | ops.enable_eager_execution_internal()
42 | physical_devices = tf.config.experimental.list_physical_devices("CPU")
43 | tf.config.experimental.set_virtual_device_configuration(
44 | physical_devices[0],
45 | [tf.config.experimental.VirtualDeviceConfiguration(),
46 | tf.config.experimental.VirtualDeviceConfiguration()])
47 |
48 | devices = tf.config.experimental.list_logical_devices(device_type="CPU")
49 | strategy = tf.distribute.MirroredStrategy(devices=[device.name for device in devices])
50 |
51 | with strategy.scope():
52 | accumulator = GradientAccumulator()
53 | variable = tf.Variable([4.0, 3.0])
54 | optimizer = create_optimizer(5e-5, 10, 5)
55 | gradient_placeholder = tf.Variable([0.0, 0.0], trainable=False)
56 |
57 | def accumulate_on_replica(gradient):
58 | accumulator([gradient])
59 |
60 | def apply_on_replica():
61 | optimizer.apply_gradients(list(zip(accumulator.gradients, [variable])), 1.0)
62 |
63 | @tf.function
64 | def accumulate(grad1, grad2):
65 | with strategy.scope():
66 | gradient_placeholder.values[0].assign(grad1)
67 | gradient_placeholder.values[1].assign(grad2)
68 | strategy.experimental_run_v2(accumulate_on_replica, args=(gradient_placeholder,))
69 |
70 | @tf.function
71 | def apply_grad():
72 | with strategy.scope():
73 | strategy.experimental_run_v2(apply_on_replica)
74 |
75 | accumulate([1.0, 2.0], [-1.0, 1.0])
76 | accumulate([3.0, -1.0], [-1.0, -1.0])
77 | accumulate([-2.0, 2.0], [3.0, -2.0])
78 | self.assertEqual(accumulator.step, 3)
79 | self.assertListAlmostEqual(accumulator._gradients[0].values[0].value().numpy().tolist(), [2.0, 3.0], tol=1e-2)
80 | self.assertListAlmostEqual(accumulator._gradients[0].values[1].value().numpy().tolist(), [1.0, -2.0], tol=1e-2)
81 | apply_grad()
82 | self.assertListAlmostEqual(variable.value().numpy().tolist(), [4.0, 3.0], tol=1e-2)
83 | accumulator.reset()
84 | self.assertEqual(accumulator.step, 0)
85 | self.assertListAlmostEqual(accumulator._gradients[0].values[0].value().numpy().tolist(), [0.0, 0.0], tol=1e-2)
86 | self.assertListAlmostEqual(accumulator._gradients[0].values[1].value().numpy().tolist(), [0.0, 0.0], tol=1e-2)
87 |
88 |
89 | if __name__ == "__main__":
90 | unittest.main()
--------------------------------------------------------------------------------
/transformers/tests/tokenization_albert_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 Hugging Face inc.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | from __future__ import absolute_import, division, print_function, unicode_literals
16 |
17 | import os
18 | import unittest
19 |
20 | from transformers.tokenization_albert import (AlbertTokenizer, SPIECE_UNDERLINE)
21 |
22 | from .tokenization_tests_commons import CommonTestCases
23 |
24 | SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)),
25 | 'fixtures/spiece.model')
26 |
27 | class AlbertTokenizationTest(CommonTestCases.CommonTokenizerTester):
28 |
29 | tokenizer_class = AlbertTokenizer
30 |
31 | def setUp(self):
32 | super(AlbertTokenizationTest, self).setUp()
33 |
34 | # We have a SentencePiece fixture for testing
35 | tokenizer = AlbertTokenizer(SAMPLE_VOCAB)
36 | tokenizer.save_pretrained(self.tmpdirname)
37 |
38 | def get_tokenizer(self, **kwargs):
39 | return AlbertTokenizer.from_pretrained(self.tmpdirname, **kwargs)
40 |
41 | def get_input_output_texts(self):
42 | input_text = u"this is a test"
43 | output_text = u"this is a test"
44 | return input_text, output_text
45 |
46 |
47 | def test_full_tokenizer(self):
48 | tokenizer = AlbertTokenizer(SAMPLE_VOCAB, keep_accents=True)
49 |
50 | tokens = tokenizer.tokenize(u'This is a test')
51 | self.assertListEqual(tokens, [u'▁this', u'▁is', u'▁a', u'▁test'])
52 |
53 | self.assertListEqual(
54 | tokenizer.convert_tokens_to_ids(tokens), [48, 25, 21, 1289])
55 |
56 | tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.")
57 | self.assertListEqual(tokens, [u'▁i', u'▁was', u'▁born', u'▁in', u'▁9', u'2000', u',', u'▁and', u'▁this', u'▁is', u'▁fal', u's', u'é', u'.'])
58 | ids = tokenizer.convert_tokens_to_ids(tokens)
59 | self.assertListEqual(ids, [31, 23, 386, 19, 561, 3050, 15, 17, 48, 25, 8256, 18, 1, 9])
60 |
61 | back_tokens = tokenizer.convert_ids_to_tokens(ids)
62 | self.assertListEqual(back_tokens, ['▁i', '▁was', '▁born', '▁in', '▁9', '2000', ',', '▁and', '▁this', '▁is', '▁fal', 's', '', '.'])
63 |
64 | def test_sequence_builders(self):
65 | tokenizer = AlbertTokenizer(SAMPLE_VOCAB)
66 |
67 | text = tokenizer.encode("sequence builders")
68 | text_2 = tokenizer.encode("multi-sequence build")
69 |
70 | encoded_sentence = tokenizer.build_inputs_with_special_tokens(text)
71 | encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2)
72 |
73 | assert encoded_sentence == [tokenizer.cls_token_id] + text + [tokenizer.sep_token_id]
74 | assert encoded_pair == [tokenizer.cls_token_id] + text + [tokenizer.sep_token_id] + text_2 + [tokenizer.sep_token_id]
75 |
76 |
77 | if __name__ == '__main__':
78 | unittest.main()
79 |
--------------------------------------------------------------------------------
/transformers/tests/tokenization_auto_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import unittest
20 | import shutil
21 | import logging
22 |
23 | from transformers import AutoTokenizer, BertTokenizer, AutoTokenizer, GPT2Tokenizer
24 | from transformers import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
25 |
26 | from .utils import slow, SMALL_MODEL_IDENTIFIER
27 |
28 |
29 | class AutoTokenizerTest(unittest.TestCase):
30 | @slow
31 | def test_tokenizer_from_pretrained(self):
32 | logging.basicConfig(level=logging.INFO)
33 | for model_name in list(BERT_PRETRAINED_CONFIG_ARCHIVE_MAP.keys())[:1]:
34 | tokenizer = AutoTokenizer.from_pretrained(model_name)
35 | self.assertIsNotNone(tokenizer)
36 | self.assertIsInstance(tokenizer, BertTokenizer)
37 | self.assertGreater(len(tokenizer), 0)
38 |
39 | for model_name in list(GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP.keys())[:1]:
40 | tokenizer = AutoTokenizer.from_pretrained(model_name)
41 | self.assertIsNotNone(tokenizer)
42 | self.assertIsInstance(tokenizer, GPT2Tokenizer)
43 | self.assertGreater(len(tokenizer), 0)
44 |
45 | def test_tokenizer_from_pretrained_identifier(self):
46 | logging.basicConfig(level=logging.INFO)
47 | tokenizer = AutoTokenizer.from_pretrained(SMALL_MODEL_IDENTIFIER)
48 | self.assertIsInstance(tokenizer, BertTokenizer)
49 | self.assertEqual(len(tokenizer), 12)
50 |
51 | if __name__ == "__main__":
52 | unittest.main()
53 |
--------------------------------------------------------------------------------
/transformers/tests/tokenization_bert_japanese_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | from __future__ import absolute_import, division, print_function, unicode_literals
16 |
17 | import os
18 | import unittest
19 | from io import open
20 |
21 | from transformers.tokenization_bert import WordpieceTokenizer
22 | from transformers.tokenization_bert_japanese import (BertJapaneseTokenizer,
23 | MecabTokenizer, CharacterTokenizer,
24 | VOCAB_FILES_NAMES)
25 |
26 | from .tokenization_tests_commons import CommonTestCases
27 | from .utils import slow, custom_tokenizers
28 |
29 |
30 | @custom_tokenizers
31 | class BertJapaneseTokenizationTest(CommonTestCases.CommonTokenizerTester):
32 |
33 | tokenizer_class = BertJapaneseTokenizer
34 |
35 | def setUp(self):
36 | super(BertJapaneseTokenizationTest, self).setUp()
37 |
38 | vocab_tokens = [u"[UNK]", u"[CLS]", u"[SEP]",
39 | u"こんにちは", u"こん", u"にちは", u"ばんは", u"##こん", u"##にちは", u"##ばんは",
40 | u"世界", u"##世界", u"、", u"##、", u"。", u"##。"]
41 |
42 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
43 | with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
44 | vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
45 |
46 | def get_tokenizer(self, **kwargs):
47 | return BertJapaneseTokenizer.from_pretrained(self.tmpdirname, **kwargs)
48 |
49 | def get_input_output_texts(self):
50 | input_text = u"こんにちは、世界。 \nこんばんは、世界。"
51 | output_text = u"こんにちは 、 世界 。 こんばんは 、 世界 。"
52 | return input_text, output_text
53 |
54 | def test_full_tokenizer(self):
55 | tokenizer = self.tokenizer_class(self.vocab_file)
56 |
57 | tokens = tokenizer.tokenize(u"こんにちは、世界。\nこんばんは、世界。")
58 | self.assertListEqual(tokens,
59 | [u"こんにちは", u"、", u"世界", u"。",
60 | u"こん", u"##ばんは", u"、", u"世界", "。"])
61 | self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens),
62 | [3, 12, 10, 14, 4, 9, 12, 10, 14])
63 |
64 | def test_mecab_tokenizer(self):
65 | tokenizer = MecabTokenizer()
66 |
67 | self.assertListEqual(
68 | tokenizer.tokenize(u" \tアップルストアでiPhone8 が \n 発売された 。 "),
69 | [u"アップルストア", u"で", u"iPhone", u"8", u"が",
70 | u"発売", u"さ", u"れ", u"た", u"。"])
71 |
72 | def test_mecab_tokenizer_lower(self):
73 | tokenizer = MecabTokenizer(do_lower_case=True)
74 |
75 | self.assertListEqual(
76 | tokenizer.tokenize(u" \tアップルストアでiPhone8 が \n 発売された 。 "),
77 | [u"アップルストア", u"で", u"iphone", u"8", u"が",
78 | u"発売", u"さ", u"れ", u"た", u"。"])
79 |
80 | def test_mecab_tokenizer_no_normalize(self):
81 | tokenizer = MecabTokenizer(normalize_text=False)
82 |
83 | self.assertListEqual(
84 | tokenizer.tokenize(u" \tアップルストアでiPhone8 が \n 発売された 。 "),
85 | [u"アップルストア", u"で", u"iPhone", u"8", u"が",
86 | u"発売", u"さ", u"れ", u"た", u" ", u"。"])
87 |
88 | def test_wordpiece_tokenizer(self):
89 | vocab_tokens = [u"[UNK]", u"[CLS]", u"[SEP]",
90 | u"こんにちは", u"こん", u"にちは" u"ばんは", u"##こん", u"##にちは", u"##ばんは"]
91 |
92 | vocab = {}
93 | for (i, token) in enumerate(vocab_tokens):
94 | vocab[token] = i
95 | tokenizer = WordpieceTokenizer(vocab=vocab, unk_token=u"[UNK]")
96 |
97 | self.assertListEqual(tokenizer.tokenize(u""), [])
98 |
99 | self.assertListEqual(tokenizer.tokenize(u"こんにちは"),
100 | [u"こんにちは"])
101 |
102 | self.assertListEqual(tokenizer.tokenize(u"こんばんは"),
103 | [u"こん", u"##ばんは"])
104 |
105 | self.assertListEqual(tokenizer.tokenize(u"こんばんは こんばんにちは こんにちは"),
106 | [u"こん", u"##ばんは", u"[UNK]", u"こんにちは"])
107 |
108 | @slow
109 | def test_sequence_builders(self):
110 | tokenizer = self.tokenizer_class.from_pretrained("bert-base-japanese")
111 |
112 | text = tokenizer.encode(u"ありがとう。", add_special_tokens=False)
113 | text_2 = tokenizer.encode(u"どういたしまして。", add_special_tokens=False)
114 |
115 | encoded_sentence = tokenizer.build_inputs_with_special_tokens(text)
116 | encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2)
117 |
118 | # 2 is for "[CLS]", 3 is for "[SEP]"
119 | assert encoded_sentence == [2] + text + [3]
120 | assert encoded_pair == [2] + text + [3] + text_2 + [3]
121 |
122 |
123 | class BertJapaneseCharacterTokenizationTest(CommonTestCases.CommonTokenizerTester):
124 |
125 | tokenizer_class = BertJapaneseTokenizer
126 |
127 | def setUp(self):
128 | super(BertJapaneseCharacterTokenizationTest, self).setUp()
129 |
130 | vocab_tokens = [u"[UNK]", u"[CLS]", u"[SEP]",
131 | u"こ", u"ん", u"に", u"ち", u"は", u"ば", u"世", u"界", u"、", u"。"]
132 |
133 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
134 | with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
135 | vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
136 |
137 | def get_tokenizer(self, **kwargs):
138 | return BertJapaneseTokenizer.from_pretrained(self.tmpdirname,
139 | subword_tokenizer_type="character",
140 | **kwargs)
141 |
142 | def get_input_output_texts(self):
143 | input_text = u"こんにちは、世界。 \nこんばんは、世界。"
144 | output_text = u"こ ん に ち は 、 世 界 。 こ ん ば ん は 、 世 界 。"
145 | return input_text, output_text
146 |
147 | def test_full_tokenizer(self):
148 | tokenizer = self.tokenizer_class(self.vocab_file,
149 | subword_tokenizer_type="character")
150 |
151 | tokens = tokenizer.tokenize(u"こんにちは、世界。 \nこんばんは、世界。")
152 | self.assertListEqual(tokens,
153 | [u"こ", u"ん", u"に", u"ち", u"は", u"、", u"世", u"界", u"。",
154 | u"こ", u"ん", u"ば", u"ん", u"は", u"、", u"世", u"界", u"。"])
155 | self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens),
156 | [3, 4, 5, 6, 7, 11, 9, 10, 12,
157 | 3, 4, 8, 4, 7, 11, 9, 10, 12])
158 |
159 | def test_character_tokenizer(self):
160 | vocab_tokens = [u"[UNK]", u"[CLS]", u"[SEP]",
161 | u"こ", u"ん", u"に", u"ち", u"は", u"ば", u"世", u"界"u"、", u"。"]
162 |
163 | vocab = {}
164 | for (i, token) in enumerate(vocab_tokens):
165 | vocab[token] = i
166 | tokenizer = CharacterTokenizer(vocab=vocab, unk_token=u"[UNK]")
167 |
168 | self.assertListEqual(tokenizer.tokenize(u""), [])
169 |
170 | self.assertListEqual(tokenizer.tokenize(u"こんにちは"),
171 | [u"こ", u"ん", u"に", u"ち", u"は"])
172 |
173 | self.assertListEqual(tokenizer.tokenize(u"こんにちほ"),
174 | [u"こ", u"ん", u"に", u"ち", u"[UNK]"])
175 |
176 | @slow
177 | def test_sequence_builders(self):
178 | tokenizer = self.tokenizer_class.from_pretrained("bert-base-japanese-char")
179 |
180 | text = tokenizer.encode(u"ありがとう。", add_special_tokens=False)
181 | text_2 = tokenizer.encode(u"どういたしまして。", add_special_tokens=False)
182 |
183 | encoded_sentence = tokenizer.build_inputs_with_special_tokens(text)
184 | encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2)
185 |
186 | # 2 is for "[CLS]", 3 is for "[SEP]"
187 | assert encoded_sentence == [2] + text + [3]
188 | assert encoded_pair == [2] + text + [3] + text_2 + [3]
189 |
190 |
191 |
192 |
--------------------------------------------------------------------------------
/transformers/tests/tokenization_bert_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | from __future__ import absolute_import, division, print_function, unicode_literals
16 |
17 | import os
18 | import unittest
19 | from io import open
20 |
21 | from transformers.tokenization_bert import (BasicTokenizer,
22 | BertTokenizer,
23 | WordpieceTokenizer,
24 | _is_control, _is_punctuation,
25 | _is_whitespace, VOCAB_FILES_NAMES)
26 |
27 | from .tokenization_tests_commons import CommonTestCases
28 | from .utils import slow
29 |
30 | class BertTokenizationTest(CommonTestCases.CommonTokenizerTester):
31 |
32 | tokenizer_class = BertTokenizer
33 |
34 | def setUp(self):
35 | super(BertTokenizationTest, self).setUp()
36 |
37 | vocab_tokens = [
38 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
39 | "##ing", ",", "low", "lowest",
40 | ]
41 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
42 | with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer:
43 | vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
44 |
45 | def get_tokenizer(self, **kwargs):
46 | return BertTokenizer.from_pretrained(self.tmpdirname, **kwargs)
47 |
48 | def get_input_output_texts(self):
49 | input_text = u"UNwant\u00E9d,running"
50 | output_text = u"unwanted, running"
51 | return input_text, output_text
52 |
53 | def test_full_tokenizer(self):
54 | tokenizer = self.tokenizer_class(self.vocab_file)
55 |
56 | tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
57 | self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
58 | self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
59 |
60 | def test_chinese(self):
61 | tokenizer = BasicTokenizer()
62 |
63 | self.assertListEqual(
64 | tokenizer.tokenize(u"ah\u535A\u63A8zz"),
65 | [u"ah", u"\u535A", u"\u63A8", u"zz"])
66 |
67 | def test_basic_tokenizer_lower(self):
68 | tokenizer = BasicTokenizer(do_lower_case=True)
69 |
70 | self.assertListEqual(
71 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
72 | ["hello", "!", "how", "are", "you", "?"])
73 | self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"])
74 |
75 | def test_basic_tokenizer_no_lower(self):
76 | tokenizer = BasicTokenizer(do_lower_case=False)
77 |
78 | self.assertListEqual(
79 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
80 | ["HeLLo", "!", "how", "Are", "yoU", "?"])
81 |
82 | def test_wordpiece_tokenizer(self):
83 | vocab_tokens = [
84 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
85 | "##ing"
86 | ]
87 |
88 | vocab = {}
89 | for (i, token) in enumerate(vocab_tokens):
90 | vocab[token] = i
91 | tokenizer = WordpieceTokenizer(vocab=vocab, unk_token="[UNK]")
92 |
93 | self.assertListEqual(tokenizer.tokenize(""), [])
94 |
95 | self.assertListEqual(
96 | tokenizer.tokenize("unwanted running"),
97 | ["un", "##want", "##ed", "runn", "##ing"])
98 |
99 | self.assertListEqual(
100 | tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"])
101 |
102 | def test_is_whitespace(self):
103 | self.assertTrue(_is_whitespace(u" "))
104 | self.assertTrue(_is_whitespace(u"\t"))
105 | self.assertTrue(_is_whitespace(u"\r"))
106 | self.assertTrue(_is_whitespace(u"\n"))
107 | self.assertTrue(_is_whitespace(u"\u00A0"))
108 |
109 | self.assertFalse(_is_whitespace(u"A"))
110 | self.assertFalse(_is_whitespace(u"-"))
111 |
112 | def test_is_control(self):
113 | self.assertTrue(_is_control(u"\u0005"))
114 |
115 | self.assertFalse(_is_control(u"A"))
116 | self.assertFalse(_is_control(u" "))
117 | self.assertFalse(_is_control(u"\t"))
118 | self.assertFalse(_is_control(u"\r"))
119 |
120 | def test_is_punctuation(self):
121 | self.assertTrue(_is_punctuation(u"-"))
122 | self.assertTrue(_is_punctuation(u"$"))
123 | self.assertTrue(_is_punctuation(u"`"))
124 | self.assertTrue(_is_punctuation(u"."))
125 |
126 | self.assertFalse(_is_punctuation(u"A"))
127 | self.assertFalse(_is_punctuation(u" "))
128 |
129 | @slow
130 | def test_sequence_builders(self):
131 | tokenizer = self.tokenizer_class.from_pretrained("bert-base-uncased")
132 |
133 | text = tokenizer.encode("sequence builders", add_special_tokens=False)
134 | text_2 = tokenizer.encode("multi-sequence build", add_special_tokens=False)
135 |
136 | encoded_sentence = tokenizer.build_inputs_with_special_tokens(text)
137 | encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2)
138 |
139 | assert encoded_sentence == [101] + text + [102]
140 | assert encoded_pair == [101] + text + [102] + text_2 + [102]
141 |
142 |
143 | if __name__ == '__main__':
144 | unittest.main()
145 |
--------------------------------------------------------------------------------
/transformers/tests/tokenization_ctrl_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Salesforce and HuggingFace Inc. team.
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from __future__ import absolute_import, division, print_function, unicode_literals
15 |
16 | import os
17 | import unittest
18 | import json
19 | from io import open
20 |
21 | from transformers.tokenization_ctrl import CTRLTokenizer, VOCAB_FILES_NAMES
22 |
23 | from .tokenization_tests_commons import CommonTestCases
24 |
25 | class CTRLTokenizationTest(CommonTestCases.CommonTokenizerTester):
26 |
27 | tokenizer_class = CTRLTokenizer
28 |
29 | def setUp(self):
30 | super(CTRLTokenizationTest, self).setUp()
31 |
32 | # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
33 | vocab = ['adapt', 're@@', 'a@@', 'apt', 'c@@', 't', '']
34 | vocab_tokens = dict(zip(vocab, range(len(vocab))))
35 | merges = ["#version: 0.2", 'a p', 'ap t', 'r e', 'a d', 'ad apt', '']
36 | self.special_tokens_map = {"unk_token": ""}
37 |
38 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
39 | self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file'])
40 | with open(self.vocab_file, "w", encoding="utf-8") as fp:
41 | fp.write(json.dumps(vocab_tokens) + "\n")
42 | with open(self.merges_file, "w", encoding="utf-8") as fp:
43 | fp.write("\n".join(merges))
44 |
45 | def get_tokenizer(self, **kwargs):
46 | kwargs.update(self.special_tokens_map)
47 | return CTRLTokenizer.from_pretrained(self.tmpdirname, **kwargs)
48 |
49 | def get_input_output_texts(self):
50 | input_text = u"adapt react readapt apt"
51 | output_text = u"adapt react readapt apt"
52 | return input_text, output_text
53 |
54 | def test_full_tokenizer(self):
55 | tokenizer = CTRLTokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map)
56 | text = "adapt react readapt apt"
57 | bpe_tokens = 'adapt re@@ a@@ c@@ t re@@ adapt apt'.split()
58 | tokens = tokenizer.tokenize(text)
59 | self.assertListEqual(tokens, bpe_tokens)
60 |
61 | input_tokens = tokens + [tokenizer.unk_token]
62 |
63 | input_bpe_tokens = [0, 1, 2, 4, 5, 1, 0, 3, 6]
64 | self.assertListEqual(
65 | tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
66 |
67 |
68 | if __name__ == '__main__':
69 | unittest.main()
70 |
--------------------------------------------------------------------------------
/transformers/tests/tokenization_distilbert_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | from __future__ import absolute_import, division, print_function, unicode_literals
16 |
17 | import os
18 | import unittest
19 | from io import open
20 |
21 | from transformers.tokenization_distilbert import (DistilBertTokenizer)
22 |
23 | from .tokenization_tests_commons import CommonTestCases
24 | from .tokenization_bert_test import BertTokenizationTest
25 | from .utils import slow
26 |
27 | class DistilBertTokenizationTest(BertTokenizationTest):
28 |
29 | tokenizer_class = DistilBertTokenizer
30 |
31 | def get_tokenizer(self, **kwargs):
32 | return DistilBertTokenizer.from_pretrained(self.tmpdirname, **kwargs)
33 |
34 | @slow
35 | def test_sequence_builders(self):
36 | tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
37 |
38 | text = tokenizer.encode("sequence builders", add_special_tokens=False)
39 | text_2 = tokenizer.encode("multi-sequence build", add_special_tokens=False)
40 |
41 | encoded_sentence = tokenizer.build_inputs_with_special_tokens(text)
42 | encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2)
43 |
44 | assert encoded_sentence == [tokenizer.cls_token_id] + text + [tokenizer.sep_token_id]
45 | assert encoded_pair == [tokenizer.cls_token_id] + text + [tokenizer.sep_token_id] + \
46 | text_2 + [tokenizer.sep_token_id]
47 |
48 |
49 | if __name__ == '__main__':
50 | unittest.main()
51 |
--------------------------------------------------------------------------------
/transformers/tests/tokenization_gpt2_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | from __future__ import absolute_import, division, print_function, unicode_literals
16 |
17 | import os
18 | import unittest
19 | import json
20 | from io import open
21 |
22 | from transformers.tokenization_gpt2 import GPT2Tokenizer, VOCAB_FILES_NAMES
23 |
24 | from .tokenization_tests_commons import CommonTestCases
25 |
26 | class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester):
27 |
28 | tokenizer_class = GPT2Tokenizer
29 |
30 | def setUp(self):
31 | super(GPT2TokenizationTest, self).setUp()
32 |
33 | # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
34 | vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
35 | "\u0120", "\u0120l", "\u0120n",
36 | "\u0120lo", "\u0120low", "er",
37 | "\u0120lowest", "\u0120newer", "\u0120wider", ""]
38 | vocab_tokens = dict(zip(vocab, range(len(vocab))))
39 | merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""]
40 | self.special_tokens_map = {"unk_token": ""}
41 |
42 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
43 | self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file'])
44 | with open(self.vocab_file, "w", encoding="utf-8") as fp:
45 | fp.write(json.dumps(vocab_tokens) + "\n")
46 | with open(self.merges_file, "w", encoding="utf-8") as fp:
47 | fp.write("\n".join(merges))
48 |
49 | def get_tokenizer(self, **kwargs):
50 | kwargs.update(self.special_tokens_map)
51 | return GPT2Tokenizer.from_pretrained(self.tmpdirname, **kwargs)
52 |
53 | def get_input_output_texts(self):
54 | input_text = u"lower newer"
55 | output_text = u"lower newer"
56 | return input_text, output_text
57 |
58 | def test_full_tokenizer(self):
59 | tokenizer = GPT2Tokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map)
60 | text = "lower newer"
61 | bpe_tokens = ["\u0120low", "er", "\u0120", "n", "e", "w", "er"]
62 | tokens = tokenizer.tokenize(text, add_prefix_space=True)
63 | self.assertListEqual(tokens, bpe_tokens)
64 |
65 | input_tokens = tokens + [tokenizer.unk_token]
66 | input_bpe_tokens = [14, 15, 10, 9, 3, 2, 15, 19]
67 | self.assertListEqual(
68 | tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
69 |
70 | if __name__ == '__main__':
71 | unittest.main()
72 |
--------------------------------------------------------------------------------
/transformers/tests/tokenization_openai_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | from __future__ import absolute_import, division, print_function, unicode_literals
16 |
17 | import os
18 | import unittest
19 | import json
20 |
21 | from transformers.tokenization_openai import OpenAIGPTTokenizer, VOCAB_FILES_NAMES
22 |
23 | from .tokenization_tests_commons import CommonTestCases
24 |
25 |
26 | class OpenAIGPTTokenizationTest(CommonTestCases.CommonTokenizerTester):
27 |
28 | tokenizer_class = OpenAIGPTTokenizer
29 |
30 | def setUp(self):
31 | super(OpenAIGPTTokenizationTest, self).setUp()
32 |
33 | # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
34 | vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
35 | "w", "r", "t",
36 | "lo", "low", "er",
37 | "low", "lowest", "newer", "wider", ""]
38 | vocab_tokens = dict(zip(vocab, range(len(vocab))))
39 | merges = ["#version: 0.2", "l o", "lo w", "e r", ""]
40 |
41 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
42 | self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file'])
43 | with open(self.vocab_file, "w") as fp:
44 | fp.write(json.dumps(vocab_tokens))
45 | with open(self.merges_file, "w") as fp:
46 | fp.write("\n".join(merges))
47 |
48 | def get_tokenizer(self, **kwargs):
49 | return OpenAIGPTTokenizer.from_pretrained(self.tmpdirname, **kwargs)
50 |
51 | def get_input_output_texts(self):
52 | input_text = u"lower newer"
53 | output_text = u"lower newer"
54 | return input_text, output_text
55 |
56 |
57 | def test_full_tokenizer(self):
58 | tokenizer = OpenAIGPTTokenizer(self.vocab_file, self.merges_file)
59 |
60 | text = "lower"
61 | bpe_tokens = ["low", "er"]
62 | tokens = tokenizer.tokenize(text)
63 | self.assertListEqual(tokens, bpe_tokens)
64 |
65 | input_tokens = tokens + [""]
66 | input_bpe_tokens = [14, 15, 20]
67 | self.assertListEqual(
68 | tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
69 |
70 |
71 | if __name__ == '__main__':
72 | unittest.main()
73 |
--------------------------------------------------------------------------------
/transformers/tests/tokenization_roberta_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | from __future__ import absolute_import, division, print_function, unicode_literals
16 |
17 | import os
18 | import json
19 | import unittest
20 | from io import open
21 |
22 | from transformers.tokenization_roberta import RobertaTokenizer, VOCAB_FILES_NAMES
23 | from .tokenization_tests_commons import CommonTestCases
24 | from .utils import slow
25 |
26 |
27 | class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester):
28 | tokenizer_class = RobertaTokenizer
29 |
30 | def setUp(self):
31 | super(RobertaTokenizationTest, self).setUp()
32 |
33 | # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
34 | vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
35 | "\u0120", "\u0120l", "\u0120n",
36 | "\u0120lo", "\u0120low", "er",
37 | "\u0120lowest", "\u0120newer", "\u0120wider", ""]
38 | vocab_tokens = dict(zip(vocab, range(len(vocab))))
39 | merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""]
40 | self.special_tokens_map = {"unk_token": ""}
41 |
42 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
43 | self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file'])
44 | with open(self.vocab_file, "w", encoding="utf-8") as fp:
45 | fp.write(json.dumps(vocab_tokens) + "\n")
46 | with open(self.merges_file, "w", encoding="utf-8") as fp:
47 | fp.write("\n".join(merges))
48 |
49 | def get_tokenizer(self, **kwargs):
50 | kwargs.update(self.special_tokens_map)
51 | return RobertaTokenizer.from_pretrained(self.tmpdirname, **kwargs)
52 |
53 | def get_input_output_texts(self):
54 | input_text = u"lower newer"
55 | output_text = u"lower newer"
56 | return input_text, output_text
57 |
58 | def test_full_tokenizer(self):
59 | tokenizer = RobertaTokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map)
60 | text = "lower newer"
61 | bpe_tokens = ["\u0120low", "er", "\u0120", "n", "e", "w", "er"]
62 | tokens = tokenizer.tokenize(text, add_prefix_space=True)
63 | self.assertListEqual(tokens, bpe_tokens)
64 |
65 | input_tokens = tokens + [tokenizer.unk_token]
66 | input_bpe_tokens = [14, 15, 10, 9, 3, 2, 15, 19]
67 | self.assertListEqual(
68 | tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
69 |
70 | def roberta_dict_integration_testing(self):
71 | tokenizer = self.get_tokenizer()
72 |
73 | self.assertListEqual(
74 | tokenizer.encode('Hello world!', add_special_tokens=False),
75 | [0, 31414, 232, 328, 2]
76 | )
77 | self.assertListEqual(
78 | tokenizer.encode('Hello world! cécé herlolip 418', add_special_tokens=False),
79 | [0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]
80 | )
81 |
82 | @slow
83 | def test_sequence_builders(self):
84 | tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
85 |
86 | text = tokenizer.encode("sequence builders", add_special_tokens=False)
87 | text_2 = tokenizer.encode("multi-sequence build", add_special_tokens=False)
88 |
89 | encoded_text_from_decode = tokenizer.encode("sequence builders", add_special_tokens=True)
90 | encoded_pair_from_decode = tokenizer.encode("sequence builders", "multi-sequence build", add_special_tokens=True)
91 |
92 | encoded_sentence = tokenizer.build_inputs_with_special_tokens(text)
93 | encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2)
94 |
95 | assert encoded_sentence == encoded_text_from_decode
96 | assert encoded_pair == encoded_pair_from_decode
97 |
98 |
99 | if __name__ == '__main__':
100 | unittest.main()
101 |
--------------------------------------------------------------------------------
/transformers/tests/tokenization_transfo_xl_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | from __future__ import absolute_import, division, print_function, unicode_literals
16 |
17 | import os
18 | import unittest
19 | from io import open
20 |
21 | from transformers import is_torch_available
22 |
23 | if is_torch_available():
24 | import torch
25 | from transformers.tokenization_transfo_xl import TransfoXLTokenizer, VOCAB_FILES_NAMES
26 |
27 | from .tokenization_tests_commons import CommonTestCases
28 | from .utils import require_torch
29 |
30 |
31 | @require_torch
32 | class TransfoXLTokenizationTest(CommonTestCases.CommonTokenizerTester):
33 |
34 | tokenizer_class = TransfoXLTokenizer if is_torch_available() else None
35 |
36 | def setUp(self):
37 | super(TransfoXLTokenizationTest, self).setUp()
38 |
39 | vocab_tokens = [
40 | "", "[CLS]", "[SEP]", "want", "unwanted", "wa", "un",
41 | "running", ",", "low", "l",
42 | ]
43 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
44 | with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer:
45 | vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
46 |
47 | def get_tokenizer(self, **kwargs):
48 | kwargs['lower_case'] = True
49 | return TransfoXLTokenizer.from_pretrained(self.tmpdirname, **kwargs)
50 |
51 | def get_input_output_texts(self):
52 | input_text = u" UNwanted , running"
53 | output_text = u" unwanted, running"
54 | return input_text, output_text
55 |
56 | def test_full_tokenizer(self):
57 | tokenizer = TransfoXLTokenizer(vocab_file=self.vocab_file, lower_case=True)
58 |
59 | tokens = tokenizer.tokenize(u" UNwanted , running")
60 | self.assertListEqual(tokens, ["", "unwanted", ",", "running"])
61 |
62 | self.assertListEqual(
63 | tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7])
64 |
65 | def test_full_tokenizer_lower(self):
66 | tokenizer = TransfoXLTokenizer(lower_case=True)
67 |
68 | self.assertListEqual(
69 | tokenizer.tokenize(u" \tHeLLo ! how \n Are yoU ? "),
70 | ["hello", "!", "how", "are", "you", "?"])
71 |
72 | def test_full_tokenizer_no_lower(self):
73 | tokenizer = TransfoXLTokenizer(lower_case=False)
74 |
75 | self.assertListEqual(
76 | tokenizer.tokenize(u" \tHeLLo ! how \n Are yoU ? "),
77 | ["HeLLo", "!", "how", "Are", "yoU", "?"])
78 |
79 |
80 | if __name__ == '__main__':
81 | unittest.main()
82 |
--------------------------------------------------------------------------------
/transformers/tests/tokenization_utils_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 HuggingFace Inc..
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import unittest
20 | import six
21 |
22 | from transformers import PreTrainedTokenizer
23 | from transformers.tokenization_gpt2 import GPT2Tokenizer
24 |
25 | from .utils import slow
26 |
27 | class TokenizerUtilsTest(unittest.TestCase):
28 |
29 | def check_tokenizer_from_pretrained(self, tokenizer_class):
30 | s3_models = list(tokenizer_class.max_model_input_sizes.keys())
31 | for model_name in s3_models[:1]:
32 | tokenizer = tokenizer_class.from_pretrained(model_name)
33 | self.assertIsNotNone(tokenizer)
34 | self.assertIsInstance(tokenizer, tokenizer_class)
35 | self.assertIsInstance(tokenizer, PreTrainedTokenizer)
36 |
37 | for special_tok in tokenizer.all_special_tokens:
38 | if six.PY2:
39 | self.assertIsInstance(special_tok, unicode)
40 | else:
41 | self.assertIsInstance(special_tok, str)
42 | special_tok_id = tokenizer.convert_tokens_to_ids(special_tok)
43 | self.assertIsInstance(special_tok_id, int)
44 |
45 | @slow
46 | def test_pretrained_tokenizers(self):
47 | self.check_tokenizer_from_pretrained(GPT2Tokenizer)
48 |
49 | if __name__ == "__main__":
50 | unittest.main()
51 |
--------------------------------------------------------------------------------
/transformers/tests/tokenization_xlm_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | from __future__ import absolute_import, division, print_function, unicode_literals
16 |
17 | import os
18 | import unittest
19 | import json
20 |
21 | from transformers.tokenization_xlm import XLMTokenizer, VOCAB_FILES_NAMES
22 |
23 | from .tokenization_tests_commons import CommonTestCases
24 | from .utils import slow
25 |
26 | class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester):
27 |
28 | tokenizer_class = XLMTokenizer
29 |
30 | def setUp(self):
31 | super(XLMTokenizationTest, self).setUp()
32 |
33 | # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
34 | vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
35 | "w", "r", "t",
36 | "lo", "low", "er",
37 | "low", "lowest", "newer", "wider", ""]
38 | vocab_tokens = dict(zip(vocab, range(len(vocab))))
39 | merges = ["l o 123", "lo w 1456", "e r 1789", ""]
40 |
41 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
42 | self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file'])
43 | with open(self.vocab_file, "w") as fp:
44 | fp.write(json.dumps(vocab_tokens))
45 | with open(self.merges_file, "w") as fp:
46 | fp.write("\n".join(merges))
47 |
48 | def get_tokenizer(self, **kwargs):
49 | return XLMTokenizer.from_pretrained(self.tmpdirname, **kwargs)
50 |
51 | def get_input_output_texts(self):
52 | input_text = u"lower newer"
53 | output_text = u"lower newer"
54 | return input_text, output_text
55 |
56 | def test_full_tokenizer(self):
57 | """ Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """
58 | tokenizer = XLMTokenizer(self.vocab_file, self.merges_file)
59 |
60 | text = "lower"
61 | bpe_tokens = ["low", "er"]
62 | tokens = tokenizer.tokenize(text)
63 | self.assertListEqual(tokens, bpe_tokens)
64 |
65 | input_tokens = tokens + [""]
66 | input_bpe_tokens = [14, 15, 20]
67 | self.assertListEqual(
68 | tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
69 |
70 | @slow
71 | def test_sequence_builders(self):
72 | tokenizer = XLMTokenizer.from_pretrained("xlm-mlm-en-2048")
73 |
74 | text = tokenizer.encode("sequence builders", add_special_tokens=False)
75 | text_2 = tokenizer.encode("multi-sequence build", add_special_tokens=False)
76 |
77 | encoded_sentence = tokenizer.build_inputs_with_special_tokens(text)
78 | encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2)
79 |
80 | assert encoded_sentence == [1] + text + [1]
81 | assert encoded_pair == [1] + text + [1] + text_2 + [1]
82 |
83 | if __name__ == '__main__':
84 | unittest.main()
85 |
--------------------------------------------------------------------------------
/transformers/tests/tokenization_xlnet_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | from __future__ import absolute_import, division, print_function, unicode_literals
16 |
17 | import os
18 | import unittest
19 |
20 | from transformers.tokenization_xlnet import (XLNetTokenizer, SPIECE_UNDERLINE)
21 |
22 | from .tokenization_tests_commons import CommonTestCases
23 | from .utils import slow
24 |
25 | SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)),
26 | 'fixtures/test_sentencepiece.model')
27 |
28 | class XLNetTokenizationTest(CommonTestCases.CommonTokenizerTester):
29 |
30 | tokenizer_class = XLNetTokenizer
31 |
32 | def setUp(self):
33 | super(XLNetTokenizationTest, self).setUp()
34 |
35 | # We have a SentencePiece fixture for testing
36 | tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True)
37 | tokenizer.save_pretrained(self.tmpdirname)
38 |
39 | def get_tokenizer(self, **kwargs):
40 | return XLNetTokenizer.from_pretrained(self.tmpdirname, **kwargs)
41 |
42 | def get_input_output_texts(self):
43 | input_text = u"This is a test"
44 | output_text = u"This is a test"
45 | return input_text, output_text
46 |
47 |
48 | def test_full_tokenizer(self):
49 | tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True)
50 |
51 | tokens = tokenizer.tokenize(u'This is a test')
52 | self.assertListEqual(tokens, [u'▁This', u'▁is', u'▁a', u'▁t', u'est'])
53 |
54 | self.assertListEqual(
55 | tokenizer.convert_tokens_to_ids(tokens), [285, 46, 10, 170, 382])
56 |
57 | tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.")
58 | self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b',
59 | u'or', u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'',
60 | u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
61 | SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's', u'é', u'.'])
62 | ids = tokenizer.convert_tokens_to_ids(tokens)
63 | self.assertListEqual(
64 | ids, [8, 21, 84, 55, 24, 19, 7, 0,
65 | 602, 347, 347, 347, 3, 12, 66,
66 | 46, 72, 80, 6, 0, 4])
67 |
68 | back_tokens = tokenizer.convert_ids_to_tokens(ids)
69 | self.assertListEqual(back_tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b',
70 | u'or', u'n', SPIECE_UNDERLINE + u'in',
71 | SPIECE_UNDERLINE + u'', u'', u'2', u'0', u'0', u'0', u',',
72 | SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
73 | SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's',
74 | u'', u'.'])
75 |
76 | def test_tokenizer_lower(self):
77 | tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=True)
78 | tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.")
79 | self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'', u'i', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b',
80 | u'or', u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'',
81 | u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
82 | SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u'se', u'.'])
83 | self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), [u"▁he", u"ll", u"o"])
84 |
85 | def test_tokenizer_no_lower(self):
86 | tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=False)
87 | tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.")
88 | self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', u'or',
89 | u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'',
90 | u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
91 | SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u'se', u'.'])
92 |
93 | @slow
94 | def test_sequence_builders(self):
95 | tokenizer = XLNetTokenizer.from_pretrained("xlnet-base-cased")
96 |
97 | text = tokenizer.encode("sequence builders", add_special_tokens=False)
98 | text_2 = tokenizer.encode("multi-sequence build", add_special_tokens=False)
99 |
100 | encoded_sentence = tokenizer.build_inputs_with_special_tokens(text)
101 | encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2)
102 |
103 | assert encoded_sentence == text + [4, 3]
104 | assert encoded_pair == text + [4] + text_2 + [4, 3]
105 |
106 |
107 | if __name__ == '__main__':
108 | unittest.main()
109 |
--------------------------------------------------------------------------------
/transformers/tests/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import unittest
3 |
4 | from distutils.util import strtobool
5 |
6 | from transformers.file_utils import _tf_available, _torch_available
7 |
8 |
9 | SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"
10 |
11 |
12 | def parse_flag_from_env(key, default=False):
13 | try:
14 | value = os.environ[key]
15 | except KeyError:
16 | # KEY isn't set, default to `default`.
17 | _value = default
18 | else:
19 | # KEY is set, convert it to True or False.
20 | try:
21 | _value = strtobool(value)
22 | except ValueError:
23 | # More values are supported, but let's keep the message simple.
24 | raise ValueError("If set, {} must be yes or no.".format(key))
25 | return _value
26 |
27 | _run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
28 | _run_custom_tokenizers = parse_flag_from_env("RUN_CUSTOM_TOKENIZERS", default=False)
29 |
30 |
31 | def slow(test_case):
32 | """
33 | Decorator marking a test as slow.
34 |
35 | Slow tests are skipped by default. Set the RUN_SLOW environment variable
36 | to a truthy value to run them.
37 |
38 | """
39 | if not _run_slow_tests:
40 | test_case = unittest.skip("test is slow")(test_case)
41 | return test_case
42 |
43 |
44 | def custom_tokenizers(test_case):
45 | """
46 | Decorator marking a test for a custom tokenizer.
47 |
48 | Custom tokenizers require additional dependencies, and are skipped
49 | by default. Set the RUN_CUSTOM_TOKENIZERS environment variable
50 | to a truthy value to run them.
51 | """
52 | if not _run_custom_tokenizers:
53 | test_case = unittest.skip("test of custom tokenizers")(test_case)
54 | return test_case
55 |
56 |
57 | def require_torch(test_case):
58 | """
59 | Decorator marking a test that requires PyTorch.
60 |
61 | These tests are skipped when PyTorch isn't installed.
62 |
63 | """
64 | if not _torch_available:
65 | test_case = unittest.skip("test requires PyTorch")(test_case)
66 | return test_case
67 |
68 |
69 | def require_tf(test_case):
70 | """
71 | Decorator marking a test that requires TensorFlow.
72 |
73 | These tests are skipped when TensorFlow isn't installed.
74 |
75 | """
76 | if not _tf_available:
77 | test_case = unittest.skip("test requires TensorFlow")(test_case)
78 | return test_case
79 |
80 |
81 | if _torch_available:
82 | # Set the USE_CUDA environment variable to select a GPU.
83 | torch_device = "cuda" if parse_flag_from_env("USE_CUDA") else "cpu"
84 | else:
85 | torch_device = None
86 |
--------------------------------------------------------------------------------