├── LICENSE.txt
├── README.md
├── dat
├── friendsqa_dev.json
├── friendsqa_trn.json
└── friendsqa_tst.json
└── python
├── README.md
└── src
├── examples
├── run_friends_split_context.py
├── run_friends_whole_context.py
├── run_language_modeling.py
├── run_umlm.py
└── run_uop.py
├── transformers
├── __init__.py
├── activations.py
├── benchmark_utils.py
├── commands
│ ├── __init__.py
│ ├── convert.py
│ ├── download.py
│ ├── env.py
│ ├── run.py
│ ├── serving.py
│ ├── train.py
│ └── user.py
├── configuration_albert.py
├── configuration_auto.py
├── configuration_bart.py
├── configuration_bert.py
├── configuration_camembert.py
├── configuration_ctrl.py
├── configuration_distilbert.py
├── configuration_flaubert.py
├── configuration_gpt2.py
├── configuration_mmbt.py
├── configuration_openai.py
├── configuration_roberta.py
├── configuration_t5.py
├── configuration_transfo_xl.py
├── configuration_utils.py
├── configuration_xlm.py
├── configuration_xlm_roberta.py
├── configuration_xlnet.py
├── convert_albert_original_tf_checkpoint_to_pytorch.py
├── convert_bart_original_pytorch_checkpoint_to_pytorch.py
├── convert_bert_original_tf_checkpoint_to_pytorch.py
├── convert_bert_pytorch_checkpoint_to_original_tf.py
├── convert_dialogpt_original_pytorch_checkpoint_to_pytorch.py
├── convert_gpt2_original_tf_checkpoint_to_pytorch.py
├── convert_openai_original_tf_checkpoint_to_pytorch.py
├── convert_pytorch_checkpoint_to_tf2.py
├── convert_roberta_original_pytorch_checkpoint_to_pytorch.py
├── convert_t5_original_tf_checkpoint_to_pytorch.py
├── convert_transfo_xl_original_tf_checkpoint_to_pytorch.py
├── convert_xlm_original_pytorch_checkpoint_to_pytorch.py
├── convert_xlnet_original_tf_checkpoint_to_pytorch.py
├── data
│ ├── __init__.py
│ ├── metrics
│ │ ├── __init__.py
│ │ └── squad_metrics.py
│ └── processors
│ │ ├── __init__.py
│ │ ├── friendsqa.py
│ │ ├── glue.py
│ │ ├── squad.py
│ │ ├── umlm.py
│ │ ├── uop.py
│ │ ├── utils.py
│ │ └── xnli.py
├── file_utils.py
├── hf_api.py
├── modelcard.py
├── modeling_albert.py
├── modeling_auto.py
├── modeling_bart.py
├── modeling_bert.py
├── modeling_camembert.py
├── modeling_ctrl.py
├── modeling_distilbert.py
├── modeling_encoder_decoder.py
├── modeling_flaubert.py
├── modeling_gpt2.py
├── modeling_mmbt.py
├── modeling_openai.py
├── modeling_roberta.py
├── modeling_t5.py
├── modeling_tf_albert.py
├── modeling_tf_auto.py
├── modeling_tf_bert.py
├── modeling_tf_camembert.py
├── modeling_tf_ctrl.py
├── modeling_tf_distilbert.py
├── modeling_tf_flaubert.py
├── modeling_tf_gpt2.py
├── modeling_tf_openai.py
├── modeling_tf_pytorch_utils.py
├── modeling_tf_roberta.py
├── modeling_tf_t5.py
├── modeling_tf_transfo_xl.py
├── modeling_tf_transfo_xl_utilities.py
├── modeling_tf_utils.py
├── modeling_tf_xlm.py
├── modeling_tf_xlm_roberta.py
├── modeling_tf_xlnet.py
├── modeling_transfo_xl.py
├── modeling_transfo_xl_utilities.py
├── modeling_utils.py
├── modeling_xlm.py
├── modeling_xlm_roberta.py
├── modeling_xlnet.py
├── optimization.py
├── optimization_tf.py
├── pipelines.py
├── tokenization_albert.py
├── tokenization_auto.py
├── tokenization_bart.py
├── tokenization_bert.py
├── tokenization_bert_japanese.py
├── tokenization_camembert.py
├── tokenization_ctrl.py
├── tokenization_distilbert.py
├── tokenization_flaubert.py
├── tokenization_gpt2.py
├── tokenization_openai.py
├── tokenization_roberta.py
├── tokenization_t5.py
├── tokenization_transfo_xl.py
├── tokenization_utils.py
├── tokenization_xlm.py
├── tokenization_xlm_roberta.py
├── tokenization_xlnet.py
└── utils_encoder_decoder.py
└── utils
├── analysis.py
├── categorizing.py
├── evaluate_split_context.py
├── evaluate_whole_context.py
└── test_model.py
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | Copyright 2019, 2020 Emory University
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Question Answering on Multiparty Dialogue
2 |
3 | Question Answering challenges the machine's ability of answering to queries in different forms.
4 | This is a part of the [Character Mining](../../../character-mining) project led by the [Emory NLP](http://nlp.mathcs.emory.edu) research group.
5 | The following shows a multiparty dialogue between Joey and Chandler with 6 questions regarding the contents of the dialogue. Questions could be in 6 forms (_what, who, when, where, why, how_)
6 |
7 | | Speaker | Utterance |
8 | |:-------:|-----------|
9 | | `U01` | [Scene: Central Perk, Joey is getting a phone number from a woman (Casey) as Chandler watches from the doorway.] |
10 | | `U02` | Casey: Here you go. |
11 | | `U03` | Joey: Great! All right, so I’ll call you later.|
12 | | `U04` | Casey: Great!|
13 | | `U05` | Chandler: Hey-Hey-Hey! Who was that? |
14 | | `U06` | Joey: That would be Casey. We’re going out tonight. |
15 | | `U07` | Chandler: Goin’ out, huh? Wow! Wow! So things didn’t work out with Kathy, huh? Bummer.|
16 | | `U08` | Joey: No, things are fine with Kathy. I’m having a late dinner with her tonight, right after my early dinner with Casey. |
17 | | `U09` | Chandler: What? |
18 | | `U10` | Joey: Yeah-yeah. And the craziest thing is that I just ate a whole pizza by myself! |
19 | | `U11` | Chandler: Wait! You’re going out with Kathy! |
20 | | `U12` | Joey: Yeah. Why are you getting so upset? |
21 | | `U13` | Chandler: Well, I’m upset for you. I mean, dating an endless line of beautiful women must be very unfulfilling for you. |
22 |
23 | * Q1: What
is Joey going to do with Casey tonight?
24 | * Q2: Who
is Joey getting a phone number from?
25 | * Q3: When
will Joey have dinner with Kathy?
26 | * Q4: Where
are Joey and Chandler?
27 | * Q5: Why
is Chandler upset?
28 | * Q6: How
are things between Joey and Kathy?
29 |
30 | Your task is to answering these open-domain questions using contiguous spans from the dialogues.
31 | This task is challenging because questions could be in any form and might not contain the exact words from the document.
32 |
33 | ## Dataset
34 |
35 | For the generation of the FriendsQA dataset, 1,222 scenes from the first four seasons of the Character Mining dataset are selected. Scenes with fewer than five utterances are discarded (83 of them), and each scene is considered an independent dialogue. FriendQA can be viewed as answer span selection, where questions are asked for some contexts in a dialogue and the model is expected to find certain spans in the dialogue containing answer contents. The dialogue aspects of this dataset, however, make it more challenging than other datasets comprising passages in formal languages. Details could be found in the paper.
36 |
37 | * Latest release: [v2.0](https://github.com/emorynlp/reading-comprehension/archive/reading-comprehension-2.0.tar.gz)
38 |
39 | ## Statistics
40 |
41 | The data split is based on chronological order of the episodes that is consistent across other Character Mining projects:
42 |
43 | | Dataset | Dialogues | Questions | Answers | Eposides |
44 | | :-----: | --------: | --------: | ------: | -------: |
45 | | TRN | 973 | 9,791 | 16,352 | 1 - 20 |
46 | | DEV | 113 | 1,189 | 2,065 | 21 - 22 |
47 | | TST | 136 | 1,172 | 1,920 | 23 - * |
48 |
49 |
50 | ##
51 |
52 | ## Annotation
53 | The format of the data separates the context into several utterances and separates the speakers and utterance for each utterance.
54 |
55 | ```json
56 | "utterances:": [
57 | {
58 | "uid": 0,
59 | "speakers": [
60 | "Ross Geller"
61 | ],
62 | "utterance": "Breathe ."
63 | },
64 | {
65 | "uid": 1,
66 | "speakers": [
67 | "Susan Bunch"
68 | ],
69 | "utterance": "Breathe ."
70 | },
71 | {
72 | "uid": 2,
73 | "speakers": [
74 | "Carol Willick"
75 | ],
76 | "utterance": "You 're gon na kill me !"
77 | }
78 | ]
79 | ```
80 |
81 | The "qas" field includes questions and the answers. An answer has five keys. The first is "answer_text" which denotes the original text. The second is "utterance_id" which denotes the answer appearing in which utterance. The "inner_start" and "inner_end" denote the answer start and end token position in the corresponding utterance and if answer is the speaker, their values are -1 and the "is_speaker" is set as true .
82 |
83 | ```json
84 | "qas": [
85 | {
86 | "id": "s01_e23_c06_What",
87 | "question": "What does Ross want to name his son ?",
88 | "answers": [
89 | {
90 | "answer_text": "Jamie",
91 | "utterance_id": 12,
92 | "inner_start": 24,
93 | "inner_end": 24,
94 | "is_speaker": false
95 | },
96 | {
97 | "answer_text": "Jordie .",
98 | "utterance_id": 9,
99 | "inner_start": 21,
100 | "inner_end": 22,
101 | "is_speaker": false
102 | }
103 | ]
104 | },
105 | {
106 | "id": "s01_e23_c06_Who_Paraphrased",
107 | "question": "By whom was Ross told to count faster ?",
108 | "answers": [
109 | {
110 | "answer_text": "Carol Willick",
111 | "utterance_id": 6,
112 | "inner_start": -1,
113 | "inner_end": -1,
114 | "is_speaker": true
115 | }
116 | ]
117 | }
118 | ]
119 | ```
120 |
121 | ## Citation
122 |
123 | * [Transformers to Learn Hierarchical Contexts in Multiparty Dialogue for Span-based Question Answering](). Changmao Li and Jinho D. Choi. In Proceedings of the Conference of the Association for Computational Linguistics, ACL'20, 2020.
124 | * [FriendsQA: Open-Domain Question Answering on TV Show Transcripts](https://www.aclweb.org/anthology/W19-5923). Zhengzhe Yang and Jinho D. Choi. In Proceedings of the Annual Conference of the ACL Special Interest Group on Discourse and Dialogue, SIGDIAL'19, 2019.
125 |
126 | ## Contact
127 |
128 | * [Jinho D. Choi](http://www.mathcs.emory.edu/~choi).
--------------------------------------------------------------------------------
/python/README.md:
--------------------------------------------------------------------------------
1 | # Hierarchical Transformer for Span-based QA
2 |
3 | ## Citation
4 |
5 | * [Transformers to Learn Hierarchical Contexts in Multiparty Dialogue for Span-based Question Answering](). Changmao Li and Jinho D. Choi. In Proceedings of the Conference of the Association for Computational Linguistics, ACL'20, 2020.
6 | * Note that some of the source codes are based on [`huggingface/transformers`](https://github.com/huggingface/transformers).
7 |
8 |
9 | ## Source files include:
10 |
11 | ### Data process files:
12 |
13 | ```
14 | /src/transformers/data/processors/umlm.py
15 | ```
16 |
17 | * Read generated utterance-level masked language modeling data into examples and create features from the examples
18 |
19 | ```
20 | /src/transformers/data/processors/uop.py
21 | ```
22 | * Read generated utterance order prediction (UOP) data into examples and create features from the examples
23 |
24 | ```
25 | /src/transformers/data/processors/friendsqa.py
26 | ```
27 |
28 | * Read friendsqa data into examples and create features from the examples
29 |
30 |
31 | ### Model files:
32 |
33 | ```
34 | /src/transformers/modeling_bert.py:
35 | ```
36 |
37 | * Token-level masked language modeling BERT model
38 | * Utterance-level masked language modeling BERT model
39 | * Utterance order prediction BERT model
40 | * QA whole context fine-tuning BERT model
41 | * QA split context fine-tuning BERT model
42 |
43 | ```
44 | /src/transformers/modeling_roberta.py:
45 | ```
46 |
47 | * Token-level masked language modeling RoBERTa model
48 | * Utterance-level masked language modeling RoBERTa model
49 | * Utterance order prediction RoBERTa model
50 | * QA whole context fine-tuning RoBERTa model
51 | * QA split context fine-tuning RoBERTa model
52 |
53 |
54 | ### Executive files:
55 |
56 | ```
57 | /src/examples/run_language_modeling.py
58 | ```
59 |
60 | * Run BERT or RoBERTa token-level masked language modeling(TMLM)
61 |
62 | ```
63 | /src/examples/run_umlm.py
64 | ```
65 |
66 | * Run BERT or RoBERTa utterance-level masked language modeling(UMLM)
67 |
68 | ```
69 | /src/examples/run_uop.py
70 | ```
71 |
72 | * Run BERT or RoBERTa utterance order prediction(UOP)
73 |
74 | ```
75 | /src/examples/run_friends_whole_context.py
76 | ```
77 |
78 | * Run BERT or RoBERTa fine-tuning on FriendsQA in whole context format
79 |
80 | ```
81 | /src/examples/run_friends_split_context.py
82 | ```
83 |
84 | * Run BERT or RoBERTa fine-tuning on FriendsQA in split context into utterances format
85 |
86 |
87 | ### Other utility files:
88 |
89 | ```
90 | /src/utils/test_models.py
91 | ```
92 |
93 | * Test if all neural models correctly work
94 |
95 | ```
96 | /src/utils/categorizing.py
97 | ```
98 |
99 | * Categorize results by question types
100 |
101 | ```
102 | /src/utils/analysis.py
103 | ```
104 |
105 | * Analyze categorized results
106 |
107 | ```
108 | /src/utils/evaluate_whole_context.py
109 | ```
110 |
111 | * Evaluate the whole context QA fine-tuning results
112 |
113 | ```
114 | /src/utils/evaluate_split_context.py
115 | ```
116 |
117 | * Evaluate the split context QA fine-tuning results
118 |
119 |
120 | ## Other Notices
121 |
122 | You need to generate your own language model pre-training data fit for your own corpus.
123 | Here is the data format for language model pre-training data.
124 |
125 | ### Pre-training data format
126 |
127 | #### TMLM
128 |
129 | * The TMLM data format is a text file.
130 | * Each line is an utterance and there is an empty line between dialogues.
131 |
132 | #### UMLM
133 |
134 | * The UMLM data format is a csv file.
135 | * Each line includes a tokenized one token masked utterance and the masked token. The separator is \t.
136 | * You can modify it to your own format in /src/transformers/data/processors/umlm.py if you want.
137 |
138 | #### UOP
139 |
140 | * The UOP data format is a json file.
141 | * The json format is
142 |
143 | ```python
144 | [
145 | {
146 | "utterances": ["u1", "u2", .....] # The utterances list,
147 | "is_correct_order": "Yes" or "No" # If the utterance is in correct order or not
148 | },
149 | ....
150 | ]
151 | ```
152 |
153 | * You can modify it to your own format in /src/transformers/data/processors/uop.py if you want.
154 |
155 |
--------------------------------------------------------------------------------
/python/src/transformers/activations.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn.functional as F
5 |
6 |
7 | def swish(x):
8 | return x * torch.sigmoid(x)
9 |
10 |
11 | def _gelu_python(x):
12 | """ Original Implementation of the gelu activation function in Google Bert repo when initially created.
13 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
14 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
15 | This is now written in C in torch.nn.functional
16 | Also see https://arxiv.org/abs/1606.08415
17 | """
18 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
19 |
20 |
21 | if torch.__version__ < "1.4.0":
22 | gelu = _gelu_python
23 | else:
24 | gelu = F.gelu
25 |
26 |
27 | def gelu_new(x):
28 | """ Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT).
29 | Also see https://arxiv.org/abs/1606.08415
30 | """
31 | return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
32 |
33 |
34 | ACT2FN = {
35 | "relu": F.relu,
36 | "swish": swish,
37 | "gelu": gelu,
38 | "tanh": F.tanh,
39 | "gelu_new": gelu_new,
40 | }
41 |
42 |
43 | def get_activation(activation_string):
44 | if activation_string in ACT2FN:
45 | return ACT2FN[activation_string]
46 | else:
47 | raise KeyError("function {} not found in ACT2FN mapping {}".format(activation_string, list(ACT2FN.keys())))
48 |
--------------------------------------------------------------------------------
/python/src/transformers/commands/__init__.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from argparse import ArgumentParser
3 |
4 |
5 | class BaseTransformersCLICommand(ABC):
6 | @staticmethod
7 | @abstractmethod
8 | def register_subcommand(parser: ArgumentParser):
9 | raise NotImplementedError()
10 |
11 | @abstractmethod
12 | def run(self):
13 | raise NotImplementedError()
14 |
--------------------------------------------------------------------------------
/python/src/transformers/commands/convert.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser, Namespace
2 | from logging import getLogger
3 |
4 | from transformers.commands import BaseTransformersCLICommand
5 |
6 |
7 | def convert_command_factory(args: Namespace):
8 | """
9 | Factory function used to convert a model TF 1.0 checkpoint in a PyTorch checkpoint.
10 | :return: ServeCommand
11 | """
12 | return ConvertCommand(
13 | args.model_type, args.tf_checkpoint, args.pytorch_dump_output, args.config, args.finetuning_task_name
14 | )
15 |
16 |
17 | class ConvertCommand(BaseTransformersCLICommand):
18 | @staticmethod
19 | def register_subcommand(parser: ArgumentParser):
20 | """
21 | Register this command to argparse so it's available for the transformer-cli
22 | :param parser: Root parser to register command-specific arguments
23 | :return:
24 | """
25 | train_parser = parser.add_parser(
26 | "convert",
27 | help="CLI tool to run convert model from original "
28 | "author checkpoints to Transformers PyTorch checkpoints.",
29 | )
30 | train_parser.add_argument("--model_type", type=str, required=True, help="Model's type.")
31 | train_parser.add_argument(
32 | "--tf_checkpoint", type=str, required=True, help="TensorFlow checkpoint path or folder."
33 | )
34 | train_parser.add_argument(
35 | "--pytorch_dump_output", type=str, required=True, help="Path to the PyTorch savd model output."
36 | )
37 | train_parser.add_argument("--config", type=str, default="", help="Configuration file path or folder.")
38 | train_parser.add_argument(
39 | "--finetuning_task_name",
40 | type=str,
41 | default=None,
42 | help="Optional fine-tuning task name if the TF model was a finetuned model.",
43 | )
44 | train_parser.set_defaults(func=convert_command_factory)
45 |
46 | def __init__(
47 | self,
48 | model_type: str,
49 | tf_checkpoint: str,
50 | pytorch_dump_output: str,
51 | config: str,
52 | finetuning_task_name: str,
53 | *args
54 | ):
55 | self._logger = getLogger("transformers-cli/converting")
56 |
57 | self._logger.info("Loading model {}".format(model_type))
58 | self._model_type = model_type
59 | self._tf_checkpoint = tf_checkpoint
60 | self._pytorch_dump_output = pytorch_dump_output
61 | self._config = config
62 | self._finetuning_task_name = finetuning_task_name
63 |
64 | def run(self):
65 | if self._model_type == "bert":
66 | try:
67 | from transformers.convert_bert_original_tf_checkpoint_to_pytorch import (
68 | convert_tf_checkpoint_to_pytorch,
69 | )
70 | except ImportError:
71 | msg = (
72 | "transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
73 | "In that case, it requires TensorFlow to be installed. Please see "
74 | "https://www.tensorflow.org/install/ for installation instructions."
75 | )
76 | raise ImportError(msg)
77 |
78 | convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
79 | elif self._model_type == "gpt":
80 | from transformers.convert_openai_original_tf_checkpoint_to_pytorch import (
81 | convert_openai_checkpoint_to_pytorch,
82 | )
83 |
84 | convert_openai_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
85 | elif self._model_type == "transfo_xl":
86 | try:
87 | from transformers.convert_transfo_xl_original_tf_checkpoint_to_pytorch import (
88 | convert_transfo_xl_checkpoint_to_pytorch,
89 | )
90 | except ImportError:
91 | msg = (
92 | "transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
93 | "In that case, it requires TensorFlow to be installed. Please see "
94 | "https://www.tensorflow.org/install/ for installation instructions."
95 | )
96 | raise ImportError(msg)
97 |
98 | if "ckpt" in self._tf_checkpoint.lower():
99 | TF_CHECKPOINT = self._tf_checkpoint
100 | TF_DATASET_FILE = ""
101 | else:
102 | TF_DATASET_FILE = self._tf_checkpoint
103 | TF_CHECKPOINT = ""
104 | convert_transfo_xl_checkpoint_to_pytorch(
105 | TF_CHECKPOINT, self._config, self._pytorch_dump_output, TF_DATASET_FILE
106 | )
107 | elif self._model_type == "gpt2":
108 | try:
109 | from transformers.convert_gpt2_original_tf_checkpoint_to_pytorch import (
110 | convert_gpt2_checkpoint_to_pytorch,
111 | )
112 | except ImportError:
113 | msg = (
114 | "transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
115 | "In that case, it requires TensorFlow to be installed. Please see "
116 | "https://www.tensorflow.org/install/ for installation instructions."
117 | )
118 | raise ImportError(msg)
119 |
120 | convert_gpt2_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
121 | elif self._model_type == "xlnet":
122 | try:
123 | from transformers.convert_xlnet_original_tf_checkpoint_to_pytorch import (
124 | convert_xlnet_checkpoint_to_pytorch,
125 | )
126 | except ImportError:
127 | msg = (
128 | "transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
129 | "In that case, it requires TensorFlow to be installed. Please see "
130 | "https://www.tensorflow.org/install/ for installation instructions."
131 | )
132 | raise ImportError(msg)
133 |
134 | convert_xlnet_checkpoint_to_pytorch(
135 | self._tf_checkpoint, self._config, self._pytorch_dump_output, self._finetuning_task_name
136 | )
137 | elif self._model_type == "xlm":
138 | from transformers.convert_xlm_original_pytorch_checkpoint_to_pytorch import (
139 | convert_xlm_checkpoint_to_pytorch,
140 | )
141 |
142 | convert_xlm_checkpoint_to_pytorch(self._tf_checkpoint, self._pytorch_dump_output)
143 | else:
144 | raise ValueError("--model_type should be selected in the list [bert, gpt, gpt2, transfo_xl, xlnet, xlm]")
145 |
--------------------------------------------------------------------------------
/python/src/transformers/commands/download.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 |
3 | from transformers.commands import BaseTransformersCLICommand
4 |
5 |
6 | def download_command_factory(args):
7 | return DownloadCommand(args.model, args.cache_dir, args.force)
8 |
9 |
10 | class DownloadCommand(BaseTransformersCLICommand):
11 | @staticmethod
12 | def register_subcommand(parser: ArgumentParser):
13 | download_parser = parser.add_parser("download")
14 | download_parser.add_argument(
15 | "--cache-dir", type=str, default=None, help="Path to location to store the models"
16 | )
17 | download_parser.add_argument(
18 | "--force", action="store_true", help="Force the model to be download even if already in cache-dir"
19 | )
20 | download_parser.add_argument("model", type=str, help="Name of the model to download")
21 | download_parser.set_defaults(func=download_command_factory)
22 |
23 | def __init__(self, model: str, cache: str, force: bool):
24 | self._model = model
25 | self._cache = cache
26 | self._force = force
27 |
28 | def run(self):
29 | from transformers import AutoModel, AutoTokenizer
30 |
31 | AutoModel.from_pretrained(self._model, cache_dir=self._cache, force_download=self._force)
32 | AutoTokenizer.from_pretrained(self._model, cache_dir=self._cache, force_download=self._force)
33 |
--------------------------------------------------------------------------------
/python/src/transformers/commands/env.py:
--------------------------------------------------------------------------------
1 | import platform
2 | from argparse import ArgumentParser
3 |
4 | from transformers import __version__ as version
5 | from transformers import is_tf_available, is_torch_available
6 | from transformers.commands import BaseTransformersCLICommand
7 |
8 |
9 | def info_command_factory(_):
10 | return EnvironmentCommand()
11 |
12 |
13 | class EnvironmentCommand(BaseTransformersCLICommand):
14 | @staticmethod
15 | def register_subcommand(parser: ArgumentParser):
16 | download_parser = parser.add_parser("env")
17 | download_parser.set_defaults(func=info_command_factory)
18 |
19 | def run(self):
20 | pt_version = "not installed"
21 | pt_cuda_available = "NA"
22 | if is_torch_available():
23 | import torch
24 |
25 | pt_version = torch.__version__
26 | pt_cuda_available = torch.cuda.is_available()
27 |
28 | tf_version = "not installed"
29 | tf_cuda_available = "NA"
30 | if is_tf_available():
31 | import tensorflow as tf
32 |
33 | tf_version = tf.__version__
34 | try:
35 | # deprecated in v2.1
36 | tf_cuda_available = tf.test.is_gpu_available()
37 | except AttributeError:
38 | # returns list of devices, convert to bool
39 | tf_cuda_available = bool(tf.config.list_physical_devices("GPU"))
40 |
41 | info = {
42 | "`transformers` version": version,
43 | "Platform": platform.platform(),
44 | "Python version": platform.python_version(),
45 | "PyTorch version (GPU?)": "{} ({})".format(pt_version, pt_cuda_available),
46 | "Tensorflow version (GPU?)": "{} ({})".format(tf_version, tf_cuda_available),
47 | "Using GPU in script?": "",
48 | "Using distributed or parallel set-up in script?": "",
49 | }
50 |
51 | print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n")
52 | print(self.format_dict(info))
53 |
54 | return info
55 |
56 | @staticmethod
57 | def format_dict(d):
58 | return "\n".join(["- {}: {}".format(prop, val) for prop, val in d.items()]) + "\n"
59 |
--------------------------------------------------------------------------------
/python/src/transformers/commands/run.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from argparse import ArgumentParser
3 |
4 | from transformers.commands import BaseTransformersCLICommand
5 | from transformers.pipelines import SUPPORTED_TASKS, Pipeline, PipelineDataFormat, pipeline
6 |
7 |
8 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name
9 |
10 |
11 | def try_infer_format_from_ext(path: str):
12 | if not path:
13 | return "pipe"
14 |
15 | for ext in PipelineDataFormat.SUPPORTED_FORMATS:
16 | if path.endswith(ext):
17 | return ext
18 |
19 | raise Exception(
20 | "Unable to determine file format from file extension {}. "
21 | "Please provide the format through --format {}".format(path, PipelineDataFormat.SUPPORTED_FORMATS)
22 | )
23 |
24 |
25 | def run_command_factory(args):
26 | nlp = pipeline(
27 | task=args.task,
28 | model=args.model if args.model else None,
29 | config=args.config,
30 | tokenizer=args.tokenizer,
31 | device=args.device,
32 | )
33 | format = try_infer_format_from_ext(args.input) if args.format == "infer" else args.format
34 | reader = PipelineDataFormat.from_str(
35 | format=format,
36 | output_path=args.output,
37 | input_path=args.input,
38 | column=args.column if args.column else nlp.default_input_names,
39 | overwrite=args.overwrite,
40 | )
41 | return RunCommand(nlp, reader)
42 |
43 |
44 | class RunCommand(BaseTransformersCLICommand):
45 | def __init__(self, nlp: Pipeline, reader: PipelineDataFormat):
46 | self._nlp = nlp
47 | self._reader = reader
48 |
49 | @staticmethod
50 | def register_subcommand(parser: ArgumentParser):
51 | run_parser = parser.add_parser("run", help="Run a pipeline through the CLI")
52 | run_parser.add_argument("--task", choices=SUPPORTED_TASKS.keys(), help="Task to run")
53 | run_parser.add_argument("--input", type=str, help="Path to the file to use for inference")
54 | run_parser.add_argument("--output", type=str, help="Path to the file that will be used post to write results.")
55 | run_parser.add_argument("--model", type=str, help="Name or path to the model to instantiate.")
56 | run_parser.add_argument("--config", type=str, help="Name or path to the model's config to instantiate.")
57 | run_parser.add_argument(
58 | "--tokenizer", type=str, help="Name of the tokenizer to use. (default: same as the model name)"
59 | )
60 | run_parser.add_argument(
61 | "--column",
62 | type=str,
63 | help="Name of the column to use as input. (For multi columns input as QA use column1,columns2)",
64 | )
65 | run_parser.add_argument(
66 | "--format",
67 | type=str,
68 | default="infer",
69 | choices=PipelineDataFormat.SUPPORTED_FORMATS,
70 | help="Input format to read from",
71 | )
72 | run_parser.add_argument(
73 | "--device",
74 | type=int,
75 | default=-1,
76 | help="Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)",
77 | )
78 | run_parser.add_argument("--overwrite", action="store_true", help="Allow overwriting the output file.")
79 | run_parser.set_defaults(func=run_command_factory)
80 |
81 | def run(self):
82 | nlp, outputs = self._nlp, []
83 |
84 | for entry in self._reader:
85 | output = nlp(**entry) if self._reader.is_multi_columns else nlp(entry)
86 | if isinstance(output, dict):
87 | outputs.append(output)
88 | else:
89 | outputs += output
90 |
91 | # Saving data
92 | if self._nlp.binary_output:
93 | binary_path = self._reader.save_binary(outputs)
94 | logger.warning("Current pipeline requires output to be in binary format, saving at {}".format(binary_path))
95 | else:
96 | self._reader.save(outputs)
97 |
--------------------------------------------------------------------------------
/python/src/transformers/commands/serving.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from argparse import ArgumentParser, Namespace
3 | from typing import Any, List, Optional
4 |
5 | from transformers import Pipeline
6 | from transformers.commands import BaseTransformersCLICommand
7 | from transformers.pipelines import SUPPORTED_TASKS, pipeline
8 |
9 |
10 | try:
11 | from uvicorn import run
12 | from fastapi import FastAPI, HTTPException, Body
13 | from fastapi.routing import APIRoute
14 | from pydantic import BaseModel
15 | from starlette.responses import JSONResponse
16 |
17 | _serve_dependencies_installed = True
18 | except (ImportError, AttributeError):
19 | BaseModel = object
20 |
21 | def Body(*x, **y):
22 | pass
23 |
24 | _serve_dependencies_installed = False
25 |
26 |
27 | logger = logging.getLogger("transformers-cli/serving")
28 |
29 |
30 | def serve_command_factory(args: Namespace):
31 | """
32 | Factory function used to instantiate serving server from provided command line arguments.
33 | :return: ServeCommand
34 | """
35 | nlp = pipeline(
36 | task=args.task,
37 | model=args.model if args.model else None,
38 | config=args.config,
39 | tokenizer=args.tokenizer,
40 | device=args.device,
41 | )
42 | return ServeCommand(nlp, args.host, args.port, args.workers)
43 |
44 |
45 | class ServeModelInfoResult(BaseModel):
46 | """
47 | Expose model information
48 | """
49 |
50 | infos: dict
51 |
52 |
53 | class ServeTokenizeResult(BaseModel):
54 | """
55 | Tokenize result model
56 | """
57 |
58 | tokens: List[str]
59 | tokens_ids: Optional[List[int]]
60 |
61 |
62 | class ServeDeTokenizeResult(BaseModel):
63 | """
64 | DeTokenize result model
65 | """
66 |
67 | text: str
68 |
69 |
70 | class ServeForwardResult(BaseModel):
71 | """
72 | Forward result model
73 | """
74 |
75 | output: Any
76 |
77 |
78 | class ServeCommand(BaseTransformersCLICommand):
79 | @staticmethod
80 | def register_subcommand(parser: ArgumentParser):
81 | """
82 | Register this command to argparse so it's available for the transformer-cli
83 | :param parser: Root parser to register command-specific arguments
84 | :return:
85 | """
86 | serve_parser = parser.add_parser(
87 | "serve", help="CLI tool to run inference requests through REST and GraphQL endpoints."
88 | )
89 | serve_parser.add_argument(
90 | "--task", type=str, choices=SUPPORTED_TASKS.keys(), help="The task to run the pipeline on"
91 | )
92 | serve_parser.add_argument("--host", type=str, default="localhost", help="Interface the server will listen on.")
93 | serve_parser.add_argument("--port", type=int, default=8888, help="Port the serving will listen to.")
94 | serve_parser.add_argument("--workers", type=int, default=1, help="Number of http workers")
95 | serve_parser.add_argument("--model", type=str, help="Model's name or path to stored model.")
96 | serve_parser.add_argument("--config", type=str, help="Model's config name or path to stored model.")
97 | serve_parser.add_argument("--tokenizer", type=str, help="Tokenizer name to use.")
98 | serve_parser.add_argument(
99 | "--device",
100 | type=int,
101 | default=-1,
102 | help="Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)",
103 | )
104 | serve_parser.set_defaults(func=serve_command_factory)
105 |
106 | def __init__(self, pipeline: Pipeline, host: str, port: int, workers: int):
107 |
108 | self._pipeline = pipeline
109 |
110 | self.host = host
111 | self.port = port
112 | self.workers = workers
113 |
114 | if not _serve_dependencies_installed:
115 | raise RuntimeError(
116 | "Using serve command requires FastAPI and unicorn. "
117 | 'Please install transformers with [serving]: pip install "transformers[serving]".'
118 | "Or install FastAPI and unicorn separately."
119 | )
120 | else:
121 | logger.info("Serving model over {}:{}".format(host, port))
122 | self._app = FastAPI(
123 | routes=[
124 | APIRoute(
125 | "/",
126 | self.model_info,
127 | response_model=ServeModelInfoResult,
128 | response_class=JSONResponse,
129 | methods=["GET"],
130 | ),
131 | APIRoute(
132 | "/tokenize",
133 | self.tokenize,
134 | response_model=ServeTokenizeResult,
135 | response_class=JSONResponse,
136 | methods=["POST"],
137 | ),
138 | APIRoute(
139 | "/detokenize",
140 | self.detokenize,
141 | response_model=ServeDeTokenizeResult,
142 | response_class=JSONResponse,
143 | methods=["POST"],
144 | ),
145 | APIRoute(
146 | "/forward",
147 | self.forward,
148 | response_model=ServeForwardResult,
149 | response_class=JSONResponse,
150 | methods=["POST"],
151 | ),
152 | ],
153 | timeout=600,
154 | )
155 |
156 | def run(self):
157 | run(self._app, host=self.host, port=self.port, workers=self.workers)
158 |
159 | def model_info(self):
160 | return ServeModelInfoResult(infos=vars(self._pipeline.model.config))
161 |
162 | def tokenize(self, text_input: str = Body(None, embed=True), return_ids: bool = Body(False, embed=True)):
163 | """
164 | Tokenize the provided input and eventually returns corresponding tokens id:
165 | - **text_input**: String to tokenize
166 | - **return_ids**: Boolean flags indicating if the tokens have to be converted to their integer mapping.
167 | """
168 | try:
169 | tokens_txt = self._pipeline.tokenizer.tokenize(text_input)
170 |
171 | if return_ids:
172 | tokens_ids = self._pipeline.tokenizer.convert_tokens_to_ids(tokens_txt)
173 | return ServeTokenizeResult(tokens=tokens_txt, tokens_ids=tokens_ids)
174 | else:
175 | return ServeTokenizeResult(tokens=tokens_txt)
176 |
177 | except Exception as e:
178 | raise HTTPException(status_code=500, detail={"model": "", "error": str(e)})
179 |
180 | def detokenize(
181 | self,
182 | tokens_ids: List[int] = Body(None, embed=True),
183 | skip_special_tokens: bool = Body(False, embed=True),
184 | cleanup_tokenization_spaces: bool = Body(True, embed=True),
185 | ):
186 | """
187 | Detokenize the provided tokens ids to readable text:
188 | - **tokens_ids**: List of tokens ids
189 | - **skip_special_tokens**: Flag indicating to not try to decode special tokens
190 | - **cleanup_tokenization_spaces**: Flag indicating to remove all leading/trailing spaces and intermediate ones.
191 | """
192 | try:
193 | decoded_str = self._pipeline.tokenizer.decode(tokens_ids, skip_special_tokens, cleanup_tokenization_spaces)
194 | return ServeDeTokenizeResult(model="", text=decoded_str)
195 | except Exception as e:
196 | raise HTTPException(status_code=500, detail={"model": "", "error": str(e)})
197 |
198 | async def forward(self, inputs=Body(None, embed=True)):
199 | """
200 | **inputs**:
201 | **attention_mask**:
202 | **tokens_type_ids**:
203 | """
204 |
205 | # Check we don't have empty string
206 | if len(inputs) == 0:
207 | return ServeForwardResult(output=[], attention=[])
208 |
209 | try:
210 | # Forward through the model
211 | output = self._pipeline(inputs)
212 | return ServeForwardResult(output=output)
213 | except Exception as e:
214 | raise HTTPException(500, {"error": str(e)})
215 |
--------------------------------------------------------------------------------
/python/src/transformers/commands/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | from argparse import ArgumentParser, Namespace
3 | from logging import getLogger
4 |
5 | from transformers import SingleSentenceClassificationProcessor as Processor
6 | from transformers import TextClassificationPipeline, is_tf_available, is_torch_available
7 | from transformers.commands import BaseTransformersCLICommand
8 |
9 |
10 | if not is_tf_available() and not is_torch_available():
11 | raise RuntimeError("At least one of PyTorch or TensorFlow 2.0+ should be installed to use CLI training")
12 |
13 | # TF training parameters
14 | USE_XLA = False
15 | USE_AMP = False
16 |
17 |
18 | def train_command_factory(args: Namespace):
19 | """
20 | Factory function used to instantiate serving server from provided command line arguments.
21 | :return: ServeCommand
22 | """
23 | return TrainCommand(args)
24 |
25 |
26 | class TrainCommand(BaseTransformersCLICommand):
27 | @staticmethod
28 | def register_subcommand(parser: ArgumentParser):
29 | """
30 | Register this command to argparse so it's available for the transformer-cli
31 | :param parser: Root parser to register command-specific arguments
32 | :return:
33 | """
34 | train_parser = parser.add_parser("train", help="CLI tool to train a model on a task.")
35 |
36 | train_parser.add_argument(
37 | "--train_data",
38 | type=str,
39 | required=True,
40 | help="path to train (and optionally evaluation) dataset as a csv with "
41 | "tab separated labels and sentences.",
42 | )
43 | train_parser.add_argument(
44 | "--column_label", type=int, default=0, help="Column of the dataset csv file with example labels."
45 | )
46 | train_parser.add_argument(
47 | "--column_text", type=int, default=1, help="Column of the dataset csv file with example texts."
48 | )
49 | train_parser.add_argument(
50 | "--column_id", type=int, default=2, help="Column of the dataset csv file with example ids."
51 | )
52 | train_parser.add_argument(
53 | "--skip_first_row", action="store_true", help="Skip the first row of the csv file (headers)."
54 | )
55 |
56 | train_parser.add_argument("--validation_data", type=str, default="", help="path to validation dataset.")
57 | train_parser.add_argument(
58 | "--validation_split",
59 | type=float,
60 | default=0.1,
61 | help="if validation dataset is not provided, fraction of train dataset " "to use as validation dataset.",
62 | )
63 |
64 | train_parser.add_argument("--output", type=str, default="./", help="path to saved the trained model.")
65 |
66 | train_parser.add_argument(
67 | "--task", type=str, default="text_classification", help="Task to train the model on."
68 | )
69 | train_parser.add_argument(
70 | "--model", type=str, default="bert-base-uncased", help="Model's name or path to stored model."
71 | )
72 | train_parser.add_argument("--train_batch_size", type=int, default=32, help="Batch size for training.")
73 | train_parser.add_argument("--valid_batch_size", type=int, default=64, help="Batch size for validation.")
74 | train_parser.add_argument("--learning_rate", type=float, default=3e-5, help="Learning rate.")
75 | train_parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon for Adam optimizer.")
76 | train_parser.set_defaults(func=train_command_factory)
77 |
78 | def __init__(self, args: Namespace):
79 | self.logger = getLogger("transformers-cli/training")
80 |
81 | self.framework = "tf" if is_tf_available() else "torch"
82 |
83 | os.makedirs(args.output, exist_ok=True)
84 | assert os.path.isdir(args.output)
85 | self.output = args.output
86 |
87 | self.column_label = args.column_label
88 | self.column_text = args.column_text
89 | self.column_id = args.column_id
90 |
91 | self.logger.info("Loading {} pipeline for {}".format(args.task, args.model))
92 | if args.task == "text_classification":
93 | self.pipeline = TextClassificationPipeline.from_pretrained(args.model)
94 | elif args.task == "token_classification":
95 | raise NotImplementedError
96 | elif args.task == "question_answering":
97 | raise NotImplementedError
98 |
99 | self.logger.info("Loading dataset from {}".format(args.train_data))
100 | self.train_dataset = Processor.create_from_csv(
101 | args.train_data,
102 | column_label=args.column_label,
103 | column_text=args.column_text,
104 | column_id=args.column_id,
105 | skip_first_row=args.skip_first_row,
106 | )
107 | self.valid_dataset = None
108 | if args.validation_data:
109 | self.logger.info("Loading validation dataset from {}".format(args.validation_data))
110 | self.valid_dataset = Processor.create_from_csv(
111 | args.validation_data,
112 | column_label=args.column_label,
113 | column_text=args.column_text,
114 | column_id=args.column_id,
115 | skip_first_row=args.skip_first_row,
116 | )
117 |
118 | self.validation_split = args.validation_split
119 | self.train_batch_size = args.train_batch_size
120 | self.valid_batch_size = args.valid_batch_size
121 | self.learning_rate = args.learning_rate
122 | self.adam_epsilon = args.adam_epsilon
123 |
124 | def run(self):
125 | if self.framework == "tf":
126 | return self.run_tf()
127 | return self.run_torch()
128 |
129 | def run_torch(self):
130 | raise NotImplementedError
131 |
132 | def run_tf(self):
133 | self.pipeline.fit(
134 | self.train_dataset,
135 | validation_data=self.valid_dataset,
136 | validation_split=self.validation_split,
137 | learning_rate=self.learning_rate,
138 | adam_epsilon=self.adam_epsilon,
139 | train_batch_size=self.train_batch_size,
140 | valid_batch_size=self.valid_batch_size,
141 | )
142 |
143 | # Save trained pipeline
144 | self.pipeline.save_pretrained(self.output)
145 |
--------------------------------------------------------------------------------
/python/src/transformers/configuration_albert.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """ ALBERT model configuration """
17 |
18 | from .configuration_utils import PretrainedConfig
19 |
20 |
21 | ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
22 | "albert-base-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-config.json",
23 | "albert-large-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-config.json",
24 | "albert-xlarge-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-config.json",
25 | "albert-xxlarge-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-config.json",
26 | "albert-base-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v2-config.json",
27 | "albert-large-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-config.json",
28 | "albert-xlarge-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-v2-config.json",
29 | "albert-xxlarge-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-v2-config.json",
30 | }
31 |
32 |
33 | class AlbertConfig(PretrainedConfig):
34 | r"""
35 | This is the configuration class to store the configuration of an :class:`~transformers.AlbertModel`.
36 | It is used to instantiate an ALBERT model according to the specified arguments, defining the model
37 | architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
38 | the ALBERT `xxlarge `__ architecture.
39 |
40 | Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used
41 | to control the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig`
42 | for more information.
43 |
44 |
45 | Args:
46 | vocab_size (:obj:`int`, optional, defaults to 30000):
47 | Vocabulary size of the ALBERT model. Defines the different tokens that
48 | can be represented by the `inputs_ids` passed to the forward method of :class:`~transformers.AlbertModel`.
49 | embedding_size (:obj:`int`, optional, defaults to 128):
50 | Dimensionality of vocabulary embeddings.
51 | hidden_size (:obj:`int`, optional, defaults to 4096):
52 | Dimensionality of the encoder layers and the pooler layer.
53 | num_hidden_layers (:obj:`int`, optional, defaults to 12):
54 | Number of hidden layers in the Transformer encoder.
55 | num_hidden_groups (:obj:`int`, optional, defaults to 1):
56 | Number of groups for the hidden layers, parameters in the same group are shared.
57 | num_attention_heads (:obj:`int`, optional, defaults to 64):
58 | Number of attention heads for each attention layer in the Transformer encoder.
59 | intermediate_size (:obj:`int`, optional, defaults to 16384):
60 | The dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
61 | inner_group_num (:obj:`int`, optional, defaults to 1):
62 | The number of inner repetition of attention and ffn.
63 | hidden_act (:obj:`str` or :obj:`function`, optional, defaults to "gelu_new"):
64 | The non-linear activation function (function or string) in the encoder and pooler.
65 | If string, "gelu", "relu", "swish" and "gelu_new" are supported.
66 | hidden_dropout_prob (:obj:`float`, optional, defaults to 0):
67 | The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
68 | attention_probs_dropout_prob (:obj:`float`, optional, defaults to 0):
69 | The dropout ratio for the attention probabilities.
70 | max_position_embeddings (:obj:`int`, optional, defaults to 512):
71 | The maximum sequence length that this model might ever be used with. Typically set this to something
72 | large (e.g., 512 or 1024 or 2048).
73 | type_vocab_size (:obj:`int`, optional, defaults to 2):
74 | The vocabulary size of the `token_type_ids` passed into :class:`~transformers.AlbertModel`.
75 | initializer_range (:obj:`float`, optional, defaults to 0.02):
76 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
77 | layer_norm_eps (:obj:`float`, optional, defaults to 1e-12):
78 | The epsilon used by the layer normalization layers.
79 | classifier_dropout_prob (:obj:`float`, optional, defaults to 0.1):
80 | The dropout ratio for attached classifiers.
81 |
82 | Example::
83 |
84 | from transformers import AlbertConfig, AlbertModel
85 | # Initializing an ALBERT-xxlarge style configuration
86 | albert_xxlarge_configuration = AlbertConfig()
87 |
88 | # Initializing an ALBERT-base style configuration
89 | albert_base_configuration = AlbertConfig(
90 | hidden_size=768,
91 | num_attention_heads=12,
92 | intermediate_size=3072,
93 | )
94 |
95 | # Initializing a model from the ALBERT-base style configuration
96 | model = AlbertModel(albert_xxlarge_configuration)
97 |
98 | # Accessing the model configuration
99 | configuration = model.config
100 |
101 | Attributes:
102 | pretrained_config_archive_map (Dict[str, str]):
103 | A dictionary containing all the available pre-trained checkpoints.
104 | """
105 |
106 | pretrained_config_archive_map = ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
107 | model_type = "albert"
108 |
109 | def __init__(
110 | self,
111 | vocab_size=30000,
112 | embedding_size=128,
113 | hidden_size=4096,
114 | num_hidden_layers=12,
115 | num_hidden_groups=1,
116 | num_attention_heads=64,
117 | intermediate_size=16384,
118 | inner_group_num=1,
119 | hidden_act="gelu_new",
120 | hidden_dropout_prob=0,
121 | attention_probs_dropout_prob=0,
122 | max_position_embeddings=512,
123 | type_vocab_size=2,
124 | initializer_range=0.02,
125 | layer_norm_eps=1e-12,
126 | classifier_dropout_prob=0.1,
127 | **kwargs
128 | ):
129 | super().__init__(**kwargs)
130 |
131 | self.vocab_size = vocab_size
132 | self.embedding_size = embedding_size
133 | self.hidden_size = hidden_size
134 | self.num_hidden_layers = num_hidden_layers
135 | self.num_hidden_groups = num_hidden_groups
136 | self.num_attention_heads = num_attention_heads
137 | self.inner_group_num = inner_group_num
138 | self.hidden_act = hidden_act
139 | self.intermediate_size = intermediate_size
140 | self.hidden_dropout_prob = hidden_dropout_prob
141 | self.attention_probs_dropout_prob = attention_probs_dropout_prob
142 | self.max_position_embeddings = max_position_embeddings
143 | self.type_vocab_size = type_vocab_size
144 | self.initializer_range = initializer_range
145 | self.layer_norm_eps = layer_norm_eps
146 | self.classifier_dropout_prob = classifier_dropout_prob
147 |
--------------------------------------------------------------------------------
/python/src/transformers/configuration_bart.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The Fairseq Authors and The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """ BART configuration """
16 |
17 |
18 | import logging
19 |
20 | from .configuration_utils import PretrainedConfig
21 |
22 |
23 | logger = logging.getLogger(__name__)
24 |
25 | BART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
26 | "bart-large": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large/config.json",
27 | "bart-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-mnli/config.json",
28 | "bart-large-cnn": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/config.json",
29 | }
30 |
31 |
32 | class BartConfig(PretrainedConfig):
33 | r"""
34 | Configuration class for Bart. Parameters are renamed from the fairseq implementation
35 | """
36 | model_type = "bart"
37 | pretrained_config_archive_map = BART_PRETRAINED_CONFIG_ARCHIVE_MAP
38 |
39 | def __init__(
40 | self,
41 | activation_dropout=0.0,
42 | activation_function="gelu",
43 | vocab_size=50265,
44 | bos_token_id=0,
45 | pad_token_id=1,
46 | eos_token_ids=[2],
47 | d_model=1024,
48 | encoder_ffn_dim=4096,
49 | encoder_layers=12,
50 | encoder_attention_heads=16,
51 | decoder_ffn_dim=4096,
52 | decoder_layers=12,
53 | decoder_attention_heads=16,
54 | encoder_layerdrop=0.0,
55 | decoder_layerdrop=0.0,
56 | attention_dropout=0.0,
57 | dropout=0.1,
58 | max_position_embeddings=1024,
59 | init_std=0.02,
60 | classifier_dropout=0.0,
61 | output_past=False,
62 | num_labels=3,
63 | is_encoder_decoder=True,
64 | **common_kwargs
65 | ):
66 | r"""
67 | :class:`~transformers.BartConfig` is the configuration class for `BartModel`.
68 | Examples:
69 | config = BartConfig.from_pretrained('bart-large')
70 | model = BartModel(config)
71 | """
72 | super().__init__(
73 | num_labels=num_labels,
74 | output_past=output_past,
75 | pad_token_id=pad_token_id,
76 | bos_token_id=bos_token_id,
77 | eos_token_ids=eos_token_ids,
78 | is_encoder_decoder=is_encoder_decoder,
79 | **common_kwargs,
80 | )
81 | self.vocab_size = vocab_size
82 | self.d_model = d_model # encoder_embed_dim and decoder_embed_dim
83 | self.encoder_ffn_dim = encoder_ffn_dim
84 | self.encoder_layers = self.num_hidden_layers = encoder_layers
85 | self.encoder_attention_heads = encoder_attention_heads
86 | self.encoder_layerdrop = encoder_layerdrop
87 | self.decoder_layerdrop = decoder_layerdrop
88 | self.decoder_ffn_dim = decoder_ffn_dim
89 | self.decoder_layers = decoder_layers
90 | self.decoder_attention_heads = decoder_attention_heads
91 | self.max_position_embeddings = max_position_embeddings
92 | self.init_std = init_std # Normal(0, this parameter)
93 | self.activation_function = activation_function
94 |
95 | # 3 Types of Dropout
96 | self.attention_dropout = attention_dropout
97 | self.activation_dropout = activation_dropout
98 | self.dropout = dropout
99 |
100 | # Classifier stuff
101 | self.classif_dropout = classifier_dropout
102 |
103 | @property
104 | def num_attention_heads(self):
105 | return self.encoder_attention_heads
106 |
107 | @property
108 | def hidden_size(self):
109 | return self.d_model
110 |
--------------------------------------------------------------------------------
/python/src/transformers/configuration_camembert.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """ CamemBERT configuration """
17 |
18 |
19 | import logging
20 |
21 | from .configuration_roberta import RobertaConfig
22 |
23 |
24 | logger = logging.getLogger(__name__)
25 |
26 | CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
27 | "camembert-base": "https://s3.amazonaws.com/models.huggingface.co/bert/camembert-base-config.json",
28 | "umberto-commoncrawl-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/Musixmatch/umberto-commoncrawl-cased-v1/config.json",
29 | "umberto-wikipedia-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/Musixmatch/umberto-wikipedia-uncased-v1/config.json",
30 | }
31 |
32 |
33 | class CamembertConfig(RobertaConfig):
34 | """
35 | This class overrides :class:`~transformers.RobertaConfig`. Please check the
36 | superclass for the appropriate documentation alongside usage examples.
37 | """
38 |
39 | pretrained_config_archive_map = CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
40 | model_type = "camembert"
41 |
--------------------------------------------------------------------------------
/python/src/transformers/configuration_ctrl.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Salesforce and HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """ Salesforce CTRL configuration """
16 |
17 |
18 | import logging
19 |
20 | from .configuration_utils import PretrainedConfig
21 |
22 |
23 | logger = logging.getLogger(__name__)
24 |
25 | CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP = {"ctrl": "https://storage.googleapis.com/sf-ctrl/pytorch/ctrl-config.json"}
26 |
27 |
28 | class CTRLConfig(PretrainedConfig):
29 | """
30 | This is the configuration class to store the configuration of an :class:`~transformers.CTRLModel`.
31 | It is used to instantiate an CTRL model according to the specified arguments, defining the model
32 | architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
33 | the `ctrl `__ architecture from SalesForce.
34 |
35 | Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used
36 | to control the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig`
37 | for more information.
38 |
39 | Args:
40 | vocab_size (:obj:`int`, optional, defaults to 246534):
41 | Vocabulary size of the CTRL model. Defines the different tokens that
42 | can be represented by the `inputs_ids` passed to the forward method of :class:`~transformers.CTRLModel`.
43 | n_positions (:obj:`int`, optional, defaults to 256):
44 | The maximum sequence length that this model might ever be used with.
45 | Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
46 | n_ctx (:obj:`int`, optional, defaults to 256):
47 | Dimensionality of the causal mask (usually same as n_positions).
48 | n_embd (:obj:`int`, optional, defaults to 1280):
49 | Dimensionality of the embeddings and hidden states.
50 | dff (:obj:`int`, optional, defaults to 8192):
51 | Dimensionality of the inner dimension of the FFN.
52 | n_layer (:obj:`int`, optional, defaults to 48):
53 | Number of hidden layers in the Transformer encoder.
54 | n_head (:obj:`int`, optional, defaults to 16):
55 | Number of attention heads for each attention layer in the Transformer encoder.
56 | resid_pdrop (:obj:`float`, optional, defaults to 0.1):
57 | The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
58 | embd_pdrop (:obj:`int`, optional, defaults to 0.1):
59 | The dropout ratio for the embeddings.
60 | attn_pdrop (:obj:`float`, optional, defaults to 0.1):
61 | The dropout ratio for the attention.
62 | layer_norm_epsilon (:obj:`float`, optional, defaults to 1e-6):
63 | The epsilon to use in the layer normalization layers
64 | initializer_range (:obj:`float`, optional, defaults to 0.02):
65 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
66 |
67 | Example::
68 |
69 | from transformers import CTRLModel, CTRLConfig
70 |
71 | # Initializing a CTRL configuration
72 | configuration = CTRLConfig()
73 |
74 | # Initializing a model from the configuration
75 | model = CTRLModel(configuration)
76 |
77 | # Accessing the model configuration
78 | configuration = model.config
79 |
80 | Attributes:
81 | pretrained_config_archive_map (Dict[str, str]):
82 | A dictionary containing all the available pre-trained checkpoints.
83 | """
84 |
85 | pretrained_config_archive_map = CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP
86 | model_type = "ctrl"
87 |
88 | def __init__(
89 | self,
90 | vocab_size=246534,
91 | n_positions=256,
92 | n_ctx=256,
93 | n_embd=1280,
94 | dff=8192,
95 | n_layer=48,
96 | n_head=16,
97 | resid_pdrop=0.1,
98 | embd_pdrop=0.1,
99 | attn_pdrop=0.1,
100 | layer_norm_epsilon=1e-6,
101 | initializer_range=0.02,
102 | summary_type="cls_index",
103 | summary_use_proj=True,
104 | summary_activation=None,
105 | summary_proj_to_labels=True,
106 | summary_first_dropout=0.1,
107 | **kwargs
108 | ):
109 | super().__init__(**kwargs)
110 | self.vocab_size = vocab_size
111 | self.n_ctx = n_ctx
112 | self.n_positions = n_positions
113 | self.n_embd = n_embd
114 | self.n_layer = n_layer
115 | self.n_head = n_head
116 | self.dff = dff
117 | self.resid_pdrop = resid_pdrop
118 | self.embd_pdrop = embd_pdrop
119 | self.attn_pdrop = attn_pdrop
120 | self.layer_norm_epsilon = layer_norm_epsilon
121 | self.initializer_range = initializer_range
122 |
123 | self.summary_type = summary_type
124 | self.summary_use_proj = summary_use_proj
125 | self.summary_activation = summary_activation
126 | self.summary_first_dropout = summary_first_dropout
127 | self.summary_proj_to_labels = summary_proj_to_labels
128 |
129 | @property
130 | def max_position_embeddings(self):
131 | return self.n_positions
132 |
133 | @property
134 | def hidden_size(self):
135 | return self.n_embd
136 |
137 | @property
138 | def num_attention_heads(self):
139 | return self.n_head
140 |
141 | @property
142 | def num_hidden_layers(self):
143 | return self.n_layer
144 |
--------------------------------------------------------------------------------
/python/src/transformers/configuration_distilbert.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """ DistilBERT model configuration """
16 |
17 |
18 | import logging
19 |
20 | from .configuration_utils import PretrainedConfig
21 |
22 |
23 | logger = logging.getLogger(__name__)
24 |
25 | DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
26 | "distilbert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-config.json",
27 | "distilbert-base-uncased-distilled-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-distilled-squad-config.json",
28 | "distilbert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-cased-config.json",
29 | "distilbert-base-cased-distilled-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-cased-distilled-squad-config.json",
30 | "distilbert-base-german-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-german-cased-config.json",
31 | "distilbert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-multilingual-cased-config.json",
32 | "distilbert-base-uncased-finetuned-sst-2-english": "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-finetuned-sst-2-english-config.json",
33 | }
34 |
35 |
36 | class DistilBertConfig(PretrainedConfig):
37 | r"""
38 | This is the configuration class to store the configuration of a :class:`~transformers.DistilBertModel`.
39 | It is used to instantiate a DistilBERT model according to the specified arguments, defining the model
40 | architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
41 | the DistilBERT `distilbert-base-uncased `__ architecture.
42 |
43 | Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used
44 | to control the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig`
45 | for more information.
46 |
47 |
48 | Args:
49 | vocab_size (:obj:`int`, optional, defaults to 30522):
50 | Vocabulary size of the DistilBERT model. Defines the different tokens that
51 | can be represented by the `inputs_ids` passed to the forward method of :class:`~transformers.BertModel`.
52 | max_position_embeddings (:obj:`int`, optional, defaults to 512):
53 | The maximum sequence length that this model might ever be used with.
54 | Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
55 | sinusoidal_pos_embds (:obj:`boolean`, optional, defaults to :obj:`False`):
56 | Whether to use sinusoidal positional embeddings.
57 | n_layers (:obj:`int`, optional, defaults to 6):
58 | Number of hidden layers in the Transformer encoder.
59 | n_heads (:obj:`int`, optional, defaults to 12):
60 | Number of attention heads for each attention layer in the Transformer encoder.
61 | dim (:obj:`int`, optional, defaults to 768):
62 | Dimensionality of the encoder layers and the pooler layer.
63 | hidden_dim (:obj:`int`, optional, defaults to 3072):
64 | The size of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
65 | dropout (:obj:`float`, optional, defaults to 0.1):
66 | The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
67 | attention_dropout (:obj:`float`, optional, defaults to 0.1):
68 | The dropout ratio for the attention probabilities.
69 | activation (:obj:`str` or :obj:`function`, optional, defaults to "gelu"):
70 | The non-linear activation function (function or string) in the encoder and pooler.
71 | If string, "gelu", "relu", "swish" and "gelu_new" are supported.
72 | initializer_range (:obj:`float`, optional, defaults to 0.02):
73 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
74 | qa_dropout (:obj:`float`, optional, defaults to 0.1):
75 | The dropout probabilities used in the question answering model
76 | :class:`~tranformers.DistilBertForQuestionAnswering`.
77 | seq_classif_dropout (:obj:`float`, optional, defaults to 0.2):
78 | The dropout probabilities used in the sequence classification model
79 | :class:`~tranformers.DistilBertForSequenceClassification`.
80 |
81 | Example::
82 |
83 | from transformers import DistilBertModel, DistilBertConfig
84 |
85 | # Initializing a DistilBERT configuration
86 | configuration = DistilBertConfig()
87 |
88 | # Initializing a model from the configuration
89 | model = DistilBertModel(configuration)
90 |
91 | # Accessing the model configuration
92 | configuration = model.config
93 |
94 | Attributes:
95 | pretrained_config_archive_map (Dict[str, str]):
96 | A dictionary containing all the available pre-trained checkpoints.
97 | """
98 | pretrained_config_archive_map = DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
99 | model_type = "distilbert"
100 |
101 | def __init__(
102 | self,
103 | vocab_size=30522,
104 | max_position_embeddings=512,
105 | sinusoidal_pos_embds=False,
106 | n_layers=6,
107 | n_heads=12,
108 | dim=768,
109 | hidden_dim=4 * 768,
110 | dropout=0.1,
111 | attention_dropout=0.1,
112 | activation="gelu",
113 | initializer_range=0.02,
114 | qa_dropout=0.1,
115 | seq_classif_dropout=0.2,
116 | **kwargs
117 | ):
118 | super().__init__(**kwargs)
119 | self.vocab_size = vocab_size
120 | self.max_position_embeddings = max_position_embeddings
121 | self.sinusoidal_pos_embds = sinusoidal_pos_embds
122 | self.n_layers = n_layers
123 | self.n_heads = n_heads
124 | self.dim = dim
125 | self.hidden_dim = hidden_dim
126 | self.dropout = dropout
127 | self.attention_dropout = attention_dropout
128 | self.activation = activation
129 | self.initializer_range = initializer_range
130 | self.qa_dropout = qa_dropout
131 | self.seq_classif_dropout = seq_classif_dropout
132 |
133 | @property
134 | def hidden_size(self):
135 | return self.dim
136 |
137 | @property
138 | def num_attention_heads(self):
139 | return self.n_heads
140 |
141 | @property
142 | def num_hidden_layers(self):
143 | return self.n_layers
144 |
--------------------------------------------------------------------------------
/python/src/transformers/configuration_mmbt.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | # Copyright (c) HuggingFace Inc. team.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """ MMBT configuration """
17 |
18 |
19 | import logging
20 |
21 |
22 | logger = logging.getLogger(__name__)
23 |
24 |
25 | class MMBTConfig(object):
26 | """Configuration class to store the configuration of a `MMBT Model`.
27 |
28 | Args:
29 | config (:obj:`~transformers.PreTrainedConfig`):
30 | Config of the underlying Transformer models. Its values are
31 | copied over to use a single config.
32 | num_labels (:obj:`int` or :obj:`None`, optional, defaults to `None`):
33 | Size of final Linear layer for classification.
34 | modal_hidden_size (:obj:`int`, optional, defautls to 2048):
35 | Embedding dimension of the non-text modality encoder.
36 | """
37 |
38 | def __init__(self, config, num_labels=None, modal_hidden_size=2048):
39 | self.__dict__ = config.__dict__
40 | self.modal_hidden_size = modal_hidden_size
41 | if num_labels:
42 | self.num_labels = num_labels
43 |
--------------------------------------------------------------------------------
/python/src/transformers/configuration_roberta.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """ RoBERTa configuration """
17 |
18 |
19 | import logging
20 |
21 | from .configuration_bert import BertConfig
22 |
23 |
24 | logger = logging.getLogger(__name__)
25 |
26 | ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
27 | "roberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-config.json",
28 | "roberta-large": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-config.json",
29 | "roberta-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-config.json",
30 | "distilroberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/distilroberta-base-config.json",
31 | "roberta-base-openai-detector": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-openai-detector-config.json",
32 | "roberta-large-openai-detector": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-openai-detector-config.json",
33 | }
34 |
35 |
36 | class RobertaConfig(BertConfig):
37 | r"""
38 | This is the configuration class to store the configuration of an :class:`~transformers.RobertaModel`.
39 | It is used to instantiate an RoBERTa model according to the specified arguments, defining the model
40 | architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
41 | the BERT `bert-base-uncased `__ architecture.
42 |
43 | Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used
44 | to control the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig`
45 | for more information.
46 |
47 | The :class:`~transformers.RobertaConfig` class directly inherits :class:`~transformers.BertConfig`.
48 | It reuses the same defaults. Please check the parent class for more information.
49 |
50 | Example::
51 |
52 | from transformers import RobertaConfig, RobertaModel
53 |
54 | # Initializing a RoBERTa configuration
55 | configuration = RobertaConfig()
56 |
57 | # Initializing a model from the configuration
58 | model = RobertaModel(configuration)
59 |
60 | # Accessing the model configuration
61 | configuration = model.config
62 |
63 | Attributes:
64 | pretrained_config_archive_map (Dict[str, str]):
65 | A dictionary containing all the available pre-trained checkpoints.
66 | """
67 | pretrained_config_archive_map = ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
68 | model_type = "roberta"
69 |
--------------------------------------------------------------------------------
/python/src/transformers/configuration_t5.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2010, The T5 Authors and HuggingFace Inc.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """ T5 model configuration """
16 |
17 |
18 | import logging
19 |
20 | from .configuration_utils import PretrainedConfig
21 |
22 |
23 | logger = logging.getLogger(__name__)
24 |
25 | T5_PRETRAINED_CONFIG_ARCHIVE_MAP = {
26 | "t5-small": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-small-config.json",
27 | "t5-base": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-base-config.json",
28 | "t5-large": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-large-config.json",
29 | "t5-3b": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-3b-config.json",
30 | "t5-11b": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-11b-config.json",
31 | }
32 |
33 |
34 | class T5Config(PretrainedConfig):
35 | r"""
36 | :class:`~transformers.T5Config` is the configuration class to store the configuration of a
37 | `T5Model`.
38 |
39 |
40 | Arguments:
41 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `T5Model`.
42 | hidden_size: Size of the encoder layers and the pooler layer.
43 | num_hidden_layers: Number of hidden layers in the Transformer encoder.
44 | num_attention_heads: Number of attention heads for each attention layer in
45 | the Transformer encoder.
46 | intermediate_size: The size of the "intermediate" (i.e., feed-forward)
47 | layer in the Transformer encoder.
48 | hidden_act: The non-linear activation function (function or string) in the
49 | encoder and pooler. If string, "gelu", "relu", "swish" and "gelu_new" are supported.
50 | hidden_dropout_prob: The dropout probabilitiy for all fully connected
51 | layers in the embeddings, encoder, and pooler.
52 | attention_probs_dropout_prob: The dropout ratio for the attention
53 | probabilities.
54 | max_position_embeddings: The maximum sequence length that this model might
55 | ever be used with. Typically set this to something large just in case
56 | (e.g., 512 or 1024 or 2048).
57 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into
58 | `T5Model`.
59 | initializer_factor: A factor for initializing all weight matrices (should be kept to 1.0, used for initialization testing).
60 | layer_norm_eps: The epsilon used by LayerNorm.
61 | """
62 | pretrained_config_archive_map = T5_PRETRAINED_CONFIG_ARCHIVE_MAP
63 | model_type = "t5"
64 |
65 | def __init__(
66 | self,
67 | vocab_size=32128,
68 | n_positions=512,
69 | d_model=512,
70 | d_kv=64,
71 | d_ff=2048,
72 | num_layers=6,
73 | num_heads=8,
74 | relative_attention_num_buckets=32,
75 | dropout_rate=0.1,
76 | layer_norm_epsilon=1e-6,
77 | initializer_factor=1.0,
78 | is_encoder_decoder=True,
79 | pad_token_id=0,
80 | eos_token_ids=[1],
81 | **kwargs
82 | ):
83 | super().__init__(
84 | is_encoder_decoder=is_encoder_decoder, **kwargs,
85 | )
86 | self.vocab_size = vocab_size
87 | self.n_positions = n_positions
88 | self.d_model = d_model
89 | self.d_kv = d_kv
90 | self.d_ff = d_ff
91 | self.num_layers = num_layers
92 | self.num_heads = num_heads
93 | self.relative_attention_num_buckets = relative_attention_num_buckets
94 | self.dropout_rate = dropout_rate
95 | self.layer_norm_epsilon = layer_norm_epsilon
96 | self.initializer_factor = initializer_factor
97 |
98 | @property
99 | def max_position_embeddings(self):
100 | return self.n_positions
101 |
102 | @property
103 | def hidden_size(self):
104 | return self.d_model
105 |
106 | @property
107 | def num_attention_heads(self):
108 | return self.num_heads
109 |
110 | @property
111 | def num_hidden_layers(self):
112 | return self.num_layers
113 |
--------------------------------------------------------------------------------
/python/src/transformers/configuration_xlm_roberta.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """ XLM-RoBERTa configuration """
17 |
18 |
19 | import logging
20 |
21 | from .configuration_roberta import RobertaConfig
22 |
23 |
24 | logger = logging.getLogger(__name__)
25 |
26 | XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
27 | "xlm-roberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-base-config.json",
28 | "xlm-roberta-large": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-config.json",
29 | "xlm-roberta-large-finetuned-conll02-dutch": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-dutch-config.json",
30 | "xlm-roberta-large-finetuned-conll02-spanish": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-spanish-config.json",
31 | "xlm-roberta-large-finetuned-conll03-english": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-english-config.json",
32 | "xlm-roberta-large-finetuned-conll03-german": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-german-config.json",
33 | }
34 |
35 |
36 | class XLMRobertaConfig(RobertaConfig):
37 | """
38 | This class overrides :class:`~transformers.RobertaConfig`. Please check the
39 | superclass for the appropriate documentation alongside usage examples.
40 | """
41 |
42 | pretrained_config_archive_map = XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
43 | model_type = "xlm-roberta"
44 |
--------------------------------------------------------------------------------
/python/src/transformers/convert_albert_original_tf_checkpoint_to_pytorch.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Convert ALBERT checkpoint."""
16 |
17 |
18 | import argparse
19 | import logging
20 |
21 | import torch
22 |
23 | from transformers import AlbertConfig, AlbertForMaskedLM, load_tf_weights_in_albert
24 |
25 |
26 | logging.basicConfig(level=logging.INFO)
27 |
28 |
29 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, albert_config_file, pytorch_dump_path):
30 | # Initialise PyTorch model
31 | config = AlbertConfig.from_json_file(albert_config_file)
32 | print("Building PyTorch model from configuration: {}".format(str(config)))
33 | model = AlbertForMaskedLM(config)
34 |
35 | # Load weights from tf checkpoint
36 | load_tf_weights_in_albert(model, config, tf_checkpoint_path)
37 |
38 | # Save pytorch-model
39 | print("Save PyTorch model to {}".format(pytorch_dump_path))
40 | torch.save(model.state_dict(), pytorch_dump_path)
41 |
42 |
43 | if __name__ == "__main__":
44 | parser = argparse.ArgumentParser()
45 | # Required parameters
46 | parser.add_argument(
47 | "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
48 | )
49 | parser.add_argument(
50 | "--albert_config_file",
51 | default=None,
52 | type=str,
53 | required=True,
54 | help="The config json file corresponding to the pre-trained ALBERT model. \n"
55 | "This specifies the model architecture.",
56 | )
57 | parser.add_argument(
58 | "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
59 | )
60 | args = parser.parse_args()
61 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.albert_config_file, args.pytorch_dump_path)
62 |
--------------------------------------------------------------------------------
/python/src/transformers/convert_bart_original_pytorch_checkpoint_to_pytorch.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Convert BART checkpoint."""
16 |
17 |
18 | import argparse
19 | import logging
20 | from pathlib import Path
21 |
22 | import fairseq
23 | import torch
24 | from packaging import version
25 |
26 | from transformers import (
27 | BartConfig,
28 | BartForConditionalGeneration,
29 | BartForSequenceClassification,
30 | BartModel,
31 | BartTokenizer,
32 | )
33 |
34 |
35 | FAIRSEQ_MODELS = ["bart.large", "bart.large.mnli", "bart.large.cnn"]
36 |
37 | if version.parse(fairseq.__version__) < version.parse("0.9.0"):
38 | raise Exception("requires fairseq >= 0.9.0")
39 |
40 |
41 | logging.basicConfig(level=logging.INFO)
42 | logger = logging.getLogger(__name__)
43 |
44 | SAMPLE_TEXT = " Hello world! cécé herlolip"
45 |
46 | rename_keys = [
47 | ("model.classification_heads.mnli.dense.weight", "classification_head.dense.weight"),
48 | ("model.classification_heads.mnli.dense.bias", "classification_head.dense.bias"),
49 | ("model.classification_heads.mnli.out_proj.weight", "classification_head.out_proj.weight"),
50 | ("model.classification_heads.mnli.out_proj.bias", "classification_head.out_proj.bias"),
51 | ]
52 | IGNORE_KEYS = ["encoder.version", "decoder.version", "model.encoder.version", "model.decoder.version", "_float_tensor"]
53 |
54 |
55 | def rename_key(dct, old, new):
56 | val = dct.pop(old)
57 | dct[new] = val
58 |
59 |
60 | def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path):
61 | """
62 | Copy/paste/tweak model's weights to our BERT structure.
63 | """
64 | bart = torch.hub.load("pytorch/fairseq", checkpoint_path)
65 | bart.eval() # disable dropout
66 | bart.model.upgrade_state_dict(bart.model.state_dict())
67 | hf_model_name = checkpoint_path.replace(".", "-")
68 | config = BartConfig.from_pretrained(hf_model_name)
69 | tokens = bart.encode(SAMPLE_TEXT).unsqueeze(0)
70 | tokens2 = BartTokenizer.from_pretrained(hf_model_name).encode(SAMPLE_TEXT, return_tensors="pt").unsqueeze(0)
71 | assert torch.eq(tokens, tokens2).all()
72 |
73 | if checkpoint_path in ["bart.large", "bart.large.cnn"]:
74 | state_dict = bart.model.state_dict()
75 | for k in IGNORE_KEYS:
76 | state_dict.pop(k, None)
77 | state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"]
78 | model = BartModel(config)
79 | their_output = bart.extract_features(tokens)
80 | else: # MNLI Case
81 | state_dict = bart.state_dict()
82 | for k in IGNORE_KEYS:
83 | state_dict.pop(k, None)
84 | state_dict["model.shared.weight"] = state_dict["model.decoder.embed_tokens.weight"]
85 | for src, dest in rename_keys:
86 | rename_key(state_dict, src, dest)
87 | model = BartForSequenceClassification(config)
88 | their_output = bart.predict("mnli", tokens, return_logits=True)
89 |
90 | # Load state dict
91 | model.load_state_dict(state_dict)
92 | model.eval()
93 | # Check results
94 |
95 | if checkpoint_path == "bart.large.cnn":
96 | model = BartForConditionalGeneration(config, base_model=model)
97 | assert "lm_head.weight" in model.state_dict()
98 | assert model.lm_head.out_features == config.max_position_embeddings
99 | model.eval()
100 | our_outputs = model.model(tokens)[0]
101 | else:
102 | our_outputs = model(tokens)[0]
103 | assert their_output.shape == our_outputs.shape
104 | assert (their_output == our_outputs).all().item()
105 | Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
106 | model.save_pretrained(pytorch_dump_folder_path)
107 |
108 |
109 | if __name__ == "__main__":
110 | parser = argparse.ArgumentParser()
111 | # Required parameters
112 | parser.add_argument("fairseq_path", choices=FAIRSEQ_MODELS, type=str, help="")
113 |
114 | parser.add_argument("pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
115 | args = parser.parse_args()
116 | convert_bart_checkpoint(
117 | args.fairseq_path, args.pytorch_dump_folder_path,
118 | )
119 |
--------------------------------------------------------------------------------
/python/src/transformers/convert_bert_original_tf_checkpoint_to_pytorch.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Convert BERT checkpoint."""
16 |
17 |
18 | import argparse
19 | import logging
20 |
21 | import torch
22 |
23 | from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert
24 |
25 |
26 | logging.basicConfig(level=logging.INFO)
27 |
28 |
29 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
30 | # Initialise PyTorch model
31 | config = BertConfig.from_json_file(bert_config_file)
32 | print("Building PyTorch model from configuration: {}".format(str(config)))
33 | model = BertForPreTraining(config)
34 |
35 | # Load weights from tf checkpoint
36 | load_tf_weights_in_bert(model, config, tf_checkpoint_path)
37 |
38 | # Save pytorch-model
39 | print("Save PyTorch model to {}".format(pytorch_dump_path))
40 | torch.save(model.state_dict(), pytorch_dump_path)
41 |
42 |
43 | if __name__ == "__main__":
44 | parser = argparse.ArgumentParser()
45 | # Required parameters
46 | parser.add_argument(
47 | "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
48 | )
49 | parser.add_argument(
50 | "--bert_config_file",
51 | default=None,
52 | type=str,
53 | required=True,
54 | help="The config json file corresponding to the pre-trained BERT model. \n"
55 | "This specifies the model architecture.",
56 | )
57 | parser.add_argument(
58 | "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
59 | )
60 | args = parser.parse_args()
61 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)
62 |
--------------------------------------------------------------------------------
/python/src/transformers/convert_bert_pytorch_checkpoint_to_original_tf.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Convert Huggingface Pytorch checkpoint to Tensorflow checkpoint."""
17 |
18 | import argparse
19 | import os
20 |
21 | import numpy as np
22 | import tensorflow as tf
23 | import torch
24 |
25 | from transformers import BertModel
26 |
27 |
28 | def convert_pytorch_checkpoint_to_tf(model: BertModel, ckpt_dir: str, model_name: str):
29 |
30 | """
31 | :param model:BertModel Pytorch model instance to be converted
32 | :param ckpt_dir: Tensorflow model directory
33 | :param model_name: model name
34 | :return:
35 |
36 | Currently supported HF models:
37 | Y BertModel
38 | N BertForMaskedLM
39 | N BertForPreTraining
40 | N BertForMultipleChoice
41 | N BertForNextSentencePrediction
42 | N BertForSequenceClassification
43 | N BertForQuestionAnswering
44 | """
45 |
46 | tensors_to_transpose = ("dense.weight", "attention.self.query", "attention.self.key", "attention.self.value")
47 |
48 | var_map = (
49 | ("layer.", "layer_"),
50 | ("word_embeddings.weight", "word_embeddings"),
51 | ("position_embeddings.weight", "position_embeddings"),
52 | ("token_type_embeddings.weight", "token_type_embeddings"),
53 | (".", "/"),
54 | ("LayerNorm/weight", "LayerNorm/gamma"),
55 | ("LayerNorm/bias", "LayerNorm/beta"),
56 | ("weight", "kernel"),
57 | )
58 |
59 | if not os.path.isdir(ckpt_dir):
60 | os.makedirs(ckpt_dir)
61 |
62 | state_dict = model.state_dict()
63 |
64 | def to_tf_var_name(name: str):
65 | for patt, repl in iter(var_map):
66 | name = name.replace(patt, repl)
67 | return "bert/{}".format(name)
68 |
69 | def create_tf_var(tensor: np.ndarray, name: str, session: tf.Session):
70 | tf_dtype = tf.dtypes.as_dtype(tensor.dtype)
71 | tf_var = tf.get_variable(dtype=tf_dtype, shape=tensor.shape, name=name, initializer=tf.zeros_initializer())
72 | session.run(tf.variables_initializer([tf_var]))
73 | session.run(tf_var)
74 | return tf_var
75 |
76 | tf.reset_default_graph()
77 | with tf.Session() as session:
78 | for var_name in state_dict:
79 | tf_name = to_tf_var_name(var_name)
80 | torch_tensor = state_dict[var_name].numpy()
81 | if any([x in var_name for x in tensors_to_transpose]):
82 | torch_tensor = torch_tensor.T
83 | tf_var = create_tf_var(tensor=torch_tensor, name=tf_name, session=session)
84 | tf.keras.backend.set_value(tf_var, torch_tensor)
85 | tf_weight = session.run(tf_var)
86 | print("Successfully created {}: {}".format(tf_name, np.allclose(tf_weight, torch_tensor)))
87 |
88 | saver = tf.train.Saver(tf.trainable_variables())
89 | saver.save(session, os.path.join(ckpt_dir, model_name.replace("-", "_") + ".ckpt"))
90 |
91 |
92 | def main(raw_args=None):
93 | parser = argparse.ArgumentParser()
94 | parser.add_argument("--model_name", type=str, required=True, help="model name e.g. bert-base-uncased")
95 | parser.add_argument(
96 | "--cache_dir", type=str, default=None, required=False, help="Directory containing pytorch model"
97 | )
98 | parser.add_argument("--pytorch_model_path", type=str, required=True, help="/path/to/.bin")
99 | parser.add_argument("--tf_cache_dir", type=str, required=True, help="Directory in which to save tensorflow model")
100 | args = parser.parse_args(raw_args)
101 |
102 | model = BertModel.from_pretrained(
103 | pretrained_model_name_or_path=args.model_name,
104 | state_dict=torch.load(args.pytorch_model_path),
105 | cache_dir=args.cache_dir,
106 | )
107 |
108 | convert_pytorch_checkpoint_to_tf(model=model, ckpt_dir=args.tf_cache_dir, model_name=args.model_name)
109 |
110 |
111 | if __name__ == "__main__":
112 | main()
113 |
--------------------------------------------------------------------------------
/python/src/transformers/convert_dialogpt_original_pytorch_checkpoint_to_pytorch.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import torch
5 |
6 | from transformers.file_utils import WEIGHTS_NAME
7 |
8 |
9 | DIALOGPT_MODELS = ["small", "medium", "large"]
10 |
11 | OLD_KEY = "lm_head.decoder.weight"
12 | NEW_KEY = "lm_head.weight"
13 |
14 |
15 | def convert_dialogpt_checkpoint(checkpoint_path: str, pytorch_dump_folder_path: str):
16 | d = torch.load(checkpoint_path)
17 | d[NEW_KEY] = d.pop(OLD_KEY)
18 | os.makedirs(pytorch_dump_folder_path, exist_ok=True)
19 | torch.save(d, os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME))
20 |
21 |
22 | if __name__ == "__main__":
23 | parser = argparse.ArgumentParser()
24 | parser.add_argument("--dialogpt_path", default=".", type=str)
25 | args = parser.parse_args()
26 | for MODEL in DIALOGPT_MODELS:
27 | checkpoint_path = os.path.join(args.dialogpt_path, f"{MODEL}_ft.pkl")
28 | pytorch_dump_folder_path = f"./DialoGPT-{MODEL}"
29 | convert_dialogpt_checkpoint(
30 | checkpoint_path, pytorch_dump_folder_path,
31 | )
32 |
--------------------------------------------------------------------------------
/python/src/transformers/convert_gpt2_original_tf_checkpoint_to_pytorch.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Convert OpenAI GPT checkpoint."""
16 |
17 |
18 | import argparse
19 | import logging
20 |
21 | import torch
22 |
23 | from transformers import CONFIG_NAME, WEIGHTS_NAME, GPT2Config, GPT2Model, load_tf_weights_in_gpt2
24 |
25 |
26 | logging.basicConfig(level=logging.INFO)
27 |
28 |
29 | def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, pytorch_dump_folder_path):
30 | # Construct model
31 | if gpt2_config_file == "":
32 | config = GPT2Config()
33 | else:
34 | config = GPT2Config.from_json_file(gpt2_config_file)
35 | model = GPT2Model(config)
36 |
37 | # Load weights from numpy
38 | load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path)
39 |
40 | # Save pytorch-model
41 | pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME
42 | pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME
43 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
44 | torch.save(model.state_dict(), pytorch_weights_dump_path)
45 | print("Save configuration file to {}".format(pytorch_config_dump_path))
46 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
47 | f.write(config.to_json_string())
48 |
49 |
50 | if __name__ == "__main__":
51 | parser = argparse.ArgumentParser()
52 | # Required parameters
53 | parser.add_argument(
54 | "--gpt2_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
55 | )
56 | parser.add_argument(
57 | "--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
58 | )
59 | parser.add_argument(
60 | "--gpt2_config_file",
61 | default="",
62 | type=str,
63 | help="An optional config json file corresponding to the pre-trained OpenAI model. \n"
64 | "This specifies the model architecture.",
65 | )
66 | args = parser.parse_args()
67 | convert_gpt2_checkpoint_to_pytorch(args.gpt2_checkpoint_path, args.gpt2_config_file, args.pytorch_dump_folder_path)
68 |
--------------------------------------------------------------------------------
/python/src/transformers/convert_openai_original_tf_checkpoint_to_pytorch.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Convert OpenAI GPT checkpoint."""
16 |
17 |
18 | import argparse
19 | import logging
20 |
21 | import torch
22 |
23 | from transformers import CONFIG_NAME, WEIGHTS_NAME, OpenAIGPTConfig, OpenAIGPTModel, load_tf_weights_in_openai_gpt
24 |
25 |
26 | logging.basicConfig(level=logging.INFO)
27 |
28 |
29 | def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path):
30 | # Construct model
31 | if openai_config_file == "":
32 | config = OpenAIGPTConfig()
33 | else:
34 | config = OpenAIGPTConfig.from_json_file(openai_config_file)
35 | model = OpenAIGPTModel(config)
36 |
37 | # Load weights from numpy
38 | load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path)
39 |
40 | # Save pytorch-model
41 | pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME
42 | pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME
43 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
44 | torch.save(model.state_dict(), pytorch_weights_dump_path)
45 | print("Save configuration file to {}".format(pytorch_config_dump_path))
46 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
47 | f.write(config.to_json_string())
48 |
49 |
50 | if __name__ == "__main__":
51 | parser = argparse.ArgumentParser()
52 | # Required parameters
53 | parser.add_argument(
54 | "--openai_checkpoint_folder_path",
55 | default=None,
56 | type=str,
57 | required=True,
58 | help="Path to the TensorFlow checkpoint path.",
59 | )
60 | parser.add_argument(
61 | "--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
62 | )
63 | parser.add_argument(
64 | "--openai_config_file",
65 | default="",
66 | type=str,
67 | help="An optional config json file corresponding to the pre-trained OpenAI model. \n"
68 | "This specifies the model architecture.",
69 | )
70 | args = parser.parse_args()
71 | convert_openai_checkpoint_to_pytorch(
72 | args.openai_checkpoint_folder_path, args.openai_config_file, args.pytorch_dump_folder_path
73 | )
74 |
--------------------------------------------------------------------------------
/python/src/transformers/convert_t5_original_tf_checkpoint_to_pytorch.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The T5 authors and HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Convert T5 checkpoint."""
16 |
17 |
18 | import argparse
19 | import logging
20 |
21 | import torch
22 |
23 | from transformers import T5Config, T5Model, load_tf_weights_in_t5
24 |
25 |
26 | logging.basicConfig(level=logging.INFO)
27 |
28 |
29 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path):
30 | # Initialise PyTorch model
31 | config = T5Config.from_json_file(config_file)
32 | print("Building PyTorch model from configuration: {}".format(str(config)))
33 | model = T5Model(config)
34 |
35 | # Load weights from tf checkpoint
36 | load_tf_weights_in_t5(model, config, tf_checkpoint_path)
37 |
38 | # Save pytorch-model
39 | print("Save PyTorch model to {}".format(pytorch_dump_path))
40 | torch.save(model.state_dict(), pytorch_dump_path)
41 |
42 |
43 | if __name__ == "__main__":
44 | parser = argparse.ArgumentParser()
45 | # Required parameters
46 | parser.add_argument(
47 | "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
48 | )
49 | parser.add_argument(
50 | "--config_file",
51 | default=None,
52 | type=str,
53 | required=True,
54 | help="The config json file corresponding to the pre-trained T5 model. \n"
55 | "This specifies the model architecture.",
56 | )
57 | parser.add_argument(
58 | "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
59 | )
60 | args = parser.parse_args()
61 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path)
62 |
--------------------------------------------------------------------------------
/python/src/transformers/convert_transfo_xl_original_tf_checkpoint_to_pytorch.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Convert Transformer XL checkpoint and datasets."""
16 |
17 |
18 | import argparse
19 | import logging
20 | import os
21 | import pickle
22 | import sys
23 |
24 | import torch
25 |
26 | import transformers.tokenization_transfo_xl as data_utils
27 | from transformers import (
28 | CONFIG_NAME,
29 | WEIGHTS_NAME,
30 | TransfoXLConfig,
31 | TransfoXLLMHeadModel,
32 | load_tf_weights_in_transfo_xl,
33 | )
34 | from transformers.tokenization_transfo_xl import CORPUS_NAME, VOCAB_FILES_NAMES
35 |
36 |
37 | logging.basicConfig(level=logging.INFO)
38 |
39 | # We do this to be able to load python 2 datasets pickles
40 | # See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918
41 | data_utils.Vocab = data_utils.TransfoXLTokenizer
42 | data_utils.Corpus = data_utils.TransfoXLCorpus
43 | sys.modules["data_utils"] = data_utils
44 | sys.modules["vocabulary"] = data_utils
45 |
46 |
47 | def convert_transfo_xl_checkpoint_to_pytorch(
48 | tf_checkpoint_path, transfo_xl_config_file, pytorch_dump_folder_path, transfo_xl_dataset_file
49 | ):
50 | if transfo_xl_dataset_file:
51 | # Convert a pre-processed corpus (see original TensorFlow repo)
52 | with open(transfo_xl_dataset_file, "rb") as fp:
53 | corpus = pickle.load(fp, encoding="latin1")
54 | # Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term)
55 | pytorch_vocab_dump_path = pytorch_dump_folder_path + "/" + VOCAB_FILES_NAMES["pretrained_vocab_file"]
56 | print("Save vocabulary to {}".format(pytorch_vocab_dump_path))
57 | corpus_vocab_dict = corpus.vocab.__dict__
58 | torch.save(corpus_vocab_dict, pytorch_vocab_dump_path)
59 |
60 | corpus_dict_no_vocab = corpus.__dict__
61 | corpus_dict_no_vocab.pop("vocab", None)
62 | pytorch_dataset_dump_path = pytorch_dump_folder_path + "/" + CORPUS_NAME
63 | print("Save dataset to {}".format(pytorch_dataset_dump_path))
64 | torch.save(corpus_dict_no_vocab, pytorch_dataset_dump_path)
65 |
66 | if tf_checkpoint_path:
67 | # Convert a pre-trained TensorFlow model
68 | config_path = os.path.abspath(transfo_xl_config_file)
69 | tf_path = os.path.abspath(tf_checkpoint_path)
70 |
71 | print("Converting Transformer XL checkpoint from {} with config at {}".format(tf_path, config_path))
72 | # Initialise PyTorch model
73 | if transfo_xl_config_file == "":
74 | config = TransfoXLConfig()
75 | else:
76 | config = TransfoXLConfig.from_json_file(transfo_xl_config_file)
77 | print("Building PyTorch model from configuration: {}".format(str(config)))
78 | model = TransfoXLLMHeadModel(config)
79 |
80 | model = load_tf_weights_in_transfo_xl(model, config, tf_path)
81 | # Save pytorch-model
82 | pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME)
83 | pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME)
84 | print("Save PyTorch model to {}".format(os.path.abspath(pytorch_weights_dump_path)))
85 | torch.save(model.state_dict(), pytorch_weights_dump_path)
86 | print("Save configuration file to {}".format(os.path.abspath(pytorch_config_dump_path)))
87 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
88 | f.write(config.to_json_string())
89 |
90 |
91 | if __name__ == "__main__":
92 | parser = argparse.ArgumentParser()
93 | parser.add_argument(
94 | "--pytorch_dump_folder_path",
95 | default=None,
96 | type=str,
97 | required=True,
98 | help="Path to the folder to store the PyTorch model or dataset/vocab.",
99 | )
100 | parser.add_argument(
101 | "--tf_checkpoint_path",
102 | default="",
103 | type=str,
104 | help="An optional path to a TensorFlow checkpoint path to be converted.",
105 | )
106 | parser.add_argument(
107 | "--transfo_xl_config_file",
108 | default="",
109 | type=str,
110 | help="An optional config json file corresponding to the pre-trained BERT model. \n"
111 | "This specifies the model architecture.",
112 | )
113 | parser.add_argument(
114 | "--transfo_xl_dataset_file",
115 | default="",
116 | type=str,
117 | help="An optional dataset file to be converted in a vocabulary.",
118 | )
119 | args = parser.parse_args()
120 | convert_transfo_xl_checkpoint_to_pytorch(
121 | args.tf_checkpoint_path,
122 | args.transfo_xl_config_file,
123 | args.pytorch_dump_folder_path,
124 | args.transfo_xl_dataset_file,
125 | )
126 |
--------------------------------------------------------------------------------
/python/src/transformers/convert_xlm_original_pytorch_checkpoint_to_pytorch.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Convert OpenAI GPT checkpoint."""
16 |
17 |
18 | import argparse
19 | import json
20 | import logging
21 |
22 | import numpy
23 | import torch
24 |
25 | from transformers import CONFIG_NAME, WEIGHTS_NAME
26 | from transformers.tokenization_xlm import VOCAB_FILES_NAMES
27 |
28 |
29 | logging.basicConfig(level=logging.INFO)
30 |
31 |
32 | def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_path):
33 | # Load checkpoint
34 | chkpt = torch.load(xlm_checkpoint_path, map_location="cpu")
35 |
36 | state_dict = chkpt["model"]
37 |
38 | # We have the base model one level deeper than the original XLM repository
39 | two_levels_state_dict = {}
40 | for k, v in state_dict.items():
41 | if "pred_layer" in k:
42 | two_levels_state_dict[k] = v
43 | else:
44 | two_levels_state_dict["transformer." + k] = v
45 |
46 | config = chkpt["params"]
47 | config = dict((n, v) for n, v in config.items() if not isinstance(v, (torch.FloatTensor, numpy.ndarray)))
48 |
49 | vocab = chkpt["dico_word2id"]
50 | vocab = dict((s + "" if s.find("@@") == -1 and i > 13 else s.replace("@@", ""), i) for s, i in vocab.items())
51 |
52 | # Save pytorch-model
53 | pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME
54 | pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME
55 | pytorch_vocab_dump_path = pytorch_dump_folder_path + "/" + VOCAB_FILES_NAMES["vocab_file"]
56 |
57 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
58 | torch.save(two_levels_state_dict, pytorch_weights_dump_path)
59 |
60 | print("Save configuration file to {}".format(pytorch_config_dump_path))
61 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
62 | f.write(json.dumps(config, indent=2) + "\n")
63 |
64 | print("Save vocab file to {}".format(pytorch_config_dump_path))
65 | with open(pytorch_vocab_dump_path, "w", encoding="utf-8") as f:
66 | f.write(json.dumps(vocab, indent=2) + "\n")
67 |
68 |
69 | if __name__ == "__main__":
70 | parser = argparse.ArgumentParser()
71 | # Required parameters
72 | parser.add_argument(
73 | "--xlm_checkpoint_path", default=None, type=str, required=True, help="Path the official PyTorch dump."
74 | )
75 | parser.add_argument(
76 | "--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
77 | )
78 | args = parser.parse_args()
79 | convert_xlm_checkpoint_to_pytorch(args.xlm_checkpoint_path, args.pytorch_dump_folder_path)
80 |
--------------------------------------------------------------------------------
/python/src/transformers/convert_xlnet_original_tf_checkpoint_to_pytorch.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Convert BERT checkpoint."""
16 |
17 |
18 | import argparse
19 | import logging
20 | import os
21 |
22 | import torch
23 |
24 | from transformers import (
25 | CONFIG_NAME,
26 | WEIGHTS_NAME,
27 | XLNetConfig,
28 | XLNetForQuestionAnswering,
29 | XLNetForSequenceClassification,
30 | XLNetLMHeadModel,
31 | load_tf_weights_in_xlnet,
32 | )
33 |
34 |
35 | GLUE_TASKS_NUM_LABELS = {
36 | "cola": 2,
37 | "mnli": 3,
38 | "mrpc": 2,
39 | "sst-2": 2,
40 | "sts-b": 1,
41 | "qqp": 2,
42 | "qnli": 2,
43 | "rte": 2,
44 | "wnli": 2,
45 | }
46 |
47 |
48 | logging.basicConfig(level=logging.INFO)
49 |
50 |
51 | def convert_xlnet_checkpoint_to_pytorch(
52 | tf_checkpoint_path, bert_config_file, pytorch_dump_folder_path, finetuning_task=None
53 | ):
54 | # Initialise PyTorch model
55 | config = XLNetConfig.from_json_file(bert_config_file)
56 |
57 | finetuning_task = finetuning_task.lower() if finetuning_task is not None else ""
58 | if finetuning_task in GLUE_TASKS_NUM_LABELS:
59 | print("Building PyTorch XLNetForSequenceClassification model from configuration: {}".format(str(config)))
60 | config.finetuning_task = finetuning_task
61 | config.num_labels = GLUE_TASKS_NUM_LABELS[finetuning_task]
62 | model = XLNetForSequenceClassification(config)
63 | elif "squad" in finetuning_task:
64 | config.finetuning_task = finetuning_task
65 | model = XLNetForQuestionAnswering(config)
66 | else:
67 | model = XLNetLMHeadModel(config)
68 |
69 | # Load weights from tf checkpoint
70 | load_tf_weights_in_xlnet(model, config, tf_checkpoint_path)
71 |
72 | # Save pytorch-model
73 | pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME)
74 | pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME)
75 | print("Save PyTorch model to {}".format(os.path.abspath(pytorch_weights_dump_path)))
76 | torch.save(model.state_dict(), pytorch_weights_dump_path)
77 | print("Save configuration file to {}".format(os.path.abspath(pytorch_config_dump_path)))
78 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
79 | f.write(config.to_json_string())
80 |
81 |
82 | if __name__ == "__main__":
83 | parser = argparse.ArgumentParser()
84 | # Required parameters
85 | parser.add_argument(
86 | "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
87 | )
88 | parser.add_argument(
89 | "--xlnet_config_file",
90 | default=None,
91 | type=str,
92 | required=True,
93 | help="The config json file corresponding to the pre-trained XLNet model. \n"
94 | "This specifies the model architecture.",
95 | )
96 | parser.add_argument(
97 | "--pytorch_dump_folder_path",
98 | default=None,
99 | type=str,
100 | required=True,
101 | help="Path to the folder to store the PyTorch model or dataset/vocab.",
102 | )
103 | parser.add_argument(
104 | "--finetuning_task",
105 | default=None,
106 | type=str,
107 | help="Name of a task on which the XLNet TensorFloaw model was fine-tuned",
108 | )
109 | args = parser.parse_args()
110 | print(args)
111 |
112 | convert_xlnet_checkpoint_to_pytorch(
113 | args.tf_checkpoint_path, args.xlnet_config_file, args.pytorch_dump_folder_path, args.finetuning_task
114 | )
115 |
--------------------------------------------------------------------------------
/python/src/transformers/data/__init__.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa
2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this
3 | # module, but to preserve other warnings. So, don't check this module at all.
4 |
5 | from .metrics import is_sklearn_available
6 | from .processors import (
7 | DataProcessor,
8 | InputExample,
9 | InputFeatures,
10 | SingleSentenceClassificationProcessor,
11 | SquadExample,
12 | SquadFeatures,
13 | SquadV1Processor,
14 | SquadV2Processor,
15 | glue_convert_examples_to_features,
16 | glue_output_modes,
17 | glue_processors,
18 | glue_tasks_num_labels,
19 | squad_convert_examples_to_features,
20 | xnli_output_modes,
21 | xnli_processors,
22 | xnli_tasks_num_labels,
23 | UMLMProcessor,
24 | umlm_convert_example_to_features,
25 | UOPProcessor,
26 | uop_convert_example_to_features,
27 | FriendsQAProcessor,
28 | friendsqa_convert_example_to_features,
29 | )
30 |
31 |
32 | if is_sklearn_available():
33 | from .metrics import glue_compute_metrics, xnli_compute_metrics
34 |
--------------------------------------------------------------------------------
/python/src/transformers/data/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | try:
18 | from scipy.stats import pearsonr, spearmanr
19 | from sklearn.metrics import matthews_corrcoef, f1_score
20 |
21 | _has_sklearn = True
22 | except (AttributeError, ImportError):
23 | _has_sklearn = False
24 |
25 |
26 | def is_sklearn_available():
27 | return _has_sklearn
28 |
29 |
30 | if _has_sklearn:
31 |
32 | def simple_accuracy(preds, labels):
33 | return (preds == labels).mean()
34 |
35 | def acc_and_f1(preds, labels):
36 | acc = simple_accuracy(preds, labels)
37 | f1 = f1_score(y_true=labels, y_pred=preds)
38 | return {
39 | "acc": acc,
40 | "f1": f1,
41 | "acc_and_f1": (acc + f1) / 2,
42 | }
43 |
44 | def pearson_and_spearman(preds, labels):
45 | pearson_corr = pearsonr(preds, labels)[0]
46 | spearman_corr = spearmanr(preds, labels)[0]
47 | return {
48 | "pearson": pearson_corr,
49 | "spearmanr": spearman_corr,
50 | "corr": (pearson_corr + spearman_corr) / 2,
51 | }
52 |
53 | def glue_compute_metrics(task_name, preds, labels):
54 | assert len(preds) == len(labels)
55 | if task_name == "cola":
56 | return {"mcc": matthews_corrcoef(labels, preds)}
57 | elif task_name == "sst-2":
58 | return {"acc": simple_accuracy(preds, labels)}
59 | elif task_name == "mrpc":
60 | return acc_and_f1(preds, labels)
61 | elif task_name == "sts-b":
62 | return pearson_and_spearman(preds, labels)
63 | elif task_name == "qqp":
64 | return acc_and_f1(preds, labels)
65 | elif task_name == "mnli":
66 | return {"acc": simple_accuracy(preds, labels)}
67 | elif task_name == "mnli-mm":
68 | return {"acc": simple_accuracy(preds, labels)}
69 | elif task_name == "qnli":
70 | return {"acc": simple_accuracy(preds, labels)}
71 | elif task_name == "rte":
72 | return {"acc": simple_accuracy(preds, labels)}
73 | elif task_name == "wnli":
74 | return {"acc": simple_accuracy(preds, labels)}
75 | elif task_name == "hans":
76 | return {"acc": simple_accuracy(preds, labels)}
77 | elif task_name == "umlm":
78 | return {"acc": simple_accuracy(preds, labels)}
79 | elif task_name == "uop":
80 | return {"acc": simple_accuracy(preds, labels)}
81 | else:
82 | raise KeyError(task_name)
83 |
84 | def xnli_compute_metrics(task_name, preds, labels):
85 | assert len(preds) == len(labels)
86 | if task_name == "xnli":
87 | return {"acc": simple_accuracy(preds, labels)}
88 | else:
89 | raise KeyError(task_name)
90 |
--------------------------------------------------------------------------------
/python/src/transformers/data/processors/__init__.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa
2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this
3 | # module, but to preserve other warnings. So, don't check this module at all.
4 |
5 | from .glue import glue_convert_examples_to_features, glue_output_modes, glue_processors, glue_tasks_num_labels
6 | from .squad import SquadExample, SquadFeatures, SquadV1Processor, SquadV2Processor, squad_convert_examples_to_features
7 | from .utils import DataProcessor, InputExample, InputFeatures, SingleSentenceClassificationProcessor
8 | from .xnli import xnli_output_modes, xnli_processors, xnli_tasks_num_labels
9 | from .umlm import UMLMProcessor, umlm_convert_example_to_features
10 | from .uop import UOPExample, UOPFeatures, UOPProcessor, uop_convert_example_to_features
11 | from .friendsqa import FriendsQAProcessor, friendsqa_convert_example_to_features
12 |
--------------------------------------------------------------------------------
/python/src/transformers/data/processors/umlm.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import os
4 | from .utils import DataProcessor, InputFeatures, InputExample
5 |
6 | logger = logging.getLogger(__name__)
7 |
8 |
9 | class UMLMProcessor(DataProcessor):
10 |
11 | def __init__(self, vocab_list):
12 | self.vocab_list = vocab_list
13 |
14 | def get_example_from_tensor_dict(self, tensor_dict):
15 | raise NotImplementedError()
16 |
17 | def get_labels(self):
18 | return self.vocab_list
19 |
20 | def get_train_examples(self, data_dir):
21 | with open(os.path.join(data_dir, "umlm_train.json"), "r", encoding="utf-8") as reader:
22 | input_data = json.load(reader)
23 | return self._create_examples(input_data, "train")
24 |
25 | def get_dev_examples(self, data_dir):
26 | with open(os.path.join(data_dir, "umlm_dev.json"), "r", encoding="utf-8") as reader:
27 | input_data = json.load(reader)
28 | return self._create_examples(input_data, "dev")
29 |
30 | def _create_examples(self, lines, set_type):
31 | examples = []
32 | for (i, line) in enumerate(lines):
33 | guid = "%s-%s" % (set_type, i)
34 | text_a = line[0]
35 | label = line[1]
36 | examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
37 | return examples
38 |
39 |
40 | def umlm_convert_example_to_features(examples, tokenizer, max_length=512, pad_token=0, pad_token_segment_id=0):
41 | processor = UMLMProcessor(tokenizer.vocab)
42 | label_list = processor.get_labels()
43 | label_map = {label: i for i, label in enumerate(label_list)}
44 | features = []
45 | for (ex_index, example) in enumerate(examples):
46 | len_examples = len(examples)
47 | if ex_index % 10000 == 0:
48 | logger.info("Writing example %d/%d" % (ex_index, len_examples))
49 | inputs = tokenizer.encode_plus(
50 | example.text_a, None, add_special_tokens=True, max_length=max_length, return_token_type_ids=True,
51 | )
52 | input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"]
53 | attention_mask = [1] * len(input_ids)
54 | padding_length = max_length - len(input_ids)
55 | input_ids = input_ids + ([pad_token] * padding_length)
56 | attention_mask = attention_mask + ([0] * padding_length)
57 | token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length)
58 | assert len(input_ids) == max_length, "Error with input length {} vs {}".format(len(input_ids), max_length)
59 | assert len(attention_mask) == max_length, "Error with input length {} vs {}".format(
60 | len(attention_mask), max_length
61 | )
62 | assert len(token_type_ids) == max_length, "Error with input length {} vs {}".format(
63 | len(token_type_ids), max_length
64 | )
65 | label = label_map[example.label]
66 | if ex_index < 5:
67 | logger.info("*** Example ***")
68 | logger.info("guid: %s" % (example.guid))
69 | logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
70 | logger.info("attention_mask: %s" % " ".join([str(x) for x in attention_mask]))
71 | logger.info("token_type_ids: %s" % " ".join([str(x) for x in token_type_ids]))
72 | logger.info("label: %s (id = %d)" % (example.label, label))
73 | features.append(
74 | InputFeatures(
75 | input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, label=label
76 | )
77 | )
78 | return features
--------------------------------------------------------------------------------
/python/src/transformers/data/processors/uop.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import os
4 | from functools import partial
5 | from multiprocessing import Pool, cpu_count
6 | from collections import OrderedDict
7 | from collections import Counter
8 | import operator
9 |
10 | import numpy as np
11 | from tqdm import tqdm
12 |
13 | from ...file_utils import is_tf_available, is_torch_available
14 | from ...tokenization_utils import PreTrainedTokenizer
15 | from .utils import DataProcessor
16 | import copy
17 |
18 | logger = logging.getLogger(__name__)
19 |
20 |
21 | class UOPProcessor(DataProcessor):
22 |
23 | def get_example_from_tensor_dict(self, tensor_dict):
24 | raise NotImplementedError()
25 |
26 | def get_labels(self):
27 | return ["Yes", "No"]
28 |
29 | def get_train_examples(self, data_dir):
30 | with open(os.path.join(data_dir, "uop_train.json"), "r", encoding="utf-8") as reader:
31 | input_data = json.load(reader)
32 | return self._create_examples(input_data, "train")
33 |
34 | def get_dev_examples(self, data_dir):
35 | with open(os.path.join(data_dir, "uop_dev.json"), "r", encoding="utf-8") as reader:
36 | input_data = json.load(reader)
37 | return self._create_examples(input_data, "dev")
38 |
39 | def _create_examples(self, input_data, set_type):
40 | examples = []
41 | for i in tqdm(range(len(input_data))):
42 | guid = "%s-%s" % (set_type, str(i))
43 | utterances = input_data[i]["utterances"]
44 | label = input_data[i]["is_correct_order"]
45 | examples.append(UOPExample(guid=guid, contents=utterances, label=label))
46 | return examples
47 |
48 |
49 | def uop_convert_example_to_features(examples, tokenizer, max_line_length=128, max_line_number=107):
50 | processor = UOPProcessor()
51 | label_list = processor.get_labels()
52 | label_map = {label: i for i, label in enumerate(label_list)}
53 | features = []
54 | for (ex_index, example) in enumerate(examples):
55 | len_examples = len(examples)
56 | if ex_index % 100 == 0:
57 | logger.info("Writing example %d/%d" % (ex_index, len_examples))
58 | contents = example.contents
59 | lines_input_ids = []
60 | attention_masks = []
61 | for content in contents:
62 | tokens = tokenizer.tokenize(content)
63 | tokens = tokens[:max_line_length - 2]
64 | inputs = tokenizer.encode_plus(" ".join(tokens), None, add_special_tokens=True,
65 | max_length=max_line_length, )
66 | input_ids = inputs["input_ids"]
67 | attention_mask = [1] * len(input_ids)
68 | padding_length = max_line_length - len(input_ids)
69 | input_ids = input_ids + ([tokenizer.pad_token_id] * padding_length)
70 | attention_mask = attention_mask + ([0] * padding_length)
71 | attention_masks.append(attention_mask)
72 | lines_input_ids.append(input_ids)
73 | label = label_map[example.label]
74 | if ex_index < 5:
75 | logger.info("*** Example ***")
76 | logger.info("guid: %s" % (example.guid))
77 | logger.info("lines_input_ids: %s" % " ".join([str(x) for x in lines_input_ids]))
78 | logger.info("attention_masks: %s" % " ".join([str(x) for x in attention_masks]))
79 | logger.info("number of lines: %s" % str(len(lines_input_ids)))
80 | logger.info("label: %s (id = %d)" % (example.label, label))
81 | if len(lines_input_ids) > max_line_number:
82 | lines_input_ids = lines_input_ids[:max_line_number]
83 | attention_masks = attention_masks[:max_line_number]
84 | features.append(UOPFeatures(lines_input_ids=lines_input_ids, attention_masks=attention_masks, label=label))
85 | return features
86 |
87 |
88 | class UOPExample(object):
89 | def __init__(self, guid, contents, label=None):
90 | self.guid = guid
91 | self.contents = contents
92 | self.label = label
93 |
94 | def __repr__(self):
95 | return str(self.to_json_string())
96 |
97 | def to_dict(self):
98 | """Serializes this instance to a Python dictionary."""
99 | output = copy.deepcopy(self.__dict__)
100 | return output
101 |
102 | def to_json_string(self):
103 | """Serializes this instance to a JSON string."""
104 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
105 |
106 |
107 | class UOPFeatures(object):
108 | def __init__(self, lines_input_ids, attention_masks=None, label=None):
109 | self.lines_input_ids = lines_input_ids
110 | self.attention_masks = attention_masks
111 | self.label = label
112 |
113 | def __repr__(self):
114 | return str(self.to_json_string())
115 |
116 | def to_dict(self):
117 | """Serializes this instance to a Python dictionary."""
118 | output = copy.deepcopy(self.__dict__)
119 | return output
120 |
121 | def to_json_string(self):
122 | """Serializes this instance to a JSON string."""
123 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
124 |
125 |
126 |
--------------------------------------------------------------------------------
/python/src/transformers/data/processors/xnli.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """ XNLI utils (dataset loading and evaluation) """
17 |
18 |
19 | import logging
20 | import os
21 |
22 | from .utils import DataProcessor, InputExample
23 |
24 |
25 | logger = logging.getLogger(__name__)
26 |
27 |
28 | class XnliProcessor(DataProcessor):
29 | """Processor for the XNLI dataset.
30 | Adapted from https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/run_classifier.py#L207"""
31 |
32 | def __init__(self, language, train_language=None):
33 | self.language = language
34 | self.train_language = train_language
35 |
36 | def get_train_examples(self, data_dir):
37 | """See base class."""
38 | lg = self.language if self.train_language is None else self.train_language
39 | lines = self._read_tsv(os.path.join(data_dir, "XNLI-MT-1.0/multinli/multinli.train.{}.tsv".format(lg)))
40 | examples = []
41 | for (i, line) in enumerate(lines):
42 | if i == 0:
43 | continue
44 | guid = "%s-%s" % ("train", i)
45 | text_a = line[0]
46 | text_b = line[1]
47 | label = "contradiction" if line[2] == "contradictory" else line[2]
48 | assert isinstance(text_a, str) and isinstance(text_b, str) and isinstance(label, str)
49 | examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
50 | return examples
51 |
52 | def get_test_examples(self, data_dir):
53 | """See base class."""
54 | lines = self._read_tsv(os.path.join(data_dir, "XNLI-1.0/xnli.test.tsv"))
55 | examples = []
56 | for (i, line) in enumerate(lines):
57 | if i == 0:
58 | continue
59 | language = line[0]
60 | if language != self.language:
61 | continue
62 | guid = "%s-%s" % ("test", i)
63 | text_a = line[6]
64 | text_b = line[7]
65 | label = line[1]
66 | assert isinstance(text_a, str) and isinstance(text_b, str) and isinstance(label, str)
67 | examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
68 | return examples
69 |
70 | def get_labels(self):
71 | """See base class."""
72 | return ["contradiction", "entailment", "neutral"]
73 |
74 |
75 | xnli_processors = {
76 | "xnli": XnliProcessor,
77 | }
78 |
79 | xnli_output_modes = {
80 | "xnli": "classification",
81 | }
82 |
83 | xnli_tasks_num_labels = {
84 | "xnli": 3,
85 | }
86 |
--------------------------------------------------------------------------------
/python/src/transformers/modeling_camembert.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 Inria, Facebook AI Research and the HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """PyTorch CamemBERT model. """
17 |
18 | import logging
19 |
20 | from .configuration_camembert import CamembertConfig
21 | from .file_utils import add_start_docstrings
22 | from .modeling_roberta import (
23 | RobertaForMaskedLM,
24 | RobertaForMultipleChoice,
25 | RobertaForQuestionAnswering,
26 | RobertaForSequenceClassification,
27 | RobertaForTokenClassification,
28 | RobertaModel,
29 | )
30 |
31 |
32 | logger = logging.getLogger(__name__)
33 |
34 | CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
35 | "camembert-base": "https://s3.amazonaws.com/models.huggingface.co/bert/camembert-base-pytorch_model.bin",
36 | "umberto-commoncrawl-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/Musixmatch/umberto-commoncrawl-cased-v1/pytorch_model.bin",
37 | "umberto-wikipedia-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/Musixmatch/umberto-wikipedia-uncased-v1/pytorch_model.bin",
38 | }
39 |
40 | CAMEMBERT_START_DOCSTRING = r"""
41 |
42 | This model is a PyTorch `torch.nn.Module `_ sub-class.
43 | Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
44 | usage and behavior.
45 |
46 | Parameters:
47 | config (:class:`~transformers.CamembertConfig`): Model configuration class with all the parameters of the
48 | model. Initializing with a config file does not load the weights associated with the model, only the
49 | configuration.
50 | Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
51 | """
52 |
53 |
54 | @add_start_docstrings(
55 | "The bare CamemBERT Model transformer outputting raw hidden-states without any specific head on top.",
56 | CAMEMBERT_START_DOCSTRING,
57 | )
58 | class CamembertModel(RobertaModel):
59 | """
60 | This class overrides :class:`~transformers.RobertaModel`. Please check the
61 | superclass for the appropriate documentation alongside usage examples.
62 | """
63 |
64 | config_class = CamembertConfig
65 | pretrained_model_archive_map = CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP
66 |
67 |
68 | @add_start_docstrings(
69 | """CamemBERT Model with a `language modeling` head on top. """, CAMEMBERT_START_DOCSTRING,
70 | )
71 | class CamembertForMaskedLM(RobertaForMaskedLM):
72 | """
73 | This class overrides :class:`~transformers.RobertaForMaskedLM`. Please check the
74 | superclass for the appropriate documentation alongside usage examples.
75 | """
76 |
77 | config_class = CamembertConfig
78 | pretrained_model_archive_map = CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP
79 |
80 |
81 | @add_start_docstrings(
82 | """CamemBERT Model transformer with a sequence classification/regression head on top (a linear layer
83 | on top of the pooled output) e.g. for GLUE tasks. """,
84 | CAMEMBERT_START_DOCSTRING,
85 | )
86 | class CamembertForSequenceClassification(RobertaForSequenceClassification):
87 | """
88 | This class overrides :class:`~transformers.RobertaForSequenceClassification`. Please check the
89 | superclass for the appropriate documentation alongside usage examples.
90 | """
91 |
92 | config_class = CamembertConfig
93 | pretrained_model_archive_map = CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP
94 |
95 |
96 | @add_start_docstrings(
97 | """CamemBERT Model with a multiple choice classification head on top (a linear layer on top of
98 | the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
99 | CAMEMBERT_START_DOCSTRING,
100 | )
101 | class CamembertForMultipleChoice(RobertaForMultipleChoice):
102 | """
103 | This class overrides :class:`~transformers.RobertaForMultipleChoice`. Please check the
104 | superclass for the appropriate documentation alongside usage examples.
105 | """
106 |
107 | config_class = CamembertConfig
108 | pretrained_model_archive_map = CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP
109 |
110 |
111 | @add_start_docstrings(
112 | """CamemBERT Model with a token classification head on top (a linear layer on top of
113 | the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
114 | CAMEMBERT_START_DOCSTRING,
115 | )
116 | class CamembertForTokenClassification(RobertaForTokenClassification):
117 | """
118 | This class overrides :class:`~transformers.RobertaForTokenClassification`. Please check the
119 | superclass for the appropriate documentation alongside usage examples.
120 | """
121 |
122 | config_class = CamembertConfig
123 | pretrained_model_archive_map = CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP
124 |
125 |
126 | @add_start_docstrings(
127 | """CamemBERT Model with a span classification head on top for extractive question-answering tasks like SQuAD
128 | (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits` """,
129 | CAMEMBERT_START_DOCSTRING,
130 | )
131 | class CamembertForQuestionAnswering(RobertaForQuestionAnswering):
132 | """
133 | This class overrides :class:`~transformers.RobertaForQuestionAnswering`. Please check the
134 | superclass for the appropriate documentation alongside usage examples.
135 | """
136 |
137 | config_class = CamembertConfig
138 | pretrained_model_archive_map = CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP
139 |
--------------------------------------------------------------------------------
/python/src/transformers/modeling_tf_camembert.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """ TF 2.0 CamemBERT model. """
17 |
18 |
19 | import logging
20 |
21 | from .configuration_camembert import CamembertConfig
22 | from .file_utils import add_start_docstrings
23 | from .modeling_tf_roberta import (
24 | TFRobertaForMaskedLM,
25 | TFRobertaForSequenceClassification,
26 | TFRobertaForTokenClassification,
27 | TFRobertaModel,
28 | )
29 |
30 |
31 | logger = logging.getLogger(__name__)
32 |
33 | TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {}
34 |
35 |
36 | CAMEMBERT_START_DOCSTRING = r"""
37 |
38 | .. note::
39 |
40 | TF 2.0 models accepts two formats as inputs:
41 |
42 | - having all inputs as keyword arguments (like PyTorch models), or
43 | - having all inputs as a list, tuple or dict in the first positional arguments.
44 |
45 | This second option is useful when using :obj:`tf.keras.Model.fit()` method which currently requires having
46 | all the tensors in the first argument of the model call function: :obj:`model(inputs)`.
47 |
48 | If you choose this second option, there are three possibilities you can use to gather all the input Tensors
49 | in the first positional argument :
50 |
51 | - a single Tensor with input_ids only and nothing else: :obj:`model(inputs_ids)`
52 | - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
53 | :obj:`model([input_ids, attention_mask])` or :obj:`model([input_ids, attention_mask, token_type_ids])`
54 | - a dictionary with one or several input Tensors associated to the input names given in the docstring:
55 | :obj:`model({'input_ids': input_ids, 'token_type_ids': token_type_ids})`
56 |
57 | Parameters:
58 | config (:class:`~transformers.CamembertConfig`): Model configuration class with all the parameters of the
59 | model. Initializing with a config file does not load the weights associated with the model, only the configuration.
60 | Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
61 | """
62 |
63 |
64 | @add_start_docstrings(
65 | "The bare CamemBERT Model transformer outputting raw hidden-states without any specific head on top.",
66 | CAMEMBERT_START_DOCSTRING,
67 | )
68 | class TFCamembertModel(TFRobertaModel):
69 | """
70 | This class overrides :class:`~transformers.TFRobertaModel`. Please check the
71 | superclass for the appropriate documentation alongside usage examples.
72 | """
73 |
74 | config_class = CamembertConfig
75 | pretrained_model_archive_map = TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP
76 |
77 |
78 | @add_start_docstrings(
79 | """CamemBERT Model with a `language modeling` head on top. """, CAMEMBERT_START_DOCSTRING,
80 | )
81 | class TFCamembertForMaskedLM(TFRobertaForMaskedLM):
82 | """
83 | This class overrides :class:`~transformers.TFRobertaForMaskedLM`. Please check the
84 | superclass for the appropriate documentation alongside usage examples.
85 | """
86 |
87 | config_class = CamembertConfig
88 | pretrained_model_archive_map = TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP
89 |
90 |
91 | @add_start_docstrings(
92 | """CamemBERT Model transformer with a sequence classification/regression head on top (a linear layer
93 | on top of the pooled output) e.g. for GLUE tasks. """,
94 | CAMEMBERT_START_DOCSTRING,
95 | )
96 | class TFCamembertForSequenceClassification(TFRobertaForSequenceClassification):
97 | """
98 | This class overrides :class:`~transformers.TFRobertaForSequenceClassification`. Please check the
99 | superclass for the appropriate documentation alongside usage examples.
100 | """
101 |
102 | config_class = CamembertConfig
103 | pretrained_model_archive_map = TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP
104 |
105 |
106 | @add_start_docstrings(
107 | """CamemBERT Model with a token classification head on top (a linear layer on top of
108 | the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
109 | CAMEMBERT_START_DOCSTRING,
110 | )
111 | class TFCamembertForTokenClassification(TFRobertaForTokenClassification):
112 | """
113 | This class overrides :class:`~transformers.TFRobertaForTokenClassification`. Please check the
114 | superclass for the appropriate documentation alongside usage examples.
115 | """
116 |
117 | config_class = CamembertConfig
118 | pretrained_model_archive_map = TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP
119 |
--------------------------------------------------------------------------------
/python/src/transformers/modeling_tf_transfo_xl_utilities.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """ A TF 2.0 Adaptive Softmax for Transformer XL model.
17 | """
18 |
19 |
20 | import tensorflow as tf
21 |
22 | from .modeling_tf_utils import shape_list
23 |
24 |
25 | class TFAdaptiveSoftmaxMask(tf.keras.layers.Layer):
26 | def __init__(self, vocab_size, d_embed, d_proj, cutoffs, div_val=1, keep_order=False, **kwargs):
27 | super().__init__(**kwargs)
28 |
29 | self.vocab_size = vocab_size
30 | self.d_embed = d_embed
31 | self.d_proj = d_proj
32 |
33 | self.cutoffs = cutoffs + [vocab_size]
34 | self.cutoff_ends = [0] + self.cutoffs
35 | self.div_val = div_val
36 |
37 | self.shortlist_size = self.cutoffs[0]
38 | self.n_clusters = len(self.cutoffs) - 1
39 | self.head_size = self.shortlist_size + self.n_clusters
40 | self.keep_order = keep_order
41 |
42 | self.out_layers = []
43 | self.out_projs = []
44 |
45 | def build(self, input_shape):
46 | if self.n_clusters > 0:
47 | self.cluster_weight = self.add_weight(
48 | shape=(self.n_clusters, self.d_embed), initializer="zeros", trainable=True, name="cluster_weight"
49 | )
50 | self.cluster_bias = self.add_weight(
51 | shape=(self.n_clusters,), initializer="zeros", trainable=True, name="cluster_bias"
52 | )
53 |
54 | if self.div_val == 1:
55 | for i in range(len(self.cutoffs)):
56 | if self.d_proj != self.d_embed:
57 | weight = self.add_weight(
58 | shape=(self.d_embed, self.d_proj),
59 | initializer="zeros",
60 | trainable=True,
61 | name="out_projs_._{}".format(i),
62 | )
63 | self.out_projs.append(weight)
64 | else:
65 | self.out_projs.append(None)
66 | weight = self.add_weight(
67 | shape=(self.vocab_size, self.d_embed,),
68 | initializer="zeros",
69 | trainable=True,
70 | name="out_layers_._{}_._weight".format(i),
71 | )
72 | bias = self.add_weight(
73 | shape=(self.vocab_size,),
74 | initializer="zeros",
75 | trainable=True,
76 | name="out_layers_._{}_._bias".format(i),
77 | )
78 | self.out_layers.append((weight, bias))
79 | else:
80 | for i in range(len(self.cutoffs)):
81 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
82 | d_emb_i = self.d_embed // (self.div_val ** i)
83 |
84 | weight = self.add_weight(
85 | shape=(d_emb_i, self.d_proj), initializer="zeros", trainable=True, name="out_projs_._{}".format(i)
86 | )
87 | self.out_projs.append(weight)
88 | weight = self.add_weight(
89 | shape=(r_idx - l_idx, d_emb_i,),
90 | initializer="zeros",
91 | trainable=True,
92 | name="out_layers_._{}_._weight".format(i),
93 | )
94 | bias = self.add_weight(
95 | shape=(r_idx - l_idx,),
96 | initializer="zeros",
97 | trainable=True,
98 | name="out_layers_._{}_._bias".format(i),
99 | )
100 | self.out_layers.append((weight, bias))
101 | super().build(input_shape)
102 |
103 | @staticmethod
104 | def _logit(x, W, b, proj=None):
105 | y = x
106 | if proj is not None:
107 | y = tf.einsum("ibd,ed->ibe", y, proj)
108 | return tf.einsum("ibd,nd->ibn", y, W) + b
109 |
110 | @staticmethod
111 | def _gather_logprob(logprob, target):
112 | lp_size = shape_list(logprob)
113 | r = tf.range(lp_size[0])
114 | idx = tf.stack([r, target], 1)
115 | return tf.gather_nd(logprob, idx)
116 |
117 | def call(self, inputs, return_mean=True, training=False):
118 | hidden, target = inputs
119 | head_logprob = 0
120 | if self.n_clusters == 0:
121 | output = self._logit(hidden, self.out_layers[0][0], self.out_layers[0][1], self.out_projs[0])
122 | if target is not None:
123 | loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=target, logits=output)
124 | out = tf.nn.log_softmax(output, axis=-1)
125 | else:
126 | hidden_sizes = shape_list(hidden)
127 | out = []
128 | loss = tf.zeros(hidden_sizes[:2], dtype=tf.float32)
129 | for i in range(len(self.cutoffs)):
130 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
131 | if target is not None:
132 | mask = (target >= l_idx) & (target < r_idx)
133 | mask_idx = tf.where(mask)
134 | cur_target = tf.boolean_mask(target, mask) - l_idx
135 |
136 | if self.div_val == 1:
137 | cur_W = self.out_layers[0][0][l_idx:r_idx]
138 | cur_b = self.out_layers[0][1][l_idx:r_idx]
139 | else:
140 | cur_W = self.out_layers[i][0]
141 | cur_b = self.out_layers[i][1]
142 |
143 | if i == 0:
144 | cur_W = tf.concat([cur_W, self.cluster_weight], 0)
145 | cur_b = tf.concat([cur_b, self.cluster_bias], 0)
146 |
147 | head_logit = self._logit(hidden, cur_W, cur_b, self.out_projs[0])
148 | head_logprob = tf.nn.log_softmax(head_logit)
149 | out.append(head_logprob[..., : self.cutoffs[0]])
150 | if target is not None:
151 | cur_head_logprob = tf.boolean_mask(head_logprob, mask)
152 | cur_logprob = self._gather_logprob(cur_head_logprob, cur_target)
153 | else:
154 | tail_logit = self._logit(hidden, cur_W, cur_b, self.out_projs[i])
155 | tail_logprob = tf.nn.log_softmax(tail_logit)
156 | cluster_prob_idx = self.cutoffs[0] + i - 1 # No probability for the head cluster
157 | logprob_i = head_logprob[..., cluster_prob_idx, None] + tail_logprob
158 | out.append(logprob_i)
159 | if target is not None:
160 | cur_head_logprob = tf.boolean_mask(head_logprob, mask)
161 | cur_tail_logprob = tf.boolean_mask(tail_logprob, mask)
162 | cur_logprob = self._gather_logprob(cur_tail_logprob, cur_target)
163 | cur_logprob += cur_head_logprob[:, self.cutoff_ends[1] + i - 1]
164 | if target is not None:
165 | loss += tf.scatter_nd(mask_idx, -cur_logprob, tf.cast(shape_list(loss), dtype=tf.int64))
166 | out = tf.concat(out, axis=-1)
167 |
168 | if target is not None:
169 | if return_mean:
170 | loss = tf.reduce_mean(loss)
171 | # Add the training-time loss value to the layer using `self.add_loss()`.
172 | self.add_loss(loss)
173 |
174 | # Log the loss as a metric (we could log arbitrary metrics,
175 | # including different metrics for training and inference.
176 | self.add_metric(loss, name=self.name, aggregation="mean" if return_mean else "")
177 |
178 | return out
179 |
--------------------------------------------------------------------------------
/python/src/transformers/modeling_tf_xlm_roberta.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 Facebook AI Research and the HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """ TF 2.0 XLM-RoBERTa model. """
17 |
18 |
19 | import logging
20 |
21 | from .configuration_xlm_roberta import XLMRobertaConfig
22 | from .file_utils import add_start_docstrings
23 | from .modeling_tf_roberta import (
24 | TFRobertaForMaskedLM,
25 | TFRobertaForSequenceClassification,
26 | TFRobertaForTokenClassification,
27 | TFRobertaModel,
28 | )
29 |
30 |
31 | logger = logging.getLogger(__name__)
32 |
33 | TF_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP = {}
34 |
35 |
36 | XLM_ROBERTA_START_DOCSTRING = r"""
37 |
38 | .. note::
39 |
40 | TF 2.0 models accepts two formats as inputs:
41 |
42 | - having all inputs as keyword arguments (like PyTorch models), or
43 | - having all inputs as a list, tuple or dict in the first positional arguments.
44 |
45 | This second option is useful when using :obj:`tf.keras.Model.fit()` method which currently requires having
46 | all the tensors in the first argument of the model call function: :obj:`model(inputs)`.
47 |
48 | If you choose this second option, there are three possibilities you can use to gather all the input Tensors
49 | in the first positional argument :
50 |
51 | - a single Tensor with input_ids only and nothing else: :obj:`model(inputs_ids)`
52 | - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
53 | :obj:`model([input_ids, attention_mask])` or :obj:`model([input_ids, attention_mask, token_type_ids])`
54 | - a dictionary with one or several input Tensors associated to the input names given in the docstring:
55 | :obj:`model({'input_ids': input_ids, 'token_type_ids': token_type_ids})`
56 |
57 | Parameters:
58 | config (:class:`~transformers.XLMRobertaConfig`): Model configuration class with all the parameters of the
59 | model. Initializing with a config file does not load the weights associated with the model, only the configuration.
60 | Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
61 | """
62 |
63 |
64 | @add_start_docstrings(
65 | "The bare XLM-RoBERTa Model transformer outputting raw hidden-states without any specific head on top.",
66 | XLM_ROBERTA_START_DOCSTRING,
67 | )
68 | class TFXLMRobertaModel(TFRobertaModel):
69 | """
70 | This class overrides :class:`~transformers.TFRobertaModel`. Please check the
71 | superclass for the appropriate documentation alongside usage examples.
72 | """
73 |
74 | config_class = XLMRobertaConfig
75 | pretrained_model_archive_map = TF_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
76 |
77 |
78 | @add_start_docstrings(
79 | """XLM-RoBERTa Model with a `language modeling` head on top. """, XLM_ROBERTA_START_DOCSTRING,
80 | )
81 | class TFXLMRobertaForMaskedLM(TFRobertaForMaskedLM):
82 | """
83 | This class overrides :class:`~transformers.TFRobertaForMaskedLM`. Please check the
84 | superclass for the appropriate documentation alongside usage examples.
85 | """
86 |
87 | config_class = XLMRobertaConfig
88 | pretrained_model_archive_map = TF_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
89 |
90 |
91 | @add_start_docstrings(
92 | """XLM-RoBERTa Model transformer with a sequence classification/regression head on top (a linear layer
93 | on top of the pooled output) e.g. for GLUE tasks. """,
94 | XLM_ROBERTA_START_DOCSTRING,
95 | )
96 | class TFXLMRobertaForSequenceClassification(TFRobertaForSequenceClassification):
97 | """
98 | This class overrides :class:`~transformers.TFRobertaForSequenceClassification`. Please check the
99 | superclass for the appropriate documentation alongside usage examples.
100 | """
101 |
102 | config_class = XLMRobertaConfig
103 | pretrained_model_archive_map = TF_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
104 |
105 |
106 | @add_start_docstrings(
107 | """XLM-RoBERTa Model with a token classification head on top (a linear layer on top of
108 | the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
109 | XLM_ROBERTA_START_DOCSTRING,
110 | )
111 | class TFXLMRobertaForTokenClassification(TFRobertaForTokenClassification):
112 | """
113 | This class overrides :class:`~transformers.TFRobertaForTokenClassification`. Please check the
114 | superclass for the appropriate documentation alongside usage examples.
115 | """
116 |
117 | config_class = XLMRobertaConfig
118 | pretrained_model_archive_map = TF_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
119 |
--------------------------------------------------------------------------------
/python/src/transformers/modeling_xlm_roberta.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 Facebook AI Research and the HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """PyTorch XLM-RoBERTa model. """
17 |
18 |
19 | import logging
20 |
21 | from .configuration_xlm_roberta import XLMRobertaConfig
22 | from .file_utils import add_start_docstrings
23 | from .modeling_roberta import (
24 | RobertaForMaskedLM,
25 | RobertaForMultipleChoice,
26 | RobertaForSequenceClassification,
27 | RobertaForTokenClassification,
28 | RobertaModel,
29 | )
30 |
31 |
32 | logger = logging.getLogger(__name__)
33 |
34 | XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP = {
35 | "xlm-roberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-base-pytorch_model.bin",
36 | "xlm-roberta-large": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-pytorch_model.bin",
37 | "xlm-roberta-large-finetuned-conll02-dutch": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-dutch-pytorch_model.bin",
38 | "xlm-roberta-large-finetuned-conll02-spanish": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-spanish-pytorch_model.bin",
39 | "xlm-roberta-large-finetuned-conll03-english": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-english-pytorch_model.bin",
40 | "xlm-roberta-large-finetuned-conll03-german": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-german-pytorch_model.bin",
41 | }
42 |
43 |
44 | XLM_ROBERTA_START_DOCSTRING = r"""
45 |
46 | This model is a PyTorch `torch.nn.Module `_ sub-class.
47 | Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
48 | usage and behavior.
49 |
50 | Parameters:
51 | config (:class:`~transformers.XLMRobertaConfig`): Model configuration class with all the parameters of the
52 | model. Initializing with a config file does not load the weights associated with the model, only the configuration.
53 | Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
54 | """
55 |
56 |
57 | @add_start_docstrings(
58 | "The bare XLM-RoBERTa Model transformer outputting raw hidden-states without any specific head on top.",
59 | XLM_ROBERTA_START_DOCSTRING,
60 | )
61 | class XLMRobertaModel(RobertaModel):
62 | """
63 | This class overrides :class:`~transformers.RobertaModel`. Please check the
64 | superclass for the appropriate documentation alongside usage examples.
65 | """
66 |
67 | config_class = XLMRobertaConfig
68 | pretrained_model_archive_map = XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
69 |
70 |
71 | @add_start_docstrings(
72 | """XLM-RoBERTa Model with a `language modeling` head on top. """, XLM_ROBERTA_START_DOCSTRING,
73 | )
74 | class XLMRobertaForMaskedLM(RobertaForMaskedLM):
75 | """
76 | This class overrides :class:`~transformers.RobertaForMaskedLM`. Please check the
77 | superclass for the appropriate documentation alongside usage examples.
78 | """
79 |
80 | config_class = XLMRobertaConfig
81 | pretrained_model_archive_map = XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
82 |
83 |
84 | @add_start_docstrings(
85 | """XLM-RoBERTa Model transformer with a sequence classification/regression head on top (a linear layer
86 | on top of the pooled output) e.g. for GLUE tasks. """,
87 | XLM_ROBERTA_START_DOCSTRING,
88 | )
89 | class XLMRobertaForSequenceClassification(RobertaForSequenceClassification):
90 | """
91 | This class overrides :class:`~transformers.RobertaForSequenceClassification`. Please check the
92 | superclass for the appropriate documentation alongside usage examples.
93 | """
94 |
95 | config_class = XLMRobertaConfig
96 | pretrained_model_archive_map = XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
97 |
98 |
99 | @add_start_docstrings(
100 | """XLM-RoBERTa Model with a multiple choice classification head on top (a linear layer on top of
101 | the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
102 | XLM_ROBERTA_START_DOCSTRING,
103 | )
104 | class XLMRobertaForMultipleChoice(RobertaForMultipleChoice):
105 | """
106 | This class overrides :class:`~transformers.RobertaForMultipleChoice`. Please check the
107 | superclass for the appropriate documentation alongside usage examples.
108 | """
109 |
110 | config_class = XLMRobertaConfig
111 | pretrained_model_archive_map = XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
112 |
113 |
114 | @add_start_docstrings(
115 | """XLM-RoBERTa Model with a token classification head on top (a linear layer on top of
116 | the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
117 | XLM_ROBERTA_START_DOCSTRING,
118 | )
119 | class XLMRobertaForTokenClassification(RobertaForTokenClassification):
120 | """
121 | This class overrides :class:`~transformers.RobertaForTokenClassification`. Please check the
122 | superclass for the appropriate documentation alongside usage examples.
123 | """
124 |
125 | config_class = XLMRobertaConfig
126 | pretrained_model_archive_map = XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
127 |
--------------------------------------------------------------------------------
/python/src/transformers/optimization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """PyTorch optimization for BERT model."""
16 |
17 | import logging
18 | import math
19 |
20 | import torch
21 | from torch.optim import Optimizer
22 | from torch.optim.lr_scheduler import LambdaLR
23 |
24 |
25 | logger = logging.getLogger(__name__)
26 |
27 |
28 | def get_constant_schedule(optimizer, last_epoch=-1):
29 | """ Create a schedule with a constant learning rate.
30 | """
31 | return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch)
32 |
33 |
34 | def get_constant_schedule_with_warmup(optimizer, num_warmup_steps, last_epoch=-1):
35 | """ Create a schedule with a constant learning rate preceded by a warmup
36 | period during which the learning rate increases linearly between 0 and 1.
37 | """
38 |
39 | def lr_lambda(current_step):
40 | if current_step < num_warmup_steps:
41 | return float(current_step) / float(max(1.0, num_warmup_steps))
42 | return 1.0
43 |
44 | return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
45 |
46 |
47 | def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
48 | """ Create a schedule with a learning rate that decreases linearly after
49 | linearly increasing during a warmup period.
50 | """
51 |
52 | def lr_lambda(current_step):
53 | if current_step < num_warmup_steps:
54 | return float(current_step) / float(max(1, num_warmup_steps))
55 | return max(
56 | 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
57 | )
58 |
59 | return LambdaLR(optimizer, lr_lambda, last_epoch)
60 |
61 |
62 | def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=0.5, last_epoch=-1):
63 | """ Create a schedule with a learning rate that decreases following the
64 | values of the cosine function between 0 and `pi * cycles` after a warmup
65 | period during which it increases linearly between 0 and 1.
66 | """
67 |
68 | def lr_lambda(current_step):
69 | if current_step < num_warmup_steps:
70 | return float(current_step) / float(max(1, num_warmup_steps))
71 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
72 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
73 |
74 | return LambdaLR(optimizer, lr_lambda, last_epoch)
75 |
76 |
77 | def get_cosine_with_hard_restarts_schedule_with_warmup(
78 | optimizer, num_warmup_steps, num_training_steps, num_cycles=1.0, last_epoch=-1
79 | ):
80 | """ Create a schedule with a learning rate that decreases following the
81 | values of the cosine function with several hard restarts, after a warmup
82 | period during which it increases linearly between 0 and 1.
83 | """
84 |
85 | def lr_lambda(current_step):
86 | if current_step < num_warmup_steps:
87 | return float(current_step) / float(max(1, num_warmup_steps))
88 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
89 | if progress >= 1.0:
90 | return 0.0
91 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0))))
92 |
93 | return LambdaLR(optimizer, lr_lambda, last_epoch)
94 |
95 |
96 | class AdamW(Optimizer):
97 | """ Implements Adam algorithm with weight decay fix.
98 |
99 | Parameters:
100 | lr (float): learning rate. Default 1e-3.
101 | betas (tuple of 2 floats): Adams beta parameters (b1, b2). Default: (0.9, 0.999)
102 | eps (float): Adams epsilon. Default: 1e-6
103 | weight_decay (float): Weight decay. Default: 0.0
104 | correct_bias (bool): can be set to False to avoid correcting bias in Adam (e.g. like in Bert TF repository). Default True.
105 | """
106 |
107 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.0, correct_bias=True):
108 | if lr < 0.0:
109 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
110 | if not 0.0 <= betas[0] < 1.0:
111 | raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0]))
112 | if not 0.0 <= betas[1] < 1.0:
113 | raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1]))
114 | if not 0.0 <= eps:
115 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps))
116 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, correct_bias=correct_bias)
117 | super().__init__(params, defaults)
118 |
119 | def step(self, closure=None):
120 | """Performs a single optimization step.
121 |
122 | Arguments:
123 | closure (callable, optional): A closure that reevaluates the model
124 | and returns the loss.
125 | """
126 | loss = None
127 | if closure is not None:
128 | loss = closure()
129 |
130 | for group in self.param_groups:
131 | for p in group["params"]:
132 | if p.grad is None:
133 | continue
134 | grad = p.grad.data
135 | if grad.is_sparse:
136 | raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")
137 |
138 | state = self.state[p]
139 |
140 | # State initialization
141 | if len(state) == 0:
142 | state["step"] = 0
143 | # Exponential moving average of gradient values
144 | state["exp_avg"] = torch.zeros_like(p.data)
145 | # Exponential moving average of squared gradient values
146 | state["exp_avg_sq"] = torch.zeros_like(p.data)
147 |
148 | exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
149 | beta1, beta2 = group["betas"]
150 |
151 | state["step"] += 1
152 |
153 | # Decay the first and second moment running average coefficient
154 | # In-place operations to update the averages at the same time
155 | exp_avg.mul_(beta1).add_(1.0 - beta1, grad)
156 | exp_avg_sq.mul_(beta2).addcmul_(1.0 - beta2, grad, grad)
157 | denom = exp_avg_sq.sqrt().add_(group["eps"])
158 |
159 | step_size = group["lr"]
160 | if group["correct_bias"]: # No bias correction for Bert
161 | bias_correction1 = 1.0 - beta1 ** state["step"]
162 | bias_correction2 = 1.0 - beta2 ** state["step"]
163 | step_size = step_size * math.sqrt(bias_correction2) / bias_correction1
164 |
165 | p.data.addcdiv_(-step_size, exp_avg, denom)
166 |
167 | # Just adding the square of the weights to the loss function is *not*
168 | # the correct way of using L2 regularization/weight decay with Adam,
169 | # since that will interact with the m and v parameters in strange ways.
170 | #
171 | # Instead we want to decay the weights in a manner that doesn't interact
172 | # with the m/v parameters. This is equivalent to adding the square
173 | # of the weights to the loss with plain (non-momentum) SGD.
174 | # Add weight decay at the end (fixed version)
175 | if group["weight_decay"] > 0.0:
176 | p.data.add_(-group["lr"] * group["weight_decay"], p.data)
177 |
178 | return loss
179 |
--------------------------------------------------------------------------------
/python/src/transformers/tokenization_bart.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The Facebook AI Research Team Authors and The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from .tokenization_roberta import RobertaTokenizer
17 |
18 |
19 | # vocab and merges same as roberta
20 | vocab_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json"
21 | merges_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt"
22 | _all_bart_models = ["bart-large", "bart-large-mnli", "bart-large-cnn"]
23 |
24 |
25 | class BartTokenizer(RobertaTokenizer):
26 | # merges and vocab same as Roberta
27 | max_model_input_sizes = {m: 1024 for m in _all_bart_models}
28 | pretrained_vocab_files_map = {
29 | "vocab_file": {m: vocab_url for m in _all_bart_models},
30 | "merges_file": {m: merges_url for m in _all_bart_models},
31 | }
32 |
--------------------------------------------------------------------------------
/python/src/transformers/tokenization_distilbert.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tokenization classes for DistilBERT."""
16 |
17 |
18 | import logging
19 |
20 | from .tokenization_bert import BertTokenizer, BertTokenizerFast
21 |
22 |
23 | logger = logging.getLogger(__name__)
24 |
25 | VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
26 |
27 | PRETRAINED_VOCAB_FILES_MAP = {
28 | "vocab_file": {
29 | "distilbert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
30 | "distilbert-base-uncased-distilled-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
31 | "distilbert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt",
32 | "distilbert-base-cased-distilled-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt",
33 | "distilbert-base-german-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-german-cased-vocab.txt",
34 | "distilbert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
35 | }
36 | }
37 |
38 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
39 | "distilbert-base-uncased": 512,
40 | "distilbert-base-uncased-distilled-squad": 512,
41 | "distilbert-base-cased": 512,
42 | "distilbert-base-cased-distilled-squad": 512,
43 | "distilbert-base-german-cased": 512,
44 | "distilbert-base-multilingual-cased": 512,
45 | }
46 |
47 |
48 | PRETRAINED_INIT_CONFIGURATION = {
49 | "distilbert-base-uncased": {"do_lower_case": True},
50 | "distilbert-base-uncased-distilled-squad": {"do_lower_case": True},
51 | "distilbert-base-cased": {"do_lower_case": False},
52 | "distilbert-base-cased-distilled-squad": {"do_lower_case": False},
53 | "distilbert-base-german-cased": {"do_lower_case": False},
54 | "distilbert-base-multilingual-cased": {"do_lower_case": False},
55 | }
56 |
57 |
58 | class DistilBertTokenizer(BertTokenizer):
59 | r"""
60 | Constructs a DistilBertTokenizer.
61 | :class:`~transformers.DistilBertTokenizer` is identical to :class:`~transformers.BertTokenizer` and runs end-to-end
62 | tokenization: punctuation splitting + wordpiece.
63 |
64 | Refer to superclass :class:`~transformers.BertTokenizer` for usage examples and documentation concerning
65 | parameters.
66 | """
67 |
68 | vocab_files_names = VOCAB_FILES_NAMES
69 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
70 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
71 | pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
72 | model_input_names = ["attention_mask"]
73 |
74 |
75 | class DistilBertTokenizerFast(BertTokenizerFast):
76 | vocab_files_names = VOCAB_FILES_NAMES
77 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
78 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
79 | pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
80 | model_input_names = ["attention_mask"]
81 |
--------------------------------------------------------------------------------
/python/src/transformers/tokenization_flaubert.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019-present CNRS, Facebook Inc. and the HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tokenization classes for Flaubert, based on XLM."""
16 |
17 |
18 | import logging
19 | import unicodedata
20 |
21 | import six
22 |
23 | from .tokenization_xlm import XLMTokenizer
24 |
25 |
26 | logger = logging.getLogger(__name__)
27 |
28 | VOCAB_FILES_NAMES = {
29 | "vocab_file": "vocab.json",
30 | "merges_file": "merges.txt",
31 | }
32 |
33 | PRETRAINED_VOCAB_FILES_MAP = {
34 | "vocab_file": {
35 | "flaubert-small-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/flaubert/flaubert_small_cased/vocab.json",
36 | "flaubert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/flaubert/flaubert_base_uncased/vocab.json",
37 | "flaubert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/flaubert/flaubert_base_cased/vocab.json",
38 | "flaubert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/flaubert/flaubert_large_cased/vocab.json",
39 | },
40 | "merges_file": {
41 | "flaubert-small-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/flaubert/flaubert_small_cased/merges.txt",
42 | "flaubert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/flaubert/flaubert_base_uncased/merges.txt",
43 | "flaubert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/flaubert/flaubert_base_cased/merges.txt",
44 | "flaubert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/flaubert/flaubert_large_cased/merges.txt",
45 | },
46 | }
47 |
48 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
49 | "flaubert-small-cased": 512,
50 | "flaubert-base-uncased": 512,
51 | "flaubert-base-cased": 512,
52 | "flaubert-large-cased": 512,
53 | }
54 |
55 | PRETRAINED_INIT_CONFIGURATION = {
56 | "flaubert-small-cased": {"do_lowercase": False},
57 | "flaubert-base-uncased": {"do_lowercase": True},
58 | "flaubert-base-cased": {"do_lowercase": False},
59 | "flaubert-large-cased": {"do_lowercase": False},
60 | }
61 |
62 |
63 | def convert_to_unicode(text):
64 | """
65 | Converts `text` to Unicode (if it's not already), assuming UTF-8 input.
66 | """
67 | # six_ensure_text is copied from https://github.com/benjaminp/six
68 | def six_ensure_text(s, encoding="utf-8", errors="strict"):
69 | if isinstance(s, six.binary_type):
70 | return s.decode(encoding, errors)
71 | elif isinstance(s, six.text_type):
72 | return s
73 | else:
74 | raise TypeError("not expecting type '%s'" % type(s))
75 |
76 | return six_ensure_text(text, encoding="utf-8", errors="ignore")
77 |
78 |
79 | class FlaubertTokenizer(XLMTokenizer):
80 | """
81 | BPE tokenizer for Flaubert
82 |
83 | - Moses preprocessing & tokenization
84 | - Normalize all inputs text
85 | - argument ``special_tokens`` and function ``set_special_tokens``, can be used to add additional symbols \
86 | (ex: "__classify__") to a vocabulary
87 | - `do_lowercase` controle lower casing (automatically set for pretrained vocabularies)
88 |
89 | This tokenizer inherits from :class:`~transformers.XLMTokenizer`. Please check the superclass for usage examples
90 | and documentation regarding arguments.
91 | """
92 |
93 | vocab_files_names = VOCAB_FILES_NAMES
94 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
95 | pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
96 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
97 |
98 | def __init__(self, do_lowercase=False, **kwargs):
99 | super().__init__(**kwargs)
100 | self.do_lowercase = do_lowercase
101 | self.do_lowercase_and_remove_accent = False
102 |
103 | def preprocess_text(self, text):
104 | text = text.replace("``", '"').replace("''", '"')
105 | text = convert_to_unicode(text)
106 | text = unicodedata.normalize("NFC", text)
107 |
108 | if self.do_lowercase:
109 | text = text.lower()
110 |
111 | return text
112 |
113 | def _tokenize(self, text, bypass_tokenizer=False):
114 | """
115 | Tokenize a string given language code using Moses.
116 |
117 | Details of tokenization:
118 | - [sacremoses](https://github.com/alvations/sacremoses): port of Moses
119 | - Install with `pip install sacremoses`
120 |
121 | Args:
122 | - bypass_tokenizer: Allow users to preprocess and tokenize the sentences externally (default = False) (bool). If True, we only apply BPE.
123 |
124 | Returns:
125 | List of tokens.
126 | """
127 | lang = "fr"
128 | if lang and self.lang2id and lang not in self.lang2id:
129 | logger.error(
130 | "Supplied language code not found in lang2id mapping. Please check that your language is supported by the loaded pretrained model."
131 | )
132 |
133 | if bypass_tokenizer:
134 | text = text.split()
135 | else:
136 | text = self.preprocess_text(text)
137 | text = self.moses_pipeline(text, lang=lang)
138 | text = self.moses_tokenize(text, lang=lang)
139 |
140 | split_tokens = []
141 | for token in text:
142 | if token:
143 | split_tokens.extend([t for t in self.bpe(token).split(" ")])
144 |
145 | return split_tokens
146 |
--------------------------------------------------------------------------------
/python/src/transformers/tokenization_t5.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 T5 Authors and HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """ Tokenization class for model T5."""
16 |
17 |
18 | import logging
19 | import os
20 | import re
21 | from shutil import copyfile
22 |
23 | from .tokenization_utils import PreTrainedTokenizer
24 |
25 |
26 | logger = logging.getLogger(__name__)
27 |
28 | SPIECE_UNDERLINE = "▁"
29 |
30 | ####################################################
31 | # Mapping from the keyword arguments names of Tokenizer `__init__`
32 | # to file names for serializing Tokenizer instances
33 | ####################################################
34 | VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}
35 |
36 | ####################################################
37 | # Mapping from the keyword arguments names of Tokenizer `__init__`
38 | # to pretrained vocabulary URL for all the model shortcut names.
39 | ####################################################
40 | PRETRAINED_VOCAB_FILES_MAP = {
41 | "vocab_file": {
42 | "t5-small": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model",
43 | "t5-base": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model",
44 | "t5-large": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model",
45 | "t5-3b": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model",
46 | "t5-11b": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model",
47 | }
48 | }
49 |
50 | ####################################################
51 | # Mapping from model shortcut names to max length of inputs
52 | ####################################################
53 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
54 | "t5-small": 512,
55 | "t5-base": 512,
56 | "t5-large": 512,
57 | "t5-3b": 512,
58 | "t5-11b": 512,
59 | }
60 |
61 |
62 | class T5Tokenizer(PreTrainedTokenizer):
63 | """
64 | SentencePiece based tokenizer. Peculiarities:
65 |
66 | - requires `SentencePiece `_
67 | - `extra_ids` add a number of extra ids added to the end of the vocabulary for use as sentinels.
68 | These tokens are accessible as `` where `{%d}` is a number between 0 and extra_ids-1.
69 | Extra tokens are indexed from the end of the vocabulary up to beginnning ( is the last token in the vocabulary)
70 | (like in T5 preprocessing
71 | see: https://github.com/google-research/text-to-text-transfer-transformer/blob/9fd7b14a769417be33bc6c850f9598764913c833/t5/data/preprocessors.py#L2117)
72 | """
73 |
74 | vocab_files_names = VOCAB_FILES_NAMES
75 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
76 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
77 |
78 | def __init__(
79 | self,
80 | vocab_file,
81 | eos_token="",
82 | unk_token="",
83 | pad_token="",
84 | extra_ids=100,
85 | additional_special_tokens=None,
86 | **kwargs
87 | ):
88 | # Add extra_ids to the special token list
89 | if extra_ids > 0:
90 | if additional_special_tokens is None:
91 | additional_special_tokens = []
92 | additional_special_tokens.extend(["".format(i) for i in range(extra_ids)])
93 |
94 | super().__init__(
95 | eos_token=eos_token,
96 | unk_token=unk_token,
97 | pad_token=pad_token,
98 | additional_special_tokens=additional_special_tokens,
99 | **kwargs,
100 | )
101 | self.max_len_single_sentence = (
102 | self.max_len
103 | ) # no default special tokens - you can update this value if you add special tokens
104 | self.max_len_sentences_pair = (
105 | self.max_len
106 | ) # no default special tokens - you can update this value if you add special tokens
107 |
108 | try:
109 | import sentencepiece as spm
110 | except ImportError:
111 | logger.warning(
112 | "You need to install SentencePiece to use T5Tokenizer:"
113 | "https://github.com/google/sentencepiece"
114 | "pip install sentencepiece"
115 | )
116 | raise
117 |
118 | self.vocab_file = vocab_file
119 | self._extra_ids = extra_ids
120 |
121 | self.sp_model = spm.SentencePieceProcessor()
122 | self.sp_model.Load(vocab_file)
123 |
124 | @property
125 | def vocab_size(self):
126 | return self.sp_model.get_piece_size() + self._extra_ids
127 |
128 | def get_vocab(self):
129 | vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
130 | vocab.update(self.added_tokens_encoder)
131 | return vocab
132 |
133 | def __getstate__(self):
134 | state = self.__dict__.copy()
135 | state["sp_model"] = None
136 | return state
137 |
138 | def __setstate__(self, d):
139 | self.__dict__ = d
140 | try:
141 | import sentencepiece as spm
142 | except ImportError:
143 | logger.warning(
144 | "You need to install SentencePiece to use T5Tokenizer: https://github.com/google/sentencepiece"
145 | "pip install sentencepiece"
146 | )
147 | raise
148 | self.sp_model = spm.SentencePieceProcessor()
149 | self.sp_model.Load(self.vocab_file)
150 |
151 | def _tokenize(self, text, sample=False):
152 | """ Take as input a string and return a list of strings (tokens) for words/sub-words
153 | """
154 | if not sample:
155 | pieces = self.sp_model.EncodeAsPieces(text)
156 | else:
157 | pieces = self.sp_model.SampleEncodeAsPieces(text, 64, 0.1)
158 | return pieces
159 |
160 | def _convert_token_to_id(self, token):
161 | """ Converts a token (str) in an id using the vocab. """
162 | if token.startswith("", token)
164 | num = int(match.group(1))
165 | return self.vocab_size - num - 1
166 | return self.sp_model.piece_to_id(token)
167 |
168 | def _convert_id_to_token(self, index):
169 | """Converts an index (integer) in a token (str) using the vocab."""
170 | if index < self.sp_model.get_piece_size():
171 | token = self.sp_model.IdToPiece(index)
172 | else:
173 | token = "".format(self.vocab_size - 1 - index)
174 | return token
175 |
176 | def convert_tokens_to_string(self, tokens):
177 | """ Converts a sequence of tokens (string) in a single string. """
178 | out_string = self.sp_model.decode_pieces(tokens)
179 | return out_string
180 |
181 | def save_vocabulary(self, save_directory):
182 | """ Save the sentencepiece vocabulary (copy original file) and special tokens file
183 | to a directory.
184 | """
185 | if not os.path.isdir(save_directory):
186 | logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
187 | return
188 | out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_file"])
189 |
190 | if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
191 | copyfile(self.vocab_file, out_vocab_file)
192 |
193 | return (out_vocab_file,)
194 |
--------------------------------------------------------------------------------
/python/src/transformers/utils_encoder_decoder.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """ Classes to support Encoder-Decoder architectures """
16 |
17 |
18 | def prepare_encoder_decoder_model_kwargs(**kwargs):
19 | """ Prepare the encoder and decoder's keyword arguments.
20 |
21 | Keyword arguments come in 3 flavors:
22 | - encoder-specific (prefixed by `encoder_`)
23 | - decoder-specific (prefixed by `decoder_`)
24 | - those that apply to the model as whole.
25 |
26 | We let the specific kwargs override the common ones in case of
27 | conflict.
28 | """
29 |
30 | kwargs_common = {
31 | argument: value
32 | for argument, value in kwargs.items()
33 | if not argument.startswith("encoder_") and not argument.startswith("decoder_")
34 | }
35 | if "input_ids" in kwargs_common:
36 | kwargs["encoder_input_ids"] = kwargs_common.pop("input_ids")
37 |
38 | decoder_kwargs = kwargs_common.copy()
39 | encoder_kwargs = kwargs_common.copy()
40 | encoder_kwargs.update(
41 | {argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_")}
42 | )
43 | decoder_kwargs.update(
44 | {argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")}
45 | )
46 | decoder_kwargs["encoder_attention_mask"] = encoder_kwargs.get("attention_mask", None)
47 | return encoder_kwargs, decoder_kwargs
48 |
--------------------------------------------------------------------------------
/python/src/utils/analysis.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 Changmao Li
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | import json
16 | import string
17 | import re
18 | import sys
19 | from collections import Counter
20 |
21 |
22 | def normalize_answer(s):
23 | """Lower text and remove punctuation, articles and extra whitespace."""
24 | def remove_articles(text):
25 | return re.sub(r'\b(a|an|the)\b', ' ', text)
26 |
27 | def white_space_fix(text):
28 | return ' '.join(text.split())
29 |
30 | def remove_punc(text):
31 | exclude = set(string.punctuation)
32 | return ''.join(ch for ch in text if ch not in exclude)
33 |
34 | def lower(text):
35 | return text.lower()
36 |
37 | def remove_underline(text):
38 | return text.replace('_', ' ')
39 |
40 | return remove_underline(white_space_fix(remove_articles(remove_punc(lower(s)))))
41 |
42 |
43 | def f1_score(prediction, ground_truth):
44 | prediction_tokens = normalize_answer(prediction).split()
45 | ground_truth_tokens = normalize_answer(ground_truth).split()
46 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
47 | num_same = sum(common.values())
48 | if num_same == 0:
49 | return 0
50 | precision = 1.0 * num_same / len(prediction_tokens)
51 | recall = 1.0 * num_same / len(ground_truth_tokens)
52 | f1 = (2 * precision * recall) / (precision + recall)
53 | return f1
54 |
55 |
56 | def exact_match_score(prediction, ground_truth):
57 | return normalize_answer(prediction) == normalize_answer(ground_truth)
58 |
59 |
60 | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
61 | scores_for_ground_truths = []
62 | for ground_truth in ground_truths:
63 | score = metric_fn(prediction, ground_truth)
64 | scores_for_ground_truths.append(score)
65 | return max(scores_for_ground_truths)
66 |
67 |
68 | # evaluate f1 under 1 question word. For example, calculate F1 for all what questions
69 | def evaluate(dataset):
70 | f1 = exact_match = total = 0
71 | # for article in dataset:
72 | # for paragraph in article['paragraphs']:
73 | # for qa in paragraph['qas']:
74 | # total += 1
75 | # if qa['id'] not in predictions:
76 | # message = 'Unanswered question ' + qa['id'] + \
77 | # ' will receive score 0.'
78 | # print(message, file=sys.stderr)
79 | # continue
80 | # ground_truths = list(map(lambda x: x['text'], qa['answers']))
81 | # prediction = predictions[qa['id']]
82 | # exact_match += metric_max_over_ground_truths(
83 | # exact_match_score, prediction, ground_truths)
84 | # f1 += metric_max_over_ground_truths(
85 | # f1_score, prediction, ground_truths)
86 |
87 | for id, qa in dataset.items():
88 | total += 1
89 | ground_truths = list(map(lambda x: x['text'], qa['answers']))
90 | prediction = qa['prediction']
91 | exact_match += metric_max_over_ground_truths(
92 | exact_match_score, prediction, ground_truths)
93 | f1 += metric_max_over_ground_truths(
94 | f1_score, prediction, ground_truths)
95 | exact_match = 100.0 * exact_match / total
96 | f1 = 100.0 * f1 / total
97 | return {'exact_match': exact_match, 'f1': f1}
98 |
99 |
100 | def print_categorize(categorized_file):
101 | with open(categorized_file) as json_in:
102 | data = json.load(json_in)
103 | # print(data)
104 | for key, value in data.items():
105 | print(key, evaluate(value))
106 |
107 |
108 |
109 |
--------------------------------------------------------------------------------
/python/src/utils/categorizing.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 Changmao Li
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import json
17 | import copy
18 | import re
19 |
20 | categorized = {}
21 | utter_id_regex = re.compile('u[0-9][0-9][0-9]')
22 | q_words = ['What', 'When', 'Who', 'Where', 'Why', 'How']
23 |
24 |
25 | def find_utter(context, text):
26 | index = context.find(text)
27 | if index == -1:
28 | return ""
29 | index -= 4 # u001 format of utter id
30 | while index >= 0:
31 | if utter_id_regex.match(context[index:index + 4]):
32 | return context[index:index + 4]
33 | index -= 1
34 | return ""
35 |
36 |
37 | def categorizing( data_file,prediction_file, categorized_file):
38 | with open(prediction_file) as pred, open(data_file) as dev:
39 | pred_json = json.load(pred)
40 | dev_json = json.load(dev)
41 | total = same_utter = 0
42 | # what_count = when_count = who_count = where_count = why_count = how_count = 0
43 | # what_count_same = when_count_same = who_count_same = where_count_same = why_count_same = how_count_same = 0
44 | count_dict = {}
45 | for q in q_words:
46 | categorized[q] = {}
47 | count_dict[q] = 0
48 | count_dict[q + "_same"] = 0
49 |
50 | for para in dev_json['data']:
51 | # print(len(para['paragraphs']))
52 | qas = para['paragraphs'][0]['qas']
53 | context = para['paragraphs'][0]['context']
54 | # print(context)
55 | for qa in qas:
56 | # #of total questions
57 | total += 1
58 |
59 | # actual prediction
60 | if qa['id'] not in pred_json:
61 | continue
62 | prediction = pred_json[qa['id']]
63 | qa['prediction'] = prediction
64 |
65 | # predicted utter id
66 | pred_utter = find_utter(context, prediction)
67 | qa['predicted_utterance'] = pred_utter
68 | # print("prediction: ", pred_utter)
69 |
70 | # question word
71 | q_word = qa['id'].replace("_Paraphrased", "").split('_')[-1]
72 |
73 | count_dict[q_word] += 1
74 | same = False
75 | for a in qa['answers']:
76 | one_utter = find_utter(context, a['text'])
77 | # print("correct: ", one_utter)
78 |
79 | if one_utter == pred_utter:
80 | same = True
81 | same_utter += 1
82 | count_dict[q_word + "_same"] += 1
83 | break
84 | # if not same:
85 | # print("context: ", context)
86 | # print("pred utter: ", pred_utter)
87 | # print("pred: ", prediction)
88 | # print("answer: ", qa['answers'])
89 | # print("q: ", qa['question'])
90 | # print(qa['id'])
91 | # print('=' * 20)
92 | # print(qa['id'].replace("_Paraphrased", "").split('_')[-1])
93 |
94 | categorized[q_word][qa['id']] = qa
95 | # categorized[qa['id'].replace("_Paraphrased", "").split('_')[-1]].append(qa)
96 | # print(qa['id'])
97 | # print(qa)
98 | print(same_utter)
99 | print(total)
100 | print(count_dict)
101 | total_q = 0.0
102 | total_q_same = 0.0
103 | for key, value in count_dict.items():
104 | if key + "_same" in count_dict:
105 | print(key + ":", count_dict[key + "_same"] * 100.0 / count_dict[key])
106 | total_q += count_dict[key]
107 | total_q_same += count_dict[key + "_same"]
108 | print("overall: ", total_q_same * 100.0 / total_q)
109 | print(total_q)
110 | print(categorized)
111 | print(dev_json)
112 | with open(categorized_file, 'w') as out:
113 | json.dump(categorized, out, indent=2)
114 |
115 |
--------------------------------------------------------------------------------
/python/src/utils/evaluate_split_context.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 Changmao Li
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from __future__ import print_function
17 | from collections import Counter
18 | import string
19 | import re
20 | import argparse
21 | import json
22 | import os
23 | from tqdm import tqdm
24 |
25 |
26 | def eval_best(data_file, result_file):
27 | with open(os.path.join(data_file), "r", encoding="utf-8") as reader:
28 | input_data = json.load(reader)
29 | with open(os.path.join(result_file), "r", encoding="utf-8") as reader:
30 | result_data = json.load(reader)
31 | data = input_data["data"]
32 | f1 = exact_match = total = 0
33 | for i in tqdm(range(len(data))):
34 | content_list = []
35 | utterances = data[i]["paragraphs"][0]["utterances:"]
36 | qas = data[i]["paragraphs"][0]["qas"]
37 | for ui, utterance in enumerate(utterances):
38 | speaker = utterance["speakers"][0].split(" ")
39 | if len(speaker) >= 2:
40 | speaker = speaker[0] + "_" + speaker[1]
41 | else:
42 | speaker = speaker[0]
43 | u_text = "u" + str(ui) + " " + speaker + " " + utterance["utterance"]
44 | content_list.append(u_text)
45 | for qa in qas:
46 | total += 1
47 | q_id = qa["id"]
48 | result = result_data[q_id]
49 | pred_uid = result["uid"]
50 | pred_left = result["inner_left"]
51 | pred_right = result["inner_right"]
52 | pred_utterance = None
53 | pred_answer_text = None
54 | answers = qa["answers"]
55 | answer_texts = []
56 | answer_uids = []
57 | for answer in answers:
58 | answer_texts.append(answer["answer_text"])
59 | answer_uids.append(answer["utterance_id"])
60 | if 0 <= pred_uid < len(content_list) and pred_uid in answer_uids:
61 | pred_utterance = content_list[pred_uid]
62 | if pred_utterance:
63 | pred_u_tokens = pred_utterance.split(" ")
64 | if pred_left <= pred_right and 0 <= pred_left < len(pred_u_tokens) and 0 <= pred_right < len(pred_u_tokens):
65 | pred_answer_text = " ".join(pred_u_tokens[pred_left:pred_right + 1])
66 | else:
67 | pred_answer_text = pred_utterance
68 |
69 | if pred_answer_text:
70 | f1 += metric_max_over_ground_truths(f1_score, pred_answer_text, answer_texts)
71 | exact_match += metric_max_over_ground_truths(
72 | exact_match_score, pred_answer_text, answer_texts)
73 | exact_match = 100.0 * exact_match / total
74 | f1 = 100.0 * f1 / total
75 | return {'exact_match': exact_match, 'f1': f1}
76 |
77 |
78 | def eval_n_best(data_file, result_file):
79 | with open(os.path.join(data_file), "r", encoding="utf-8") as reader:
80 | input_data = json.load(reader)
81 | with open(os.path.join(result_file), "r", encoding="utf-8") as reader:
82 | result_data = json.load(reader)
83 | data = input_data["data"]
84 | f1 = exact_match = total = 0
85 | for i in tqdm(range(len(data))):
86 | content_list = []
87 | utterances = data[i]["paragraphs"][0]["utterances:"]
88 | qas = data[i]["paragraphs"][0]["qas"]
89 | for ui, utterance in enumerate(utterances):
90 | speaker = utterance["speakers"][0].split(" ")
91 | if len(speaker) >= 2:
92 | speaker = speaker[0] + "_" + speaker[1]
93 | else:
94 | speaker = speaker[0]
95 | u_text = "u" + str(ui) + " " + speaker + " " + utterance["utterance"]
96 | content_list.append(u_text)
97 | for qa in qas:
98 | total += 1
99 | q_id = qa["id"]
100 | results = result_data[q_id]
101 | f1s = []
102 | exact_matches = []
103 | for result in results:
104 | pred_uid = result["uid"]
105 | pred_left = result["inner_left"]
106 | pred_right = result["inner_right"]
107 | pred_utterance = None
108 | pred_answer_text = None
109 | answers = qa["answers"]
110 | answer_texts = []
111 | answer_uids = []
112 | for answer in answers:
113 | answer_texts.append(answer["answer_text"])
114 | answer_uids.append(answer["utterance_id"])
115 | if 0 <= pred_uid < len(content_list) and pred_uid in answer_uids:
116 | pred_utterance = content_list[pred_uid]
117 | if pred_utterance:
118 | pred_u_tokens = pred_utterance.split(" ")
119 | if pred_left <= pred_right and 0 <= pred_left < len(pred_u_tokens) and 0 <= pred_right < len(
120 | pred_u_tokens):
121 | pred_answer_text = " ".join(pred_u_tokens[pred_left:pred_right + 1])
122 | else:
123 | pred_answer_text = pred_utterance
124 | if pred_answer_text:
125 | f1s.append(metric_max_over_ground_truths(f1_score, pred_answer_text, answer_texts))
126 | exact_matches.append(metric_max_over_ground_truths(
127 | exact_match_score, pred_answer_text, answer_texts))
128 | if len(f1s) > 0:
129 | f1 += max(f1s)
130 | exact_match += max(exact_matches)
131 | exact_match = 100.0 * exact_match / total
132 | f1 = 100.0 * f1 / total
133 | return {'exact_match': exact_match, 'f1': f1}
134 |
135 |
136 | def normalize_answer(s):
137 | """Lower text and remove punctuation, articles and extra whitespace."""
138 | def remove_articles(text):
139 | return re.sub(r'\b(a|an|the)\b', ' ', text)
140 |
141 | def white_space_fix(text):
142 | return ' '.join(text.split())
143 |
144 | def remove_punc(text):
145 | exclude = set(string.punctuation)
146 | return ''.join(ch for ch in text if ch not in exclude or ch == '_')
147 |
148 | def lower(text):
149 | return text.lower()
150 |
151 | def remove_underline(text):
152 | return text.replace('_', ' ')
153 |
154 | return remove_underline(white_space_fix(remove_articles(remove_punc(lower(s)))))
155 |
156 |
157 | def f1_score(prediction, ground_truth):
158 | prediction_tokens = normalize_answer(prediction).split()
159 | ground_truth_tokens = normalize_answer(ground_truth).split()
160 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
161 | num_same = sum(common.values())
162 | if num_same == 0:
163 | return 0
164 | precision = 1.0 * num_same / len(prediction_tokens)
165 | recall = 1.0 * num_same / len(ground_truth_tokens)
166 | f1 = (2 * precision * recall) / (precision + recall)
167 | return f1
168 |
169 |
170 | def exact_match_score(prediction, ground_truth):
171 | return (normalize_answer(prediction) == normalize_answer(ground_truth))
172 |
173 |
174 | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
175 | scores_for_ground_truths = []
176 | for ground_truth in ground_truths:
177 | score = metric_fn(prediction, ground_truth)
178 | scores_for_ground_truths.append(score)
179 | return max(scores_for_ground_truths)
180 |
181 |
182 | if __name__ == '__main__':
183 | parser = argparse.ArgumentParser(description='Evaluation for FriendsQA split context ')
184 | parser.add_argument('dataset_file', help='Dataset file')
185 | parser.add_argument('prediction_file', help='Prediction File')
186 | parser.add_argument("--do_best_eval", action="store_true", help="do n best evaluate")
187 | args = parser.parse_args()
188 | if args.do_best_eval:
189 | print(json.dumps(eval_best(args.dataset_file, args.prediction_file)))
190 | else:
191 | print(json.dumps(eval_n_best(args.dataset_file, args.prediction_file)))
--------------------------------------------------------------------------------
/python/src/utils/evaluate_whole_context.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Official evaluation script for v1.1 of the SQuAD dataset.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from __future__ import print_function
17 | from collections import Counter
18 | import string
19 | import re
20 | import argparse
21 | import json
22 | import sys
23 |
24 | def normalize_answer(s):
25 | """Lower text and remove punctuation, articles and extra whitespace."""
26 | def remove_articles(text):
27 | return re.sub(r'\b(a|an|the)\b', ' ', text)
28 |
29 | def white_space_fix(text):
30 | return ' '.join(text.split())
31 |
32 | def remove_punc(text):
33 | exclude = set(string.punctuation)
34 | return ''.join(ch for ch in text if ch not in exclude or ch == '_')
35 |
36 | def lower(text):
37 | return text.lower()
38 |
39 | def remove_underline(text):
40 | return text.replace('_', ' ')
41 |
42 | return remove_underline(white_space_fix(remove_articles(remove_punc(lower(s)))))
43 |
44 |
45 | def f1_score(prediction, ground_truth):
46 | prediction_tokens = normalize_answer(prediction).split()
47 | ground_truth_tokens = normalize_answer(ground_truth).split()
48 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
49 | num_same = sum(common.values())
50 | if num_same == 0:
51 | return 0
52 | precision = 1.0 * num_same / len(prediction_tokens)
53 | recall = 1.0 * num_same / len(ground_truth_tokens)
54 | f1 = (2 * precision * recall) / (precision + recall)
55 | return f1
56 |
57 |
58 | def exact_match_score(prediction, ground_truth):
59 | return (normalize_answer(prediction) == normalize_answer(ground_truth))
60 |
61 |
62 | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
63 | scores_for_ground_truths = []
64 | for ground_truth in ground_truths:
65 | score = metric_fn(prediction, ground_truth)
66 | scores_for_ground_truths.append(score)
67 | return max(scores_for_ground_truths)
68 |
69 |
70 | def evaluate(dataset, predictions):
71 | f1 = exact_match = total = 0
72 | for article in dataset:
73 | for paragraph in article['paragraphs']:
74 | for qa in paragraph['qas']:
75 | total += 1
76 | if qa['id'] not in predictions:
77 | message = 'Unanswered question ' + qa['id'] + \
78 | ' will receive score 0.'
79 | print(message, file=sys.stderr)
80 | continue
81 | ground_truths = list(map(lambda x: x['text'], qa['answers']))
82 | prediction = predictions[qa['id']]
83 | exact_match += metric_max_over_ground_truths(
84 | exact_match_score, prediction, ground_truths)
85 | f1 += metric_max_over_ground_truths(
86 | f1_score, prediction, ground_truths)
87 |
88 | exact_match = 100.0 * exact_match / total
89 | f1 = 100.0 * f1 / total
90 |
91 | return {'exact_match': exact_match, 'f1': f1}
92 |
93 |
94 | if __name__ == '__main__':
95 | expected_version = '1.1'
96 | parser = argparse.ArgumentParser(
97 | description='Evaluation for FriendsQA whole context ' + expected_version)
98 | parser.add_argument('dataset_file', help='Dataset file')
99 | parser.add_argument('prediction_file', help='Prediction File')
100 | args = parser.parse_args()
101 | with open(args.dataset_file) as dataset_file:
102 | dataset_json = json.load(dataset_file)
103 | if (dataset_json['version'] != expected_version):
104 | print('Evaluation expects v-' + expected_version +
105 | ', but got dataset with v-' + dataset_json['version'],
106 | file=sys.stderr)
107 | dataset = dataset_json['data']
108 | with open(args.prediction_file) as prediction_file:
109 | predictions = json.load(prediction_file)
110 | print(json.dumps(evaluate(dataset, predictions)))
111 |
--------------------------------------------------------------------------------
/python/src/utils/test_model.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 Changmao Li
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import torch
17 | import numpy as np
18 | from src.transformers import BertConfig, BertForUtteranceLanguageModeling, BertForUtteranceOrderPrediction, \
19 | BertForDialogueSpanQuestionAnswering, RobertaForUtteranceOrderPrediction, RobertaForUtteranceLanguageModeling, \
20 | RobertaForDialogueSpanQuestionAnswering, RobertaConfig
21 | from torchsummaryX import summary
22 |
23 |
24 | def test_BertForUtteranceLanguageModeling():
25 | num_samples = 5
26 | seq_len = 64
27 | config = BertConfig(max_position_embeddings=512)
28 | model = BertForUtteranceLanguageModeling(config)
29 | input_ids = np.ones(shape=(num_samples, seq_len), dtype=np.int32)
30 | attention_masks = np.ones(shape=(num_samples, seq_len), dtype=np.int32)
31 | label_ids = np.full(shape=(num_samples,), dtype=np.int32, fill_value=0)
32 | model.cpu()
33 | model.float()
34 | summary(model, torch.tensor(input_ids.astype(np.int64)),
35 | torch.tensor(attention_masks.astype(np.int64)),
36 | torch.tensor(label_ids.astype(np.int64))
37 | )
38 |
39 | def test_BertForUtteranceOrderPrediction():
40 | num_samples = 5
41 | num_utterances = 6
42 | seq_len = 64
43 | max_utterances = 10
44 | config = BertConfig(max_position_embeddings=512)
45 | utterance_config = BertConfig(max_position_embeddings=max_utterances+1, num_hidden_layers=2)
46 | model = BertForUtteranceOrderPrediction(config, utterance_config, max_utterances)
47 | utterances_input_ids = np.ones(shape=(num_samples, num_utterances, seq_len), dtype=np.int32)
48 | attention_masks = np.ones(shape=(num_samples, num_utterances, seq_len), dtype=np.int32)
49 | label_ids = np.full(shape=(num_samples, ), dtype=np.int32, fill_value=0)
50 | model.cpu()
51 | model.float()
52 | summary(model, torch.tensor(utterances_input_ids.astype(np.int64)),
53 | torch.tensor(attention_masks.astype(np.int64)),
54 | torch.tensor(label_ids.astype(np.int64))
55 | )
56 |
57 | def test_BertForDialogueSpanQuestionAnswering():
58 | num_samples = 5
59 | num_utterances = 6
60 | seq_len = 64
61 | max_utterances = 10
62 | config = BertConfig(max_position_embeddings=512)
63 | utterance_config = BertConfig(max_position_embeddings=max_utterances+1, num_hidden_layers=2)
64 | model = BertForDialogueSpanQuestionAnswering(config, utterance_config, max_utterances, seq_len)
65 | utterances_input_ids = np.ones(shape=(num_samples, num_utterances, seq_len), dtype=np.int32)
66 | attention_masks = np.ones(shape=(num_samples, num_utterances, seq_len), dtype=np.int32)
67 | question_input_ids = np.ones(shape=(num_samples, 512), dtype=np.int32)
68 | question_attention_masks = np.ones(shape=(num_samples, 512), dtype=np.int32)
69 | left_ids = np.full(shape=(num_samples, num_utterances), dtype=np.int32, fill_value=1)
70 | right_ids = np.full(shape=(num_samples, num_utterances), dtype=np.int32, fill_value=1)
71 | label_ids = np.full(shape=(num_samples,), dtype=np.int32, fill_value=1)
72 | model.cpu()
73 | model.float()
74 | summary(model,
75 | torch.tensor(question_input_ids.astype(np.int64)),
76 | torch.tensor(utterances_input_ids.astype(np.int64)),
77 | torch.tensor(question_attention_masks.astype(np.int64)),
78 | torch.tensor(attention_masks.astype(np.int64)),
79 | torch.tensor(label_ids.astype(np.int64)),
80 | torch.tensor(left_ids.astype(np.int64)),
81 | torch.tensor(right_ids.astype(np.int64)),
82 | )
83 |
84 | def test_RobertaForUtteranceLanguageModeling():
85 | num_samples = 5
86 | seq_len = 64
87 | config = RobertaConfig(max_position_embeddings=512)
88 | model = RobertaForUtteranceLanguageModeling(config)
89 | input_ids = np.ones(shape=(num_samples, seq_len), dtype=np.int32)
90 | attention_masks = np.ones(shape=(num_samples, seq_len), dtype=np.int32)
91 | label_ids = np.full(shape=(num_samples,), dtype=np.int32, fill_value=0)
92 | model.cpu()
93 | model.float()
94 | summary(model, torch.tensor(input_ids.astype(np.int64)),
95 | torch.tensor(attention_masks.astype(np.int64)),
96 | torch.tensor(label_ids.astype(np.int64))
97 | )
98 |
99 |
100 | def test_RobertaForUtteranceOrderPrediction():
101 | num_samples = 5
102 | num_utterances = 6
103 | seq_len = 64
104 | max_utterances = 10
105 | config = RobertaConfig(max_position_embeddings=512)
106 | utterance_config = BertConfig(max_position_embeddings=max_utterances+1, num_hidden_layers=2)
107 | model = RobertaForUtteranceOrderPrediction(config, utterance_config, max_utterances)
108 | utterances_input_ids = np.ones(shape=(num_samples, num_utterances, seq_len), dtype=np.int32)
109 | attention_masks = np.ones(shape=(num_samples, num_utterances, seq_len), dtype=np.int32)
110 | label_ids = np.full(shape=(num_samples, ), dtype=np.int32, fill_value=0)
111 | model.cpu()
112 | model.float()
113 | summary(model, torch.tensor(utterances_input_ids.astype(np.int64)),
114 | torch.tensor(attention_masks.astype(np.int64)),
115 | torch.tensor(label_ids.astype(np.int64))
116 | )
117 |
118 | def test_RobertaForDialogueSpanQuestionAnswering():
119 | num_samples = 5
120 | num_utterances = 6
121 | seq_len = 64
122 | max_utterances = 10
123 | config = RobertaConfig(max_position_embeddings=512)
124 | utterance_config = BertConfig(max_position_embeddings=max_utterances+1, num_hidden_layers=2)
125 | model = RobertaForDialogueSpanQuestionAnswering(config, utterance_config, max_utterances, seq_len)
126 | utterances_input_ids = np.ones(shape=(num_samples, num_utterances, seq_len), dtype=np.int32)
127 | attention_masks = np.ones(shape=(num_samples, num_utterances, seq_len), dtype=np.int32)
128 | question_input_ids = np.ones(shape=(num_samples, 512), dtype=np.int32)
129 | question_attention_masks = np.ones(shape=(num_samples, 512), dtype=np.int32)
130 | left_ids = np.full(shape=(num_samples, num_utterances), dtype=np.int32, fill_value=1)
131 | right_ids = np.full(shape=(num_samples, num_utterances), dtype=np.int32, fill_value=1)
132 | label_ids = np.full(shape=(num_samples,), dtype=np.int32, fill_value=1)
133 | model.cpu()
134 | model.float()
135 | summary(model,
136 | torch.tensor(question_input_ids.astype(np.int64)),
137 | torch.tensor(utterances_input_ids.astype(np.int64)),
138 | torch.tensor(question_attention_masks.astype(np.int64)),
139 | torch.tensor(attention_masks.astype(np.int64)),
140 | torch.tensor(label_ids.astype(np.int64)),
141 | torch.tensor(left_ids.astype(np.int64)),
142 | torch.tensor(right_ids.astype(np.int64)),
143 | )
144 |
145 |
146 | if __name__ == "__main__":
147 | test_BertForUtteranceLanguageModeling()
148 | test_BertForUtteranceOrderPrediction()
149 | test_BertForDialogueSpanQuestionAnswering()
150 | test_RobertaForUtteranceOrderPrediction()
151 | test_RobertaForUtteranceLanguageModeling()
152 | test_RobertaForDialogueSpanQuestionAnswering()
--------------------------------------------------------------------------------