├── .idea ├── CCF-BDCI-ABSA.iml ├── misc.xml ├── modules.xml ├── vcs.xml └── workspace.xml ├── Further_pretraining ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── __init__.py ├── create_ernie_pretraining_data.py ├── create_pretraining_data.py ├── extract_features.py ├── modeling.py ├── modeling_test.py ├── multilingual.md ├── optimization.py ├── optimization_test.py ├── predicting_movie_reviews_with_bert_on_tf_hub.ipynb ├── requirements.txt ├── run_classifier.py ├── run_classifier_with_tfhub.py ├── run_pretrain_11_8.out ├── run_pretraining.py ├── run_squad.py ├── sample_text.txt ├── tokenization.py └── tokenization_test.py ├── Fusion_model.py ├── README.md ├── model_1_bert_att_drop_42.py ├── model_2_bert_att_drop_further_pretrain.py ├── model_3_roberte_wwm_ext_att_drop_42.py ├── model_4_bert_att_drop_420.py ├── model_5_bert_att_drop_1001001.py ├── postprocess.py ├── predict_model_1_bert_att_drop_42.py ├── predict_model_2_bert_att_drop_further_pretrain.py ├── predict_model_3_roberte_wwm_ext_att_drop_42.py ├── predict_model_4_bert_att_drop_420.py ├── predict_model_5_bert_att_drop_1001001.py ├── preprocess.py ├── shell ├── get_pretrain_data.sh └── run_pretrain.sh ├── transformers ├── __init__.py ├── __main__.py ├── configuration_auto.py ├── configuration_bert.py ├── configuration_camembert.py ├── configuration_ctrl.py ├── configuration_distilbert.py ├── configuration_gpt2.py ├── configuration_openai.py ├── configuration_roberta.py ├── configuration_transfo_xl.py ├── configuration_utils.py ├── configuration_xlm.py ├── configuration_xlnet.py ├── convert_bert_original_tf_checkpoint_to_pytorch.py ├── convert_bert_pytorch_checkpoint_to_original_tf.py ├── convert_gpt2_original_tf_checkpoint_to_pytorch.py ├── convert_openai_original_tf_checkpoint_to_pytorch.py ├── convert_pytorch_checkpoint_to_tf2.py ├── convert_roberta_original_pytorch_checkpoint_to_pytorch.py ├── convert_transfo_xl_original_tf_checkpoint_to_pytorch.py ├── convert_xlm_original_pytorch_checkpoint_to_pytorch.py ├── convert_xlnet_original_tf_checkpoint_to_pytorch.py ├── data │ ├── __init__.py │ ├── metrics │ │ └── __init__.py │ └── processors │ │ ├── __init__.py │ │ ├── glue.py │ │ └── utils.py ├── file_utils.py ├── modeling_auto.py ├── modeling_beam_search.py ├── modeling_bert.py ├── modeling_camembert.py ├── modeling_ctrl.py ├── modeling_distilbert.py ├── modeling_encoder_decoder.py ├── modeling_gpt2.py ├── modeling_openai.py ├── modeling_roberta.py ├── modeling_tf_auto.py ├── modeling_tf_bert.py ├── modeling_tf_ctrl.py ├── modeling_tf_distilbert.py ├── modeling_tf_gpt2.py ├── modeling_tf_openai.py ├── modeling_tf_pytorch_utils.py ├── modeling_tf_roberta.py ├── modeling_tf_transfo_xl.py ├── modeling_tf_transfo_xl_utilities.py ├── modeling_tf_utils.py ├── modeling_tf_xlm.py ├── modeling_tf_xlnet.py ├── modeling_transfo_xl.py ├── modeling_transfo_xl_utilities.py ├── modeling_utils.py ├── modeling_xlm.py ├── modeling_xlnet.py ├── optimization.py ├── tests │ ├── __init__.py │ ├── configuration_common_test.py │ ├── conftest.py │ ├── fixtures │ │ ├── input.txt │ │ ├── sample_text.txt │ │ └── test_sentencepiece.model │ ├── modeling_auto_test.py │ ├── modeling_bert_test.py │ ├── modeling_common_test.py │ ├── modeling_ctrl_test.py │ ├── modeling_distilbert_test.py │ ├── modeling_encoder_decoder_test.py │ ├── modeling_gpt2_test.py │ ├── modeling_openai_test.py │ ├── modeling_roberta_test.py │ ├── modeling_tf_auto_test.py │ ├── modeling_tf_bert_test.py │ ├── modeling_tf_common_test.py │ ├── modeling_tf_ctrl_test.py │ ├── modeling_tf_distilbert_test.py │ ├── modeling_tf_gpt2_test.py │ ├── modeling_tf_openai_gpt_test.py │ ├── modeling_tf_roberta_test.py │ ├── modeling_tf_transfo_xl_test.py │ ├── modeling_tf_xlm_test.py │ ├── modeling_tf_xlnet_test.py │ ├── modeling_transfo_xl_test.py │ ├── modeling_xlm_test.py │ ├── modeling_xlnet_test.py │ ├── optimization_test.py │ ├── tokenization_auto_test.py │ ├── tokenization_bert_test.py │ ├── tokenization_ctrl_test.py │ ├── tokenization_distilbert_test.py │ ├── tokenization_gpt2_test.py │ ├── tokenization_openai_test.py │ ├── tokenization_roberta_test.py │ ├── tokenization_tests_commons.py │ ├── tokenization_transfo_xl_test.py │ ├── tokenization_utils_test.py │ ├── tokenization_xlm_test.py │ └── tokenization_xlnet_test.py ├── tokenization_auto.py ├── tokenization_bert.py ├── tokenization_camembert.py ├── tokenization_ctrl.py ├── tokenization_distilbert.py ├── tokenization_gpt2.py ├── tokenization_openai.py ├── tokenization_roberta.py ├── tokenization_transfo_xl.py ├── tokenization_utils.py ├── tokenization_xlm.py └── tokenization_xlnet.py └── 【2019 CCF BDCI】-负面判定-登峰造极-答辩PPT-最终版.pptx /.idea/CCF-BDCI-ABSA.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /Further_pretraining/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | BERT needs to maintain permanent compatibility with the pre-trained model files, 4 | so we do not plan to make any major changes to this library (other than what was 5 | promised in the README). However, we can accept small patches related to 6 | re-factoring and documentation. To submit contributes, there are just a few 7 | small guidelines you need to follow. 8 | 9 | ## Contributor License Agreement 10 | 11 | Contributions to this project must be accompanied by a Contributor License 12 | Agreement. You (or your employer) retain the copyright to your contribution; 13 | this simply gives us permission to use and redistribute your contributions as 14 | part of the project. Head over to to see 15 | your current agreements on file or to sign a new one. 16 | 17 | You generally only need to submit a CLA once, so if you've already submitted one 18 | (even if it was for a different project), you probably don't need to do it 19 | again. 20 | 21 | ## Code reviews 22 | 23 | All submissions, including submissions by project members, require review. We 24 | use GitHub pull requests for this purpose. Consult 25 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 26 | information on using pull requests. 27 | 28 | ## Community Guidelines 29 | 30 | This project follows 31 | [Google's Open Source Community Guidelines](https://opensource.google.com/conduct/). 32 | -------------------------------------------------------------------------------- /Further_pretraining/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /Further_pretraining/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Functions and classes related to optimization (weight updates).""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import re 22 | import tensorflow as tf 23 | 24 | 25 | def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu): 26 | """Creates an optimizer training op.""" 27 | global_step = tf.train.get_or_create_global_step() 28 | 29 | learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32) 30 | 31 | # Implements linear decay of the learning rate. 32 | learning_rate = tf.train.polynomial_decay( 33 | learning_rate, 34 | global_step, 35 | num_train_steps, 36 | end_learning_rate=0.0, 37 | power=1.0, 38 | cycle=False) 39 | 40 | # Implements linear warmup. I.e., if global_step < num_warmup_steps, the 41 | # learning rate will be `global_step/num_warmup_steps * init_lr`. 42 | if num_warmup_steps: 43 | global_steps_int = tf.cast(global_step, tf.int32) 44 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32) 45 | 46 | global_steps_float = tf.cast(global_steps_int, tf.float32) 47 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32) 48 | 49 | warmup_percent_done = global_steps_float / warmup_steps_float 50 | warmup_learning_rate = init_lr * warmup_percent_done 51 | 52 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32) 53 | learning_rate = ( 54 | (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate) 55 | 56 | # It is recommended that you use this optimizer for fine tuning, since this 57 | # is how the model was trained (note that the Adam m/v variables are NOT 58 | # loaded from init_checkpoint.) 59 | optimizer = AdamWeightDecayOptimizer( 60 | learning_rate=learning_rate, 61 | weight_decay_rate=0.01, 62 | beta_1=0.9, 63 | beta_2=0.999, 64 | epsilon=1e-6, 65 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) 66 | 67 | if use_tpu: 68 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) 69 | 70 | tvars = tf.trainable_variables() 71 | grads = tf.gradients(loss, tvars) 72 | 73 | # This is how the model was pre-trained. 74 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) 75 | 76 | train_op = optimizer.apply_gradients( 77 | zip(grads, tvars), global_step=global_step) 78 | 79 | # Normally the global step update is done inside of `apply_gradients`. 80 | # However, `AdamWeightDecayOptimizer` doesn't do this. But if you use 81 | # a different optimizer, you should probably take this line out. 82 | new_global_step = global_step + 1 83 | train_op = tf.group(train_op, [global_step.assign(new_global_step)]) 84 | return train_op 85 | 86 | 87 | class AdamWeightDecayOptimizer(tf.train.Optimizer): 88 | """A basic Adam optimizer that includes "correct" L2 weight decay.""" 89 | 90 | def __init__(self, 91 | learning_rate, 92 | weight_decay_rate=0.0, 93 | beta_1=0.9, 94 | beta_2=0.999, 95 | epsilon=1e-6, 96 | exclude_from_weight_decay=None, 97 | name="AdamWeightDecayOptimizer"): 98 | """Constructs a AdamWeightDecayOptimizer.""" 99 | super(AdamWeightDecayOptimizer, self).__init__(False, name) 100 | 101 | self.learning_rate = learning_rate 102 | self.weight_decay_rate = weight_decay_rate 103 | self.beta_1 = beta_1 104 | self.beta_2 = beta_2 105 | self.epsilon = epsilon 106 | self.exclude_from_weight_decay = exclude_from_weight_decay 107 | 108 | def apply_gradients(self, grads_and_vars, global_step=None, name=None): 109 | """See base class.""" 110 | assignments = [] 111 | for (grad, param) in grads_and_vars: 112 | if grad is None or param is None: 113 | continue 114 | 115 | param_name = self._get_variable_name(param.name) 116 | 117 | m = tf.get_variable( 118 | name=param_name + "/adam_m", 119 | shape=param.shape.as_list(), 120 | dtype=tf.float32, 121 | trainable=False, 122 | initializer=tf.zeros_initializer()) 123 | v = tf.get_variable( 124 | name=param_name + "/adam_v", 125 | shape=param.shape.as_list(), 126 | dtype=tf.float32, 127 | trainable=False, 128 | initializer=tf.zeros_initializer()) 129 | 130 | # Standard Adam update. 131 | next_m = ( 132 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) 133 | next_v = ( 134 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, 135 | tf.square(grad))) 136 | 137 | update = next_m / (tf.sqrt(next_v) + self.epsilon) 138 | 139 | # Just adding the square of the weights to the loss function is *not* 140 | # the correct way of using L2 regularization/weight decay with Adam, 141 | # since that will interact with the m and v parameters in strange ways. 142 | # 143 | # Instead we want ot decay the weights in a manner that doesn't interact 144 | # with the m/v parameters. This is equivalent to adding the square 145 | # of the weights to the loss with plain (non-momentum) SGD. 146 | if self._do_use_weight_decay(param_name): 147 | update += self.weight_decay_rate * param 148 | 149 | update_with_lr = self.learning_rate * update 150 | 151 | next_param = param - update_with_lr 152 | 153 | assignments.extend( 154 | [param.assign(next_param), 155 | m.assign(next_m), 156 | v.assign(next_v)]) 157 | return tf.group(*assignments, name=name) 158 | 159 | def _do_use_weight_decay(self, param_name): 160 | """Whether to use L2 weight decay for `param_name`.""" 161 | if not self.weight_decay_rate: 162 | return False 163 | if self.exclude_from_weight_decay: 164 | for r in self.exclude_from_weight_decay: 165 | if re.search(r, param_name) is not None: 166 | return False 167 | return True 168 | 169 | def _get_variable_name(self, param_name): 170 | """Get the variable name from the tensor name.""" 171 | m = re.match("^(.*):\\d+$", param_name) 172 | if m is not None: 173 | param_name = m.group(1) 174 | return param_name 175 | -------------------------------------------------------------------------------- /Further_pretraining/optimization_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import optimization 20 | import tensorflow as tf 21 | 22 | 23 | class OptimizationTest(tf.test.TestCase): 24 | 25 | def test_adam(self): 26 | with self.test_session() as sess: 27 | w = tf.get_variable( 28 | "w", 29 | shape=[3], 30 | initializer=tf.constant_initializer([0.1, -0.2, -0.1])) 31 | x = tf.constant([0.4, 0.2, -0.5]) 32 | loss = tf.reduce_mean(tf.square(x - w)) 33 | tvars = tf.trainable_variables() 34 | grads = tf.gradients(loss, tvars) 35 | global_step = tf.train.get_or_create_global_step() 36 | optimizer = optimization.AdamWeightDecayOptimizer(learning_rate=0.2) 37 | train_op = optimizer.apply_gradients(zip(grads, tvars), global_step) 38 | init_op = tf.group(tf.global_variables_initializer(), 39 | tf.local_variables_initializer()) 40 | sess.run(init_op) 41 | for _ in range(100): 42 | sess.run(train_op) 43 | w_np = sess.run(w) 44 | self.assertAllClose(w_np.flat, [0.4, 0.2, -0.5], rtol=1e-2, atol=1e-2) 45 | 46 | 47 | if __name__ == "__main__": 48 | tf.test.main() 49 | -------------------------------------------------------------------------------- /Further_pretraining/requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow >= 1.11.0 # CPU Version of TensorFlow. 2 | # tensorflow-gpu >= 1.11.0 # GPU version of TensorFlow. 3 | tqdm 4 | jieba 5 | -------------------------------------------------------------------------------- /Further_pretraining/run_pretrain_11_8.out: -------------------------------------------------------------------------------- 1 | nohup: ignoring input 2 | 2019-11-08 13:17:32.430055: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA 3 | 2019-11-08 13:17:32.765105: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1432] Found device 0 with properties: 4 | name: GeForce RTX 2080 Ti major: 7 minor: 5 memoryClockRate(GHz): 1.605 5 | pciBusID: 0000:85:00.0 6 | totalMemory: 10.73GiB freeMemory: 10.53GiB 7 | 2019-11-08 13:17:32.765160: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1511] Adding visible gpu devices: 0 8 | 2019-11-08 13:17:35.777063: I tensorflow/core/common_runtime/gpu/gpu_device.cc:982] Device interconnect StreamExecutor with strength 1 edge matrix: 9 | 2019-11-08 13:17:35.777110: I tensorflow/core/common_runtime/gpu/gpu_device.cc:988] 0 10 | 2019-11-08 13:17:35.777119: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1001] 0: N 11 | 2019-11-08 13:17:35.778313: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1115] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 10168 MB memory) -> physical GPU (device: 0, name: GeForce RTX 2080 Ti, pci bus id: 0000:85:00.0, compute capability: 7.5) 12 | Traceback (most recent call last): 13 | File "/usr/local/lib/python3.6/dist-packages/absl/flags/_flagvalues.py", line 528, in _assert_validators 14 | validator.verify(self) 15 | File "/usr/local/lib/python3.6/dist-packages/absl/flags/_validators.py", line 81, in verify 16 | raise _exceptions.ValidationError(self.message) 17 | absl.flags._exceptions.ValidationError: Flag --input_file must have a value other than None. 18 | 19 | During handling of the above exception, another exception occurred: 20 | 21 | Traceback (most recent call last): 22 | File "run_pretraining.py", line 499, in 23 | tf.app.run() 24 | File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/platform/app.py", line 119, in run 25 | argv = flags.FLAGS(_sys.argv if argv is None else argv, known_only=True) 26 | File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/platform/flags.py", line 112, in __call__ 27 | return self.__dict__['__wrapped'].__call__(*args, **kwargs) 28 | File "/usr/local/lib/python3.6/dist-packages/absl/flags/_flagvalues.py", line 636, in __call__ 29 | self._assert_all_validators() 30 | File "/usr/local/lib/python3.6/dist-packages/absl/flags/_flagvalues.py", line 510, in _assert_all_validators 31 | self._assert_validators(all_validators) 32 | File "/usr/local/lib/python3.6/dist-packages/absl/flags/_flagvalues.py", line 531, in _assert_validators 33 | raise _exceptions.IllegalFlagValueError('%s: %s' % (message, str(e))) 34 | absl.flags._exceptions.IllegalFlagValueError: flag --input_file=None: Flag --input_file must have a value other than None. 35 | -------------------------------------------------------------------------------- /Further_pretraining/sample_text.txt: -------------------------------------------------------------------------------- 1 | This text is included to make sure Unicode is handled properly: 力加勝北区ᴵᴺᵀᵃছজটডণত 2 | Text should be one-sentence-per-line, with empty lines between documents. 3 | This sample text is public domain and was randomly selected from Project Guttenberg. 4 | 5 | The rain had only ceased with the gray streaks of morning at Blazing Star, and the settlement awoke to a moral sense of cleanliness, and the finding of forgotten knives, tin cups, and smaller camp utensils, where the heavy showers had washed away the debris and dust heaps before the cabin doors. 6 | Indeed, it was recorded in Blazing Star that a fortunate early riser had once picked up on the highway a solid chunk of gold quartz which the rain had freed from its incumbering soil, and washed into immediate and glittering popularity. 7 | Possibly this may have been the reason why early risers in that locality, during the rainy season, adopted a thoughtful habit of body, and seldom lifted their eyes to the rifted or india-ink washed skies above them. 8 | "Cass" Beard had risen early that morning, but not with a view to discovery. 9 | A leak in his cabin roof,--quite consistent with his careless, improvident habits,--had roused him at 4 A. M., with a flooded "bunk" and wet blankets. 10 | The chips from his wood pile refused to kindle a fire to dry his bed-clothes, and he had recourse to a more provident neighbor's to supply the deficiency. 11 | This was nearly opposite. 12 | Mr. Cassius crossed the highway, and stopped suddenly. 13 | Something glittered in the nearest red pool before him. 14 | Gold, surely! 15 | But, wonderful to relate, not an irregular, shapeless fragment of crude ore, fresh from Nature's crucible, but a bit of jeweler's handicraft in the form of a plain gold ring. 16 | Looking at it more attentively, he saw that it bore the inscription, "May to Cass." 17 | Like most of his fellow gold-seekers, Cass was superstitious. 18 | 19 | The fountain of classic wisdom, Hypatia herself. 20 | As the ancient sage--the name is unimportant to a monk--pumped water nightly that he might study by day, so I, the guardian of cloaks and parasols, at the sacred doors of her lecture-room, imbibe celestial knowledge. 21 | From my youth I felt in me a soul above the matter-entangled herd. 22 | She revealed to me the glorious fact, that I am a spark of Divinity itself. 23 | A fallen star, I am, sir!' continued he, pensively, stroking his lean stomach--'a fallen star!--fallen, if the dignity of philosophy will allow of the simile, among the hogs of the lower world--indeed, even into the hog-bucket itself. Well, after all, I will show you the way to the Archbishop's. 24 | There is a philosophic pleasure in opening one's treasures to the modest young. 25 | Perhaps you will assist me by carrying this basket of fruit?' And the little man jumped up, put his basket on Philammon's head, and trotted off up a neighbouring street. 26 | Philammon followed, half contemptuous, half wondering at what this philosophy might be, which could feed the self-conceit of anything so abject as his ragged little apish guide; 27 | but the novel roar and whirl of the street, the perpetual stream of busy faces, the line of curricles, palanquins, laden asses, camels, elephants, which met and passed him, and squeezed him up steps and into doorways, as they threaded their way through the great Moon-gate into the ample street beyond, drove everything from his mind but wondering curiosity, and a vague, helpless dread of that great living wilderness, more terrible than any dead wilderness of sand which he had left behind. 28 | Already he longed for the repose, the silence of the Laura--for faces which knew him and smiled upon him; but it was too late to turn back now. 29 | His guide held on for more than a mile up the great main street, crossed in the centre of the city, at right angles, by one equally magnificent, at each end of which, miles away, appeared, dim and distant over the heads of the living stream of passengers, the yellow sand-hills of the desert; 30 | while at the end of the vista in front of them gleamed the blue harbour, through a network of countless masts. 31 | At last they reached the quay at the opposite end of the street; 32 | and there burst on Philammon's astonished eyes a vast semicircle of blue sea, ringed with palaces and towers. 33 | He stopped involuntarily; and his little guide stopped also, and looked askance at the young monk, to watch the effect which that grand panorama should produce on him. 34 | -------------------------------------------------------------------------------- /Further_pretraining/tokenization_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import os 20 | import tempfile 21 | import tokenization 22 | import six 23 | import tensorflow as tf 24 | 25 | 26 | class TokenizationTest(tf.test.TestCase): 27 | 28 | def test_full_tokenizer(self): 29 | vocab_tokens = [ 30 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 31 | "##ing", "," 32 | ] 33 | with tempfile.NamedTemporaryFile(delete=False) as vocab_writer: 34 | if six.PY2: 35 | vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) 36 | else: 37 | vocab_writer.write("".join( 38 | [x + "\n" for x in vocab_tokens]).encode("utf-8")) 39 | 40 | vocab_file = vocab_writer.name 41 | 42 | tokenizer = tokenization.FullTokenizer(vocab_file) 43 | os.unlink(vocab_file) 44 | 45 | tokens = tokenizer.tokenize(u"UNwant\u00E9d,running") 46 | self.assertAllEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"]) 47 | 48 | self.assertAllEqual( 49 | tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) 50 | 51 | def test_chinese(self): 52 | tokenizer = tokenization.BasicTokenizer() 53 | 54 | self.assertAllEqual( 55 | tokenizer.tokenize(u"ah\u535A\u63A8zz"), 56 | [u"ah", u"\u535A", u"\u63A8", u"zz"]) 57 | 58 | def test_basic_tokenizer_lower(self): 59 | tokenizer = tokenization.BasicTokenizer(do_lower_case=True) 60 | 61 | self.assertAllEqual( 62 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), 63 | ["hello", "!", "how", "are", "you", "?"]) 64 | self.assertAllEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"]) 65 | 66 | def test_basic_tokenizer_no_lower(self): 67 | tokenizer = tokenization.BasicTokenizer(do_lower_case=False) 68 | 69 | self.assertAllEqual( 70 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), 71 | ["HeLLo", "!", "how", "Are", "yoU", "?"]) 72 | 73 | def test_wordpiece_tokenizer(self): 74 | vocab_tokens = [ 75 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 76 | "##ing" 77 | ] 78 | 79 | vocab = {} 80 | for (i, token) in enumerate(vocab_tokens): 81 | vocab[token] = i 82 | tokenizer = tokenization.WordpieceTokenizer(vocab=vocab) 83 | 84 | self.assertAllEqual(tokenizer.tokenize(""), []) 85 | 86 | self.assertAllEqual( 87 | tokenizer.tokenize("unwanted running"), 88 | ["un", "##want", "##ed", "runn", "##ing"]) 89 | 90 | self.assertAllEqual( 91 | tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"]) 92 | 93 | def test_convert_tokens_to_ids(self): 94 | vocab_tokens = [ 95 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 96 | "##ing" 97 | ] 98 | 99 | vocab = {} 100 | for (i, token) in enumerate(vocab_tokens): 101 | vocab[token] = i 102 | 103 | self.assertAllEqual( 104 | tokenization.convert_tokens_to_ids( 105 | vocab, ["un", "##want", "##ed", "runn", "##ing"]), [7, 4, 5, 8, 9]) 106 | 107 | def test_is_whitespace(self): 108 | self.assertTrue(tokenization._is_whitespace(u" ")) 109 | self.assertTrue(tokenization._is_whitespace(u"\t")) 110 | self.assertTrue(tokenization._is_whitespace(u"\r")) 111 | self.assertTrue(tokenization._is_whitespace(u"\n")) 112 | self.assertTrue(tokenization._is_whitespace(u"\u00A0")) 113 | 114 | self.assertFalse(tokenization._is_whitespace(u"A")) 115 | self.assertFalse(tokenization._is_whitespace(u"-")) 116 | 117 | def test_is_control(self): 118 | self.assertTrue(tokenization._is_control(u"\u0005")) 119 | 120 | self.assertFalse(tokenization._is_control(u"A")) 121 | self.assertFalse(tokenization._is_control(u" ")) 122 | self.assertFalse(tokenization._is_control(u"\t")) 123 | self.assertFalse(tokenization._is_control(u"\r")) 124 | 125 | def test_is_punctuation(self): 126 | self.assertTrue(tokenization._is_punctuation(u"-")) 127 | self.assertTrue(tokenization._is_punctuation(u"$")) 128 | self.assertTrue(tokenization._is_punctuation(u"`")) 129 | self.assertTrue(tokenization._is_punctuation(u".")) 130 | 131 | self.assertFalse(tokenization._is_punctuation(u"A")) 132 | self.assertFalse(tokenization._is_punctuation(u" ")) 133 | 134 | 135 | if __name__ == "__main__": 136 | tf.test.main() 137 | -------------------------------------------------------------------------------- /Fusion_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | 该代码主要执行以下五个模型的融合: 6 | 1.model_1_bert_att_drop_42.py 7 | 2.model_2_bert_att_drop_further_pretrain.py 8 | 3.model_3_roberte_wwm_ext_att_drop_42.py.py 9 | 4.model_4_bert_att_drop_420.py 10 | 5.model_5_bert_att_drop_1001001.py 11 | 12 | 融合方法:五个模型概率求平均 13 | """ 14 | 15 | 16 | import numpy as np 17 | from tqdm import tqdm 18 | import time 19 | import logging 20 | import os 21 | import pandas as pd 22 | from sklearn.metrics import f1_score 23 | 24 | # 创建一个logger 25 | file_path = './log/' 26 | logger = logging.getLogger('mylogger') 27 | logger.setLevel(logging.DEBUG) 28 | timestamp = time.strftime("%Y.%m.%d_%H.%M.%S", time.localtime()) 29 | fh = logging.FileHandler(file_path + 'log_fusion_model.txt') 30 | fh.setLevel(logging.DEBUG) 31 | ch = logging.StreamHandler() 32 | ch.setLevel(logging.DEBUG) 33 | formatter = logging.Formatter('[%(asctime)s][%(levelname)s] ## %(message)s') 34 | fh.setFormatter(formatter) 35 | ch.setFormatter(formatter) 36 | logger.addHandler(fh) 37 | logger.addHandler(ch) 38 | 39 | file_name = 'Fusion_model_6' 40 | 41 | 42 | class InputExample(object): 43 | """A single training/test example for simple sequence classification.""" 44 | 45 | def __init__(self, id, text, entity=None, label=None): 46 | """Constructs a InputExample. 47 | Args: 48 | guid: Unique id for the example. 49 | text_a: string. The untokenized text of the first sequence. For single 50 | sequence tasks, only this sequence must be specified. 51 | text_b: (Optional) string. The untokenized text of the second sequence. 52 | Only must be specified for sequence pair tasks. 53 | label: (Optional) string. The label of the example. This should be 54 | specified for train and dev examples, but not for test examples. 55 | """ 56 | self.id = id 57 | self.text = text 58 | self.entity = entity 59 | self.label = label 60 | 61 | 62 | def read_examples(input_file, is_training): 63 | df = pd.read_csv(input_file) 64 | if not is_training: 65 | df['negative'] = np.zeros(len(df), dtype=np.int64) 66 | examples = [] 67 | for val in df[['id', 'text', 'entity', 'negative']].values: 68 | examples.append(InputExample(id=val[0], text=val[1], entity=val[2], label=val[3])) 69 | return examples, df 70 | 71 | 72 | def postprocess(raw, df, prefix=''): 73 | """ 74 | 将多条预测结果数据拼接成一条 75 | :param raw: 76 | :param df: 77 | :param prefix: 78 | :return: 79 | """ 80 | negatives = [] 81 | key_entities = [] 82 | 83 | for raw_id in tqdm(raw['id'].tolist()): 84 | result = df[df['id'] == raw_id] 85 | if len(result) > 0: 86 | negative = 0 87 | key_entity = [] 88 | for n, e in zip(result[prefix+'negative'].tolist(), result['entity']): 89 | if '?' in e: 90 | n = 1 91 | if n == 1: 92 | negative = 1 93 | repeat = False 94 | for k_e in key_entity.copy(): 95 | if e in k_e: 96 | repeat = True 97 | break 98 | elif k_e in e: 99 | key_entity.remove(k_e) 100 | key_entity.append(e) 101 | repeat = True 102 | break 103 | if not repeat: 104 | key_entity.append(e) 105 | negatives.append(negative) 106 | key_entities.append(';'.join(key_entity)) 107 | else: 108 | negatives.append(0) 109 | key_entities.append('') 110 | 111 | raw[prefix+'negative'] = negatives 112 | raw[prefix+'key_entity'] = key_entities 113 | return raw 114 | 115 | 116 | def metric(train): 117 | negative_true = train['negative'].tolist() 118 | negative_pred = train['pred_negative'].tolist() 119 | negative_f1 = f1_score(negative_true, negative_pred) 120 | 121 | key_entities_true = train['key_entity'].tolist() 122 | key_entities_pred = train['pred_key_entity'].tolist() 123 | A, B, C = 1e-10, 1e-10, 1e-10 124 | for e_true, e_pred in zip(key_entities_true, key_entities_pred): 125 | if type(e_true) == float: 126 | e_true = '' 127 | if type(e_pred) == float: 128 | e_pred = '' 129 | e_true = set(e_true.split(';')) 130 | e_pred = set(e_pred.split(';')) 131 | A += len(e_true & e_pred) 132 | B += len(e_pred) 133 | C += len(e_true) 134 | entities_f1 = 2 * A / (B + C) 135 | logger.info('precission: %.8f, recall: %.8f, f1: %.8f' % (A/B, A/C, entities_f1)) 136 | return 0.4*negative_f1, 0.6*entities_f1, 0.4*negative_f1 + 0.6*entities_f1 137 | 138 | 139 | if __name__ == '__main__': 140 | 141 | # 加载数据 142 | train_examples, train_df = read_examples('./datasets/preprocess_round_1_2_train_data.csv', is_training=True) 143 | test_examples, test_df = read_examples('./datasets/preprocess_round2_test.csv', is_training=False) 144 | raw_train = pd.read_csv('./datasets/round_1_2_train_data.csv') 145 | raw_test = pd.read_csv('./datasets/round2_test.csv') 146 | 147 | # 计算训练集的平均融合的概率 148 | oof_train_total = 0. 149 | for i, file_name in enumerate(sorted(os.listdir('./submit/train_prob'))): 150 | file = os.path.join('./submit/train_prob', file_name) 151 | oof_train = np.loadtxt(file) 152 | oof_train_total += oof_train 153 | oof_train_ave = oof_train_total / 5 154 | 155 | # 计算测试集的平均融合的概率 156 | oof_test_total = 0. 157 | for i, file_name in enumerate(sorted(os.listdir('./submit/test_prob'))): 158 | file = os.path.join('./submit/test_prob', file_name) 159 | oof_test = np.loadtxt(file) 160 | oof_test_total += oof_test 161 | oof_test_ave = oof_test_total / 5 162 | 163 | labels = train_df['negative'].astype(int).values 164 | train_df['pred_negative'] = np.argmax(oof_train_ave, axis=1) 165 | test_df['negative'] = np.argmax(oof_test_ave, axis=1) 166 | 167 | pred_train = postprocess(raw_train, train_df, prefix='pred_') 168 | pred_train.to_csv('./submit/train_5_model_ave_predict.csv', index=False) 169 | negative_f1, entity_f1, weight_f1 = metric(pred_train) 170 | logger.info('negative_f1: %.8f, entity_f1: %.8f, weight_f1: %.8f\n' % 171 | (negative_f1, entity_f1, weight_f1)) 172 | 173 | submit = postprocess(raw_test, test_df) 174 | submit[['id', 'negative', 'key_entity']].to_csv('./submit/Fusion_model_test_predict.csv', index=False) 175 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CCF2019 BDCI 金融信息负面及主体判定 (登封造极 团队第三名方案) 2 | 3 | ## 整理不易,烦请点个star~ 本人研究方向为NLP,欢迎交流~ ~有时间会写个总结~,可以关注[我的知乎](https://www.zhihu.com/people/chen-feng-91-57/posts) 4 | 5 | * 队伍:登峰造极 6 | * Chevalier 7 | * 42Kjs 8 | * zhkkk 9 | * 队友好棒棒 10 | * Wizare 11 | 12 | ## 代码运行环境: 13 | * python3 (最好是python3.6.7) 14 | * pytorch-transformers 1.2.0 15 | * torch 1.1.0 16 | * tensorflow-gpu 1.12.0 17 | * numpy 1.16.3 18 | * tqdm 4.31.1 19 | * scikit-learn 0.20.3 20 | * pandas 0.24.2 21 | 22 | 23 | ## 代码运行系统环境: 24 | * 系统: Linux version 4.4.0-138-generic (buildd@lcy01-amd64-006) (gcc version 5.4.0 20160609 (Ubuntu 5.4.0-6ubuntu1~16.04.10) ) #164-Ubuntu SMP Tue Oct 2 17:16:02 UTC 2018 25 | * CPU: Intel(R) Xeon(R) CPU E5-2637 v4 @ 3.50GHz 26 | * GPU: 4*2080Ti 11G 27 | * CUDA Version 10.0.130 28 | * cudnn 7.5.0 29 | 30 | 31 | ## 方案概述: 32 | * 我们将本赛题简化成一个任务——基于金融实体的情感分析,在3种预训练模型上进行微调,包括3个不同种子的bert、1个在比赛数据集再次预训练的bert、1个roberta_wwm_ext模型。模型选型上考虑互补性,并考虑模型效率,只使用了base模型,使其更适合在真实环境中落地使用。 33 | * 数据预处理新颖,将样本转化为“["[CLS]"]文本["[SEP]"]实体n-10;...;<实体n>;...;实体n+10["[SEP]"]”的格式,使其能够考虑相邻实体之间相关性。 34 | * 预训练模型简洁有效,考虑到bert每一层能够捕捉到输入文本不同维度的特征,我们通过attention集成每一层的输出以得到更好的文本表示。 35 | * 同时我们通过Multi-Sample Dropout提升模型的泛化性。 36 | * 最终我们平均融合5个预训练模型,融合后的模型在线上已经能取得较好的成绩。 37 | * 另外,考虑到模型不能解决所有问题,因此我们在模型融合的基础上进行了后处理,提升了总体性能。 38 | 39 | 40 | ## 代码框架: 41 | * datasets/: 存放原始的数据集,以及预处理后的数据集 42 | * preprocess_round_1_2_train_data.csv: 对初赛和复赛合并之后的训练集进行预处理后的文件 43 | * preprocess_round2_test.csv: 对复赛测试集进行预处理后的文件 44 | * pretrain_data.txt: 用于further pretrain的训练文件 45 | * round_1_2_train_data.csv: 初赛和复赛合并之后的训练集 46 | * round2_test.csv: 复赛提交的测试集 47 | * Round2_train.csv: 复赛提供的训练集 48 | * Train_Data.csv: 初赛提供的训练集 49 | * transformers: 用于将tensorflow的预训练权重转换为pytorch的预训练权重 50 | * Further_pretraining/: 根据现有数据集,对bert模型进行pretrain训练 51 | * pretrain_weight/: 预训练模型权重 52 | * bert 53 | * bert_further_pretrain 54 | * roberta_wwm_ext 55 | * model_save/: 从头开始训练,存放模型训练的最优权重 56 | * best_model_save/: 直接预测用的模型最优权重 57 | * log/: 日志存放 58 | 59 | * Fusion_model.py: 模型融合脚本 60 | 61 | * model_1_bert_att_drop_42.py: 在bert模型的基础上,添加attention和dropout层作为整体训练模型,以随机种子为42进行训练 62 | 63 | * model_2_bert_att_drop_further_pretrain.py: 先根据现有数据集对bert模型进行further pretrain,得到新bert的模型权重。在bert模型的基础上,添加attention和dropout层作为整体训练模型 64 | 65 | * model_3_roberte_wwm_ext_att_drop_42.py: 在roberte_wwm_ext模型的基础上,添加attention和dropout层作为整体训练模型,以随机种子为42进行训练 66 | 67 | * model_4_bert_att_drop_420.py: 在bert模型的基础上,添加attention和dropout层作为整体训练模型,以随机种子为420进行训练 68 | 69 | * model_5_bert_att_drop_1001001.py: 在bert模型的基础上,添加attention和dropout层作为整体训练模型,以随机种子为1001001进行训练 70 | 71 | * predict_model_1_bert_att_drop_42.py: 无须训练,加载最优模型直接预测 72 | 73 | * predict_model_2_bert_att_drop_further_pretrain.py: 无须训练,加载最优模型直接预测 74 | 75 | * predict_model_3_roberte_wwm_ext_att_drop_42.py: 无须训练,加载最优模型直接预测 76 | 77 | * predict_model_4_bert_att_drop_420.py: 无须训练,加载最优模型直接预测 78 | 79 | * predict_model_5_bert_att_drop_1001001.py: 无须训练,加载最优模型直接预测 80 | 81 | 82 | 83 | * preprocess.py: 数据预处理脚本 84 | 85 | * postprocess.py: 模型预测结果后处理脚本 86 | 87 | 88 | 89 | ## 复现:One Step: 90 | * 因为训练模型比较久而且模型比较大,所以我们提供了所有模型对OOF和测试集的预测结果(./submit/train_prob和./submit/test_prob),只需要简单的做一下概率平均,然后运行一下后处理就可以得到我们提交的最好结果。 91 | 92 | ``` 93 | python Fusion_model.py 94 | python postprocess.py 95 | ``` 96 | 97 | 最后生成的./submit/best_result.csv即可用于提交。 98 | * 当然如果想要从头复现,可以看下面的说明: 99 | 100 | ## 复现:step by step 101 | ## 1. 预处理模块: 102 | * 该文件为预处理文件,主要进行以下几个预处理: 103 | 1.清除无用的信息 104 | 2.如果待预测实体不在文本的前512中,将预测实体所在的文本提前到前512中 105 | 3.将文本中出现的实体,添加上“<”,“>”,来突出实体 106 | 4.将含有多条实体的数据切分成多条只预测一个实体的数据 107 | 5.截断文本(取前512) 108 | 得到"./datasets/preprocess_round_1_2_train_data.csv"和"preprocess_round2_test.csv" 109 | 110 | 这里我们使用的是初赛和复赛合并之后的训练集数据集,完全复现请使用合并后的数据集("./datasets/round_1_2_train_data.csv")。 111 | ``` 112 | python preprocess.py 113 | ``` 114 | 如果是使用新数据集(更改对应参数),使用以下: 115 | ``` 116 | python preprocess.py ./datasets/round_1_2_train_data.csv ./datasets/round2_test.csv 117 | ``` 118 | 119 | ## 2. 预训练权重 120 | *Ps: 如果嫌检查预训练权重麻烦,可以跳过该步骤,我们已经提供了pytorch版本的bert权重、再次预训练的bert权重、roberta_wwm_ext权重 121 | * "./pretrain_weight"下有三个预训练权重:(1)bert-base(2)roberta_wwm_ext(3)bert_further_pretrain,我们已经放在该文件下,文件来源如下: 122 | 1.[BERT-Base, Chinese](https://github.com/google-research/bert#pre-trained-models),这里只提供tensorflow版本,还需转换成pytorch版本。 123 | 2.[roberta_wwm_ext](https://github.com/ymcui/Chinese-BERT-wwm),通过讯飞云下载pytoch版本。 124 | 3.bert_further_pretrain,其中bert_further_pretrain预训练权重为bert-base通过在该比赛数据集再次预训练得到。由于训练时间比较长,我们提供已经further-pretrain好的权重供下载。 125 | * 如果你想Further pretrain Bert, 可以执行一下脚本: 126 | ``` 127 | sh ./shell/get_pretrain_data.sh 128 | sh ./shell/run_pretrain.sh 129 | ``` 130 | *Ps:你自己从官网下载的BERT-Base, Chinese和通过脚本再次预训练得到的bert-base-further-pretrain,得到的是tensorflow的权重,还需要转换为pytorch的bert权重,可以执行以下脚本或者参考[tensorflow-bert权重转pytorch](https://www.lizenghai.com/archives/32772.html) 131 | 132 | ``` 133 | cd transformers 134 | export BERT_BASE_DIR=#tensorflow权重的绝对路径# 135 | python convert_bert_original_tf_checkpoint_to_pytorch.py --tf_checkpoint_path $BERT_BASE_DIR/bert_model.ckpt --bert_config_file $BERT_BASE_DIR/bert_config.json --pytorch_dump_path $BERT_BASE_DIR/pytorch_model.bin 136 | ``` 137 | Ps:还需要把bert_config.json文件重命名为config.json 138 | 139 | 140 | ## 3. 模型训练 141 | * 该模块是主要的模型训练及模型在测试集上的预测。 142 | * 模型采用七折交叉训练。 143 | * 首先需要从百度云下载预训练权重,copy到"./pretrain_weight/"下 144 | * 执行脚本训练模型,每个模型训练的时间在15个小时左右。 145 | * 各个模型的权重在训练完后将保存在"./model_save"下,概率文件将保存在"./submit/train_prob"和"./submit/test_prob"下。 146 | 依次执行代码训练五个模型如下: 147 | ``` 148 | python model_1_bert_att_drop_42.py 149 | python model_2_bert_att_drop_further_pretrain.py 150 | python model_3_roberte_wwm_ext_att_drop_42.py.py 151 | python model_4_bert_att_drop_420.py 152 | python model_5_bert_att_drop_1001001.py 153 | ``` 154 | Ps:如果GPU指定报错,在脚本中可以修改GPU参数 155 | Ps:如果嫌模型训练时间过长,可执行以下代码直接预测 156 | ``` 157 | python predict_model_1_bert_att_drop_42.py 158 | python predict_model_2_bert_att_drop_further_pretrain.py 159 | python predict_model_3_roberte_wwm_ext_att_drop_42.py.py 160 | python predict_model_4_bert_att_drop_420.py 161 | python predict_model_5_bert_att_drop_1001001.py 162 | ``` 163 | 164 | 165 | ## 4.模型融合 166 | 该模块将五个模型的概率文件平均融合。该结果在线上已经能取得一个不错的成绩。 167 | ``` 168 | python Fusion_model.py 169 | ``` 170 | 171 | ## 5.后处理 172 | * 该模块主要根据训练集中的一些实体共现频率提取的规则,处理了下并列实体的情况,以及根据训练集的先验知识,补充部分短实体。 173 | * 运行得到最终的提交文件best_result.csv。 174 | ``` 175 | python postprocess.py 176 | ``` 177 | 178 | ## 6.提交: 179 | * 在submit目录下, 提交best_result.csv。 180 | 181 | ## Concat: 182 | email:scut_chenfeng@163.com 183 | 184 | ## 特别鸣谢 185 | * https://github.com/GeneZC/BERTFinanceNeg 186 | * https://github.com/guoday/CCF-BDCI-Sentiment-Analysis-Baseline 187 | 188 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | import csv 5 | import re 6 | import distance 7 | import sys 8 | import pandas as pd 9 | 10 | """ 11 | 该文件为预处理文件,主要进行以下几个预处理: 12 | 1.清除无用的信息 13 | 2.如果待预测实体不在文本的前512中,将预测实体所在的文本提前到前512中 14 | 3.将文本中出现的实体,添加上“<”,“>”,来突出实体 15 | 4.将含有多条实体的数据切分成多条只预测一个实体的数据 16 | 5.截断文本(取前512) 17 | """ 18 | 19 | max_seq_length = 512 20 | 21 | 22 | def clean_space(text): 23 | """" 24 | 处理多余的空格 25 | """ 26 | match_regex = re.compile(u'[\u4e00-\u9fa5。\.,,::《》、\(\)()]{1} +(?', '', text) 51 | text = re.sub(r'http(s?):[/a-zA-Z0-9.=&?_#]+', '', text) 52 | text = re.sub(r'“', '', text) 53 | text = re.sub(r'”', '', text) 54 | text = re.sub(r'—{5,}', '', text) 55 | 56 | text = re.sub(r'?{2,}', '', text) 57 | text = re.sub(r'●', '', text) 58 | text = re.sub(r'【图】', '', text) 59 | text = re.sub(r'[0-9]+[-|.|/|年][0-9]{2}[-|.|/|月][0-9]{2}日?', '', text) 60 | text = re.sub(r' ', '', text) 61 | text = re.sub(r'[0-9]{15,}', '', text) 62 | text = re.sub(r'"', '', text) 63 | return text 64 | 65 | 66 | def match(title, text): 67 | if type(title) == float or type(text) == float: 68 | return 1 69 | strs = list(distance.lcsubstrings(title, text)) 70 | if len(strs) > 0: 71 | return len(strs[0]) / len(title) 72 | else: 73 | return 0 74 | 75 | 76 | def get_entity_sentence(entity, text): 77 | """ 78 | 找出预测实体所在的文本 79 | :param entity: 80 | :param text: 81 | :return: 82 | """ 83 | index = text.find(entity) 84 | if index > 512: 85 | split_symbol = ['.', '。', '!', '!', '?', '?', ';', ';', ' ', '\t', '\n'] 86 | for i in range(50): 87 | if text[index - 20 - i - 1] in split_symbol: 88 | return ''.join(text[index - 20 - i: len(text)]) 89 | else: 90 | return None 91 | 92 | 93 | def process(filename, data_path, mode='train'): 94 | """ 95 | 数据预处理主函数 96 | :param filename: 97 | :param mode: 98 | :return: 99 | """ 100 | header = [] 101 | rows = [] 102 | with open(filename, 'r', encoding='utf-8-sig') as f: 103 | f_csv = csv.reader(f) 104 | for i, row in enumerate(f_csv): 105 | if i == 0: 106 | if mode == 'train': 107 | header = [row[0], row[2], 'clean_entity'] + row[3:5] 108 | else: 109 | header = [row[0], row[2], 'clean_entity', row[3]] 110 | else: 111 | text_id, title, text = row[0:3] 112 | if row[3] != '': 113 | entities = row[3].split(';') 114 | else: 115 | entities = [''] 116 | if mode == 'train': 117 | if row[5] != '': 118 | key_entities = clean(row[5]).split(';') 119 | else: 120 | key_entities = [] 121 | if len(text) == 0 or type(text) == float: 122 | text = title 123 | text = clean(text) 124 | for index, entity in enumerate(entities): 125 | clean_entity = clean(entity) 126 | entity_in_text = get_entity_sentence(clean_entity, text) 127 | new_text = text 128 | if entity_in_text != None: 129 | new_text = entity_in_text 130 | if clean_entity not in text: 131 | new_text = clean_entity + '。' + text 132 | if entity == '' or entity != entity or type(entity) == float: 133 | continue 134 | if clean_entity == '' or clean_entity != clean_entity or type(clean_entity) == float: 135 | continue 136 | new_text = new_text.replace(clean_entity, '<' + clean_entity + '>') 137 | if mode == 'train': 138 | if clean_entity in key_entities: 139 | negative = 1 140 | else: 141 | negative = 0 142 | if mode == 'train': 143 | new_entities = entities.copy() 144 | new_entities[index] = '<' + clean_entity + '>' 145 | 146 | new_entities_sub = [] 147 | for i in range(index-10, index + 10): 148 | if i >= 0 and i < len(new_entities): 149 | new_entities_sub.append(new_entities[i]) 150 | new_entity = ';'.join(new_entities_sub) 151 | rows.append([text_id, new_text, new_entity, entity, negative]) 152 | else: 153 | new_entities = entities.copy() 154 | new_entities[index] = '<' + clean_entity + '>' 155 | new_entities_sub = [] 156 | for i in range(index-10, index + 10): 157 | if i >= 0 and i < len(new_entities): 158 | new_entities_sub.append(new_entities[i]) 159 | new_entity = ';'.join(new_entities_sub) 160 | rows.append([text_id, new_text, new_entity, entity]) 161 | 162 | with open(data_path, 'w', encoding='utf-8-sig', newline='') as f: 163 | f_csv = csv.writer(f) 164 | f_csv.writerow(header) 165 | for row in rows: 166 | f_csv.writerow(row) 167 | 168 | 169 | def merge_round_1_round_2(): 170 | """ 171 | 合并初赛和复赛的数据 172 | :return: 173 | """ 174 | train_round_1 = pd.read_csv('./datasets/Train_Data.csv') 175 | train_round_2 = pd.read_csv('./datasets/Round2_train.csv') 176 | merge_data = pd.concat([train_round_2, train_round_1], ignore_index=True) 177 | merge_data = merge_data.drop_duplicates(keep='first', subset=['text']) 178 | merge_data.to_csv('./datasets/round_1_2_train_data.csv', encoding='utf-8', index=False) 179 | 180 | 181 | if __name__ == '__main__': 182 | merge_round_1_round_2() 183 | 184 | if len(sys.argv) < 2: 185 | 186 | # 训练集的预处理 187 | process('./datasets/round_1_2_train_data.csv', './datasets/preprocess_round_1_2_train_data.csv', 'train') 188 | # 测试集的预处理 189 | process('./datasets/round2_test.csv', './datasets/preprocess_round2_test.csv', 'test') 190 | 191 | else: 192 | train_path = sys.argv[1] 193 | test_path = sys.argv[2] 194 | # 训练集的预处理 195 | process(train_path, './datasets/preprocess_round_1_2_train_data.csv', 'train') 196 | # 测试集的预处理 197 | process(test_path, './datasets/preprocess_round2_test.csv', 'test') 198 | 199 | -------------------------------------------------------------------------------- /shell/get_pretrain_data.sh: -------------------------------------------------------------------------------- 1 | python ./Further_pretraining/create_pretraining_data.py --input_file=./datasets/pretrain_data.txt --output_file=./datasets/tf_pretrain_data.tfrecord --vocab_file=./pretrain_weight/bert/vocab.txt --do_lower_case=True --max_seq_length=128 --max_predictions_per_seq=20 --masked_lm_prob=0.15 --random_seed=42 --dupe_factor=5 -------------------------------------------------------------------------------- /shell/run_pretrain.sh: -------------------------------------------------------------------------------- 1 | python ./Further_pretraining/run_pretraining.py --input_file=./datasets/pretrain_data.tfrecord --output_dir=./pretrain_weight/pretraining_output --do_train=True --do_eval=True --bert_config_file=./pretrain_weight/bert/bert_config.json --init_checkpoint=./pretrain_weight/bert/bert_model.ckpt --train_batch_size=32 --eval_batch_size=32 --max_seq_length=128 --max_predictions_per_seq=50 --num_train_steps=100000 --num_warmup_steps=10000 --learning_rate=5e-5 -------------------------------------------------------------------------------- /transformers/__main__.py: -------------------------------------------------------------------------------- 1 | # coding: utf8 2 | def main(): 3 | import sys 4 | if (len(sys.argv) < 4 or len(sys.argv) > 6) or sys.argv[1] not in ["bert", "gpt", "transfo_xl", "gpt2", "xlnet", "xlm"]: 5 | print( 6 | "This command line utility let you convert original (author released) model checkpoint to pytorch.\n" 7 | "It should be used as one of: \n" 8 | ">> transformers bert TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT, \n" 9 | ">> transformers gpt OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG], \n" 10 | ">> transformers transfo_xl TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG] or \n" 11 | ">> transformers gpt2 TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [GPT2_CONFIG] or \n" 12 | ">> transformers xlnet TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT [FINETUNING_TASK_NAME] or \n" 13 | ">> transformers xlm XLM_CHECKPOINT_PATH PYTORCH_DUMP_OUTPUT") 14 | else: 15 | if sys.argv[1] == "bert": 16 | try: 17 | from .convert_bert_original_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch 18 | except ImportError: 19 | print("transformers can only be used from the commandline to convert TensorFlow models in PyTorch, " 20 | "In that case, it requires TensorFlow to be installed. Please see " 21 | "https://www.tensorflow.org/install/ for installation instructions.") 22 | raise 23 | 24 | if len(sys.argv) != 5: 25 | # pylint: disable=line-too-long 26 | print("Should be used as `transformers bert TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`") 27 | else: 28 | PYTORCH_DUMP_OUTPUT = sys.argv.pop() 29 | TF_CONFIG = sys.argv.pop() 30 | TF_CHECKPOINT = sys.argv.pop() 31 | convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) 32 | elif sys.argv[1] == "gpt": 33 | from .convert_openai_original_tf_checkpoint_to_pytorch import convert_openai_checkpoint_to_pytorch 34 | if len(sys.argv) < 4 or len(sys.argv) > 5: 35 | # pylint: disable=line-too-long 36 | print("Should be used as `transformers gpt OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`") 37 | else: 38 | OPENAI_GPT_CHECKPOINT_FOLDER_PATH = sys.argv[2] 39 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 40 | if len(sys.argv) == 5: 41 | OPENAI_GPT_CONFIG = sys.argv[4] 42 | else: 43 | OPENAI_GPT_CONFIG = "" 44 | convert_openai_checkpoint_to_pytorch(OPENAI_GPT_CHECKPOINT_FOLDER_PATH, 45 | OPENAI_GPT_CONFIG, 46 | PYTORCH_DUMP_OUTPUT) 47 | elif sys.argv[1] == "transfo_xl": 48 | try: 49 | from .convert_transfo_xl_original_tf_checkpoint_to_pytorch import convert_transfo_xl_checkpoint_to_pytorch 50 | except ImportError: 51 | print("transformers can only be used from the commandline to convert TensorFlow models in PyTorch, " 52 | "In that case, it requires TensorFlow to be installed. Please see " 53 | "https://www.tensorflow.org/install/ for installation instructions.") 54 | raise 55 | if len(sys.argv) < 4 or len(sys.argv) > 5: 56 | # pylint: disable=line-too-long 57 | print("Should be used as `transformers transfo_xl TF_CHECKPOINT/TF_DATASET_FILE PYTORCH_DUMP_OUTPUT [TF_CONFIG]`") 58 | else: 59 | if 'ckpt' in sys.argv[2].lower(): 60 | TF_CHECKPOINT = sys.argv[2] 61 | TF_DATASET_FILE = "" 62 | else: 63 | TF_DATASET_FILE = sys.argv[2] 64 | TF_CHECKPOINT = "" 65 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 66 | if len(sys.argv) == 5: 67 | TF_CONFIG = sys.argv[4] 68 | else: 69 | TF_CONFIG = "" 70 | convert_transfo_xl_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT, TF_DATASET_FILE) 71 | elif sys.argv[1] == "gpt2": 72 | try: 73 | from .convert_gpt2_original_tf_checkpoint_to_pytorch import convert_gpt2_checkpoint_to_pytorch 74 | except ImportError: 75 | print("transformers can only be used from the commandline to convert TensorFlow models in PyTorch, " 76 | "In that case, it requires TensorFlow to be installed. Please see " 77 | "https://www.tensorflow.org/install/ for installation instructions.") 78 | raise 79 | 80 | if len(sys.argv) < 4 or len(sys.argv) > 5: 81 | # pylint: disable=line-too-long 82 | print("Should be used as `transformers gpt2 TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [TF_CONFIG]`") 83 | else: 84 | TF_CHECKPOINT = sys.argv[2] 85 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 86 | if len(sys.argv) == 5: 87 | TF_CONFIG = sys.argv[4] 88 | else: 89 | TF_CONFIG = "" 90 | convert_gpt2_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) 91 | elif sys.argv[1] == "xlnet": 92 | try: 93 | from .convert_xlnet_original_tf_checkpoint_to_pytorch import convert_xlnet_checkpoint_to_pytorch 94 | except ImportError: 95 | print("transformers can only be used from the commandline to convert TensorFlow models in PyTorch, " 96 | "In that case, it requires TensorFlow to be installed. Please see " 97 | "https://www.tensorflow.org/install/ for installation instructions.") 98 | raise 99 | 100 | if len(sys.argv) < 5 or len(sys.argv) > 6: 101 | # pylint: disable=line-too-long 102 | print("Should be used as `transformers xlnet TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT [FINETUNING_TASK_NAME]`") 103 | else: 104 | TF_CHECKPOINT = sys.argv[2] 105 | TF_CONFIG = sys.argv[3] 106 | PYTORCH_DUMP_OUTPUT = sys.argv[4] 107 | if len(sys.argv) == 6: 108 | FINETUNING_TASK = sys.argv[5] 109 | else: 110 | FINETUNING_TASK = None 111 | 112 | convert_xlnet_checkpoint_to_pytorch(TF_CHECKPOINT, 113 | TF_CONFIG, 114 | PYTORCH_DUMP_OUTPUT, 115 | FINETUNING_TASK) 116 | elif sys.argv[1] == "xlm": 117 | from .convert_xlm_original_pytorch_checkpoint_to_pytorch import convert_xlm_checkpoint_to_pytorch 118 | 119 | if len(sys.argv) != 4: 120 | # pylint: disable=line-too-long 121 | print("Should be used as `transformers xlm XLM_CHECKPOINT_PATH PYTORCH_DUMP_OUTPUT`") 122 | else: 123 | XLM_CHECKPOINT_PATH = sys.argv[2] 124 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 125 | 126 | convert_xlm_checkpoint_to_pytorch(XLM_CHECKPOINT_PATH, PYTORCH_DUMP_OUTPUT) 127 | 128 | if __name__ == '__main__': 129 | main() 130 | -------------------------------------------------------------------------------- /transformers/configuration_bert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ BERT model configuration """ 17 | 18 | from __future__ import absolute_import, division, print_function, unicode_literals 19 | 20 | import json 21 | import logging 22 | import sys 23 | from io import open 24 | 25 | from .configuration_utils import PretrainedConfig 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { 30 | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json", 31 | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json", 32 | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json", 33 | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json", 34 | 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json", 35 | 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json", 36 | 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json", 37 | 'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json", 38 | 'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json", 39 | 'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json", 40 | 'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json", 41 | 'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json", 42 | 'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json", 43 | 'bert-base-german-dbmdz-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-config.json", 44 | 'bert-base-german-dbmdz-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-config.json", 45 | } 46 | 47 | 48 | class BertConfig(PretrainedConfig): 49 | r""" 50 | :class:`~transformers.BertConfig` is the configuration class to store the configuration of a 51 | `BertModel`. 52 | 53 | 54 | Arguments: 55 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`. 56 | hidden_size: Size of the encoder layers and the pooler layer. 57 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 58 | num_attention_heads: Number of attention heads for each attention layer in 59 | the Transformer encoder. 60 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 61 | layer in the Transformer encoder. 62 | hidden_act: The non-linear activation function (function or string) in the 63 | encoder and pooler. If string, "gelu", "relu", "swish" and "gelu_new" are supported. 64 | hidden_dropout_prob: The dropout probabilitiy for all fully connected 65 | layers in the embeddings, encoder, and pooler. 66 | attention_probs_dropout_prob: The dropout ratio for the attention 67 | probabilities. 68 | max_position_embeddings: The maximum sequence length that this model might 69 | ever be used with. Typically set this to something large just in case 70 | (e.g., 512 or 1024 or 2048). 71 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 72 | `BertModel`. 73 | initializer_range: The sttdev of the truncated_normal_initializer for 74 | initializing all weight matrices. 75 | layer_norm_eps: The epsilon used by LayerNorm. 76 | """ 77 | pretrained_config_archive_map = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP 78 | 79 | def __init__(self, 80 | vocab_size_or_config_json_file=30522, 81 | hidden_size=768, 82 | num_hidden_layers=12, 83 | num_attention_heads=12, 84 | intermediate_size=3072, 85 | hidden_act="gelu", 86 | hidden_dropout_prob=0.1, 87 | attention_probs_dropout_prob=0.1, 88 | max_position_embeddings=512, 89 | type_vocab_size=2, 90 | initializer_range=0.02, 91 | layer_norm_eps=1e-12, 92 | **kwargs): 93 | super(BertConfig, self).__init__(**kwargs) 94 | if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 95 | and isinstance(vocab_size_or_config_json_file, unicode)): 96 | with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: 97 | json_config = json.loads(reader.read()) 98 | for key, value in json_config.items(): 99 | self.__dict__[key] = value 100 | elif isinstance(vocab_size_or_config_json_file, int): 101 | self.vocab_size = vocab_size_or_config_json_file 102 | self.hidden_size = hidden_size 103 | self.num_hidden_layers = num_hidden_layers 104 | self.num_attention_heads = num_attention_heads 105 | self.hidden_act = hidden_act 106 | self.intermediate_size = intermediate_size 107 | self.hidden_dropout_prob = hidden_dropout_prob 108 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 109 | self.max_position_embeddings = max_position_embeddings 110 | self.type_vocab_size = type_vocab_size 111 | self.initializer_range = initializer_range 112 | self.layer_norm_eps = layer_norm_eps 113 | else: 114 | raise ValueError("First argument must be either a vocabulary size (int)" 115 | " or the path to a pretrained model config file (str)") 116 | -------------------------------------------------------------------------------- /transformers/configuration_camembert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ CamemBERT configuration """ 17 | 18 | from __future__ import (absolute_import, division, print_function, 19 | unicode_literals) 20 | 21 | import logging 22 | 23 | from .configuration_roberta import RobertaConfig 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { 28 | 'camembert-base': "https://s3.amazonaws.com/models.huggingface.co/bert/camembert-base-config.json", 29 | } 30 | 31 | 32 | class CamembertConfig(RobertaConfig): 33 | pretrained_config_archive_map = CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP 34 | -------------------------------------------------------------------------------- /transformers/configuration_ctrl.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Salesforce and HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ Salesforce CTRL configuration """ 16 | 17 | from __future__ import absolute_import, division, print_function, unicode_literals 18 | 19 | import json 20 | import logging 21 | import sys 22 | from io import open 23 | 24 | from .configuration_utils import PretrainedConfig 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP = {"ctrl": "https://storage.googleapis.com/sf-ctrl/pytorch/ctrl-config.json"} 29 | 30 | class CTRLConfig(PretrainedConfig): 31 | """Configuration class to store the configuration of a `CTRLModel`. 32 | 33 | Args: 34 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `CTRLModel` or a configuration json file. 35 | n_positions: Number of positional embeddings. 36 | n_ctx: Size of the causal mask (usually same as n_positions). 37 | dff: Size of the inner dimension of the FFN. 38 | n_embd: Dimensionality of the embeddings and hidden states. 39 | n_layer: Number of hidden layers in the Transformer encoder. 40 | n_head: Number of attention heads for each attention layer in 41 | the Transformer encoder. 42 | layer_norm_epsilon: epsilon to use in the layer norm layers 43 | resid_pdrop: The dropout probabilitiy for all fully connected 44 | layers in the embeddings, encoder, and pooler. 45 | attn_pdrop: The dropout ratio for the attention 46 | probabilities. 47 | embd_pdrop: The dropout ratio for the embeddings. 48 | initializer_range: The sttdev of the truncated_normal_initializer for 49 | initializing all weight matrices. 50 | """ 51 | pretrained_config_archive_map = CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP 52 | 53 | def __init__( 54 | self, 55 | vocab_size_or_config_json_file=246534, 56 | n_positions=256, 57 | n_ctx=256, 58 | n_embd=1280, 59 | dff=8192, 60 | n_layer=48, 61 | n_head=16, 62 | resid_pdrop=0.1, 63 | embd_pdrop=0.1, 64 | attn_pdrop=0.1, 65 | layer_norm_epsilon=1e-6, 66 | initializer_range=0.02, 67 | 68 | num_labels=1, 69 | summary_type='cls_index', 70 | summary_use_proj=True, 71 | summary_activation=None, 72 | summary_proj_to_labels=True, 73 | summary_first_dropout=0.1, 74 | **kwargs 75 | ): 76 | """Constructs CTRLConfig. 77 | 78 | Args: 79 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `CTRLModel` or a configuration json file. 80 | n_positions: Number of positional embeddings. 81 | n_ctx: Size of the causal mask (usually same as n_positions). 82 | dff: Size of the inner dimension of the FFN. 83 | n_embd: Dimensionality of the embeddings and hidden states. 84 | n_layer: Number of hidden layers in the Transformer encoder. 85 | n_head: Number of attention heads for each attention layer in 86 | the Transformer encoder. 87 | layer_norm_epsilon: epsilon to use in the layer norm layers 88 | resid_pdrop: The dropout probabilitiy for all fully connected 89 | layers in the embeddings, encoder, and pooler. 90 | attn_pdrop: The dropout ratio for the attention 91 | probabilities. 92 | embd_pdrop: The dropout ratio for the embeddings. 93 | initializer_range: The sttdev of the truncated_normal_initializer for 94 | initializing all weight matrices. 95 | """ 96 | super(CTRLConfig, self).__init__(**kwargs) 97 | 98 | self.vocab_size = vocab_size_or_config_json_file if isinstance(vocab_size_or_config_json_file, int) else -1 99 | self.n_ctx = n_ctx 100 | self.n_positions = n_positions 101 | self.n_embd = n_embd 102 | self.n_layer = n_layer 103 | self.n_head = n_head 104 | self.dff = dff 105 | self.resid_pdrop = resid_pdrop 106 | self.embd_pdrop = embd_pdrop 107 | self.attn_pdrop = attn_pdrop 108 | self.layer_norm_epsilon = layer_norm_epsilon 109 | self.initializer_range = initializer_range 110 | 111 | self.num_labels = num_labels 112 | self.summary_type = summary_type 113 | self.summary_use_proj = summary_use_proj 114 | self.summary_activation = summary_activation 115 | self.summary_first_dropout = summary_first_dropout 116 | self.summary_proj_to_labels = summary_proj_to_labels 117 | if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 118 | and isinstance(vocab_size_or_config_json_file, unicode)): 119 | with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader: 120 | json_config = json.loads(reader.read()) 121 | for key, value in json_config.items(): 122 | self.__dict__[key] = value 123 | elif not isinstance(vocab_size_or_config_json_file, int): 124 | raise ValueError( 125 | "First argument must be either a vocabulary size (int)" 126 | "or the path to a pretrained model config file (str)" 127 | ) 128 | 129 | @property 130 | def max_position_embeddings(self): 131 | return self.n_positions 132 | 133 | @property 134 | def hidden_size(self): 135 | return self.n_embd 136 | 137 | @property 138 | def num_attention_heads(self): 139 | return self.n_head 140 | 141 | @property 142 | def num_hidden_layers(self): 143 | return self.n_layer 144 | -------------------------------------------------------------------------------- /transformers/configuration_distilbert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ DistilBERT model configuration """ 16 | from __future__ import (absolute_import, division, print_function, 17 | unicode_literals) 18 | 19 | import sys 20 | import json 21 | import logging 22 | from io import open 23 | 24 | from .configuration_utils import PretrainedConfig 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { 29 | 'distilbert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-config.json", 30 | 'distilbert-base-uncased-distilled-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-distilled-squad-config.json" 31 | } 32 | 33 | 34 | class DistilBertConfig(PretrainedConfig): 35 | pretrained_config_archive_map = DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP 36 | 37 | def __init__(self, 38 | vocab_size_or_config_json_file=30522, 39 | max_position_embeddings=512, 40 | sinusoidal_pos_embds=False, 41 | n_layers=6, 42 | n_heads=12, 43 | dim=768, 44 | hidden_dim=4*768, 45 | dropout=0.1, 46 | attention_dropout=0.1, 47 | activation='gelu', 48 | initializer_range=0.02, 49 | tie_weights_=True, 50 | qa_dropout=0.1, 51 | seq_classif_dropout=0.2, 52 | **kwargs): 53 | super(DistilBertConfig, self).__init__(**kwargs) 54 | 55 | if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 56 | and isinstance(vocab_size_or_config_json_file, unicode)): 57 | with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: 58 | json_config = json.loads(reader.read()) 59 | for key, value in json_config.items(): 60 | self.__dict__[key] = value 61 | elif isinstance(vocab_size_or_config_json_file, int): 62 | self.vocab_size = vocab_size_or_config_json_file 63 | self.max_position_embeddings = max_position_embeddings 64 | self.sinusoidal_pos_embds = sinusoidal_pos_embds 65 | self.n_layers = n_layers 66 | self.n_heads = n_heads 67 | self.dim = dim 68 | self.hidden_dim = hidden_dim 69 | self.dropout = dropout 70 | self.attention_dropout = attention_dropout 71 | self.activation = activation 72 | self.initializer_range = initializer_range 73 | self.tie_weights_ = tie_weights_ 74 | self.qa_dropout = qa_dropout 75 | self.seq_classif_dropout = seq_classif_dropout 76 | else: 77 | raise ValueError("First argument must be either a vocabulary size (int)" 78 | " or the path to a pretrained model config file (str)") 79 | @property 80 | def hidden_size(self): 81 | return self.dim 82 | 83 | @property 84 | def num_attention_heads(self): 85 | return self.n_heads 86 | 87 | @property 88 | def num_hidden_layers(self): 89 | return self.n_layers 90 | -------------------------------------------------------------------------------- /transformers/configuration_gpt2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ OpenAI GPT-2 configuration """ 17 | 18 | from __future__ import absolute_import, division, print_function, unicode_literals 19 | 20 | import json 21 | import logging 22 | import sys 23 | from io import open 24 | 25 | from .configuration_utils import PretrainedConfig 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json", 30 | "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json", 31 | "gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-config.json", 32 | "gpt2-xl": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-xl-config.json", 33 | "distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-config.json",} 34 | 35 | class GPT2Config(PretrainedConfig): 36 | """Configuration class to store the configuration of a `GPT2Model`. 37 | 38 | Args: 39 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `GPT2Model` or a configuration json file. 40 | n_positions: Number of positional embeddings. 41 | n_ctx: Size of the causal mask (usually same as n_positions). 42 | n_embd: Dimensionality of the embeddings and hidden states. 43 | n_layer: Number of hidden layers in the Transformer encoder. 44 | n_head: Number of attention heads for each attention layer in 45 | the Transformer encoder. 46 | layer_norm_epsilon: epsilon to use in the layer norm layers 47 | resid_pdrop: The dropout probabilitiy for all fully connected 48 | layers in the embeddings, encoder, and pooler. 49 | attn_pdrop: The dropout ratio for the attention 50 | probabilities. 51 | embd_pdrop: The dropout ratio for the embeddings. 52 | initializer_range: The sttdev of the truncated_normal_initializer for 53 | initializing all weight matrices. 54 | """ 55 | pretrained_config_archive_map = GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP 56 | 57 | def __init__( 58 | self, 59 | vocab_size_or_config_json_file=50257, 60 | n_positions=1024, 61 | n_ctx=1024, 62 | n_embd=768, 63 | n_layer=12, 64 | n_head=12, 65 | resid_pdrop=0.1, 66 | embd_pdrop=0.1, 67 | attn_pdrop=0.1, 68 | layer_norm_epsilon=1e-5, 69 | initializer_range=0.02, 70 | 71 | num_labels=1, 72 | summary_type='cls_index', 73 | summary_use_proj=True, 74 | summary_activation=None, 75 | summary_proj_to_labels=True, 76 | summary_first_dropout=0.1, 77 | **kwargs 78 | ): 79 | """Constructs GPT2Config. 80 | 81 | Args: 82 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `GPT2Model` or a configuration json file. 83 | n_positions: Number of positional embeddings. 84 | n_ctx: Size of the causal mask (usually same as n_positions). 85 | n_embd: Dimensionality of the embeddings and hidden states. 86 | n_layer: Number of hidden layers in the Transformer encoder. 87 | n_head: Number of attention heads for each attention layer in 88 | the Transformer encoder. 89 | layer_norm_epsilon: epsilon to use in the layer norm layers 90 | resid_pdrop: The dropout probabilitiy for all fully connected 91 | layers in the embeddings, encoder, and pooler. 92 | attn_pdrop: The dropout ratio for the attention 93 | probabilities. 94 | embd_pdrop: The dropout ratio for the embeddings. 95 | initializer_range: The sttdev of the truncated_normal_initializer for 96 | initializing all weight matrices. 97 | """ 98 | super(GPT2Config, self).__init__(**kwargs) 99 | 100 | if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 101 | and isinstance(vocab_size_or_config_json_file, unicode)): 102 | with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader: 103 | json_config = json.loads(reader.read()) 104 | for key, value in json_config.items(): 105 | self.__dict__[key] = value 106 | elif isinstance(vocab_size_or_config_json_file, int): 107 | self.vocab_size = vocab_size_or_config_json_file 108 | self.n_ctx = n_ctx 109 | self.n_positions = n_positions 110 | self.n_embd = n_embd 111 | self.n_layer = n_layer 112 | self.n_head = n_head 113 | self.resid_pdrop = resid_pdrop 114 | self.embd_pdrop = embd_pdrop 115 | self.attn_pdrop = attn_pdrop 116 | self.layer_norm_epsilon = layer_norm_epsilon 117 | self.initializer_range = initializer_range 118 | 119 | self.num_labels = num_labels 120 | self.summary_type = summary_type 121 | self.summary_use_proj = summary_use_proj 122 | self.summary_activation = summary_activation 123 | self.summary_first_dropout = summary_first_dropout 124 | self.summary_proj_to_labels = summary_proj_to_labels 125 | else: 126 | raise ValueError( 127 | "First argument must be either a vocabulary size (int)" 128 | "or the path to a pretrained model config file (str)" 129 | ) 130 | 131 | @property 132 | def max_position_embeddings(self): 133 | return self.n_positions 134 | 135 | @property 136 | def hidden_size(self): 137 | return self.n_embd 138 | 139 | @property 140 | def num_attention_heads(self): 141 | return self.n_head 142 | 143 | @property 144 | def num_hidden_layers(self): 145 | return self.n_layer 146 | -------------------------------------------------------------------------------- /transformers/configuration_openai.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ OpenAI GPT configuration """ 17 | 18 | from __future__ import absolute_import, division, print_function, unicode_literals 19 | 20 | import json 21 | import logging 22 | import sys 23 | from io import open 24 | 25 | from .configuration_utils import PretrainedConfig 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP = { 30 | "openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-config.json" 31 | } 32 | 33 | class OpenAIGPTConfig(PretrainedConfig): 34 | """ 35 | Configuration class to store the configuration of a `OpenAIGPTModel`. 36 | 37 | Args: 38 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `OpenAIGPTModel` or a configuration json file. 39 | n_positions: Number of positional embeddings. 40 | n_ctx: Size of the causal mask (usually same as n_positions). 41 | n_embd: Dimensionality of the embeddings and hidden states. 42 | n_layer: Number of hidden layers in the Transformer encoder. 43 | n_head: Number of attention heads for each attention layer in 44 | the Transformer encoder. 45 | afn: The non-linear activation function (function or string) in the 46 | encoder and pooler. If string, "gelu", "relu" and "swish" are supported. 47 | resid_pdrop: The dropout probabilitiy for all fully connected 48 | layers in the embeddings, encoder, and pooler. 49 | attn_pdrop: The dropout ratio for the attention 50 | probabilities. 51 | embd_pdrop: The dropout ratio for the embeddings. 52 | layer_norm_epsilon: epsilon to use in the layer norm layers 53 | initializer_range: The sttdev of the truncated_normal_initializer for 54 | initializing all weight matrices. 55 | predict_special_tokens: should we predict special tokens (when the model has a LM head) 56 | """ 57 | pretrained_config_archive_map = OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP 58 | 59 | def __init__( 60 | self, 61 | vocab_size_or_config_json_file=40478, 62 | n_positions=512, 63 | n_ctx=512, 64 | n_embd=768, 65 | n_layer=12, 66 | n_head=12, 67 | afn="gelu", 68 | resid_pdrop=0.1, 69 | embd_pdrop=0.1, 70 | attn_pdrop=0.1, 71 | layer_norm_epsilon=1e-5, 72 | initializer_range=0.02, 73 | predict_special_tokens=True, 74 | 75 | num_labels=1, 76 | summary_type='cls_index', 77 | summary_use_proj=True, 78 | summary_activation=None, 79 | summary_proj_to_labels=True, 80 | summary_first_dropout=0.1, 81 | **kwargs 82 | ): 83 | """Constructs OpenAIGPTConfig. 84 | """ 85 | super(OpenAIGPTConfig, self).__init__(**kwargs) 86 | 87 | if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 88 | and isinstance(vocab_size_or_config_json_file, unicode)): 89 | with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader: 90 | json_config = json.loads(reader.read()) 91 | for key, value in json_config.items(): 92 | self.__dict__[key] = value 93 | elif isinstance(vocab_size_or_config_json_file, int): 94 | self.vocab_size = vocab_size_or_config_json_file 95 | self.n_ctx = n_ctx 96 | self.n_positions = n_positions 97 | self.n_embd = n_embd 98 | self.n_layer = n_layer 99 | self.n_head = n_head 100 | self.afn = afn 101 | self.resid_pdrop = resid_pdrop 102 | self.embd_pdrop = embd_pdrop 103 | self.attn_pdrop = attn_pdrop 104 | self.layer_norm_epsilon = layer_norm_epsilon 105 | self.initializer_range = initializer_range 106 | self.predict_special_tokens = predict_special_tokens 107 | 108 | self.num_labels = num_labels 109 | self.summary_type = summary_type 110 | self.summary_use_proj = summary_use_proj 111 | self.summary_activation = summary_activation 112 | self.summary_first_dropout = summary_first_dropout 113 | self.summary_proj_to_labels = summary_proj_to_labels 114 | else: 115 | raise ValueError( 116 | "First argument must be either a vocabulary size (int)" 117 | "or the path to a pretrained model config file (str)" 118 | ) 119 | 120 | @property 121 | def max_position_embeddings(self): 122 | return self.n_positions 123 | 124 | @property 125 | def hidden_size(self): 126 | return self.n_embd 127 | 128 | @property 129 | def num_attention_heads(self): 130 | return self.n_head 131 | 132 | @property 133 | def num_hidden_layers(self): 134 | return self.n_layer 135 | -------------------------------------------------------------------------------- /transformers/configuration_roberta.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ RoBERTa configuration """ 17 | 18 | from __future__ import (absolute_import, division, print_function, 19 | unicode_literals) 20 | 21 | import logging 22 | 23 | from .configuration_bert import BertConfig 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = { 28 | 'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-config.json", 29 | 'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-config.json", 30 | 'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-config.json", 31 | 'distilroberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/distilroberta-base-config.json", 32 | 'roberta-base-openai-detector': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-openai-detector-config.json", 33 | 'roberta-large-openai-detector': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-openai-detector-config.json", 34 | } 35 | 36 | 37 | class RobertaConfig(BertConfig): 38 | pretrained_config_archive_map = ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP 39 | -------------------------------------------------------------------------------- /transformers/configuration_transfo_xl.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Transformer XL configuration """ 17 | 18 | from __future__ import absolute_import, division, print_function, unicode_literals 19 | 20 | import json 21 | import logging 22 | import sys 23 | from io import open 24 | 25 | from .configuration_utils import PretrainedConfig 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP = { 30 | 'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-config.json", 31 | } 32 | 33 | class TransfoXLConfig(PretrainedConfig): 34 | """Configuration class to store the configuration of a `TransfoXLModel`. 35 | 36 | Args: 37 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `TransfoXLModel` or a configuration json file. 38 | cutoffs: cutoffs for the adaptive softmax 39 | d_model: Dimensionality of the model's hidden states. 40 | d_embed: Dimensionality of the embeddings 41 | d_head: Dimensionality of the model's heads. 42 | div_val: divident value for adapative input and softmax 43 | pre_lnorm: apply LayerNorm to the input instead of the output 44 | d_inner: Inner dimension in FF 45 | n_layer: Number of hidden layers in the Transformer encoder. 46 | n_head: Number of attention heads for each attention layer in 47 | the Transformer encoder. 48 | tgt_len: number of tokens to predict 49 | ext_len: length of the extended context 50 | mem_len: length of the retained previous heads 51 | same_length: use the same attn length for all tokens 52 | proj_share_all_but_first: True to share all but first projs, False not to share. 53 | attn_type: attention type. 0 for Transformer-XL, 1 for Shaw et al, 2 for Vaswani et al, 3 for Al Rfou et al. 54 | clamp_len: use the same pos embeddings after clamp_len 55 | sample_softmax: number of samples in sampled softmax 56 | adaptive: use adaptive softmax 57 | tie_weight: tie the word embedding and softmax weights 58 | dropout: The dropout probabilitiy for all fully connected 59 | layers in the embeddings, encoder, and pooler. 60 | dropatt: The dropout ratio for the attention probabilities. 61 | untie_r: untie relative position biases 62 | embd_pdrop: The dropout ratio for the embeddings. 63 | init: parameter initializer to use 64 | init_range: parameters initialized by U(-init_range, init_range). 65 | proj_init_std: parameters initialized by N(0, init_std) 66 | init_std: parameters initialized by N(0, init_std) 67 | """ 68 | pretrained_config_archive_map = TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP 69 | 70 | def __init__(self, 71 | vocab_size_or_config_json_file=267735, 72 | cutoffs=[20000, 40000, 200000], 73 | d_model=1024, 74 | d_embed=1024, 75 | n_head=16, 76 | d_head=64, 77 | d_inner=4096, 78 | div_val=4, 79 | pre_lnorm=False, 80 | n_layer=18, 81 | tgt_len=128, 82 | ext_len=0, 83 | mem_len=1600, 84 | clamp_len=1000, 85 | same_length=True, 86 | proj_share_all_but_first=True, 87 | attn_type=0, 88 | sample_softmax=-1, 89 | adaptive=True, 90 | tie_weight=True, 91 | dropout=0.1, 92 | dropatt=0.0, 93 | untie_r=True, 94 | init="normal", 95 | init_range=0.01, 96 | proj_init_std=0.01, 97 | init_std=0.02, 98 | layer_norm_epsilon=1e-5, 99 | **kwargs): 100 | """Constructs TransfoXLConfig. 101 | """ 102 | super(TransfoXLConfig, self).__init__(**kwargs) 103 | self.n_token = vocab_size_or_config_json_file if isinstance(vocab_size_or_config_json_file, int) else -1 104 | self.cutoffs = [] 105 | self.cutoffs.extend(cutoffs) 106 | self.tie_weight = tie_weight 107 | if proj_share_all_but_first: 108 | self.tie_projs = [False] + [True] * len(self.cutoffs) 109 | else: 110 | self.tie_projs = [False] + [False] * len(self.cutoffs) 111 | self.d_model = d_model 112 | self.d_embed = d_embed 113 | self.d_head = d_head 114 | self.d_inner = d_inner 115 | self.div_val = div_val 116 | self.pre_lnorm = pre_lnorm 117 | self.n_layer = n_layer 118 | self.n_head = n_head 119 | self.tgt_len = tgt_len 120 | self.ext_len = ext_len 121 | self.mem_len = mem_len 122 | self.same_length = same_length 123 | self.attn_type = attn_type 124 | self.clamp_len = clamp_len 125 | self.sample_softmax = sample_softmax 126 | self.adaptive = adaptive 127 | self.dropout = dropout 128 | self.dropatt = dropatt 129 | self.untie_r = untie_r 130 | self.init = init 131 | self.init_range = init_range 132 | self.proj_init_std = proj_init_std 133 | self.init_std = init_std 134 | self.layer_norm_epsilon = layer_norm_epsilon 135 | 136 | if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 137 | and isinstance(vocab_size_or_config_json_file, unicode)): 138 | with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: 139 | json_config = json.loads(reader.read()) 140 | for key, value in json_config.items(): 141 | self.__dict__[key] = value 142 | elif not isinstance(vocab_size_or_config_json_file, int): 143 | raise ValueError("First argument must be either a vocabulary size (int)" 144 | " or the path to a pretrained model config file (str)") 145 | 146 | @property 147 | def max_position_embeddings(self): 148 | return self.tgt_len + self.ext_len + self.mem_len 149 | 150 | @property 151 | def vocab_size(self): 152 | return self.n_token 153 | 154 | @vocab_size.setter 155 | def vocab_size(self, value): 156 | self.n_token = value 157 | 158 | @property 159 | def hidden_size(self): 160 | return self.d_model 161 | 162 | @property 163 | def num_attention_heads(self): 164 | return self.n_head 165 | 166 | @property 167 | def num_hidden_layers(self): 168 | return self.n_layer 169 | -------------------------------------------------------------------------------- /transformers/configuration_xlnet.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ XLNet configuration """ 17 | from __future__ import absolute_import, division, print_function, unicode_literals 18 | 19 | import json 20 | import logging 21 | import sys 22 | from io import open 23 | 24 | from .configuration_utils import PretrainedConfig 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP = { 29 | 'xlnet-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-config.json", 30 | 'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-config.json", 31 | } 32 | 33 | 34 | class XLNetConfig(PretrainedConfig): 35 | """Configuration class to store the configuration of a ``XLNetModel``. 36 | 37 | Args: 38 | vocab_size_or_config_json_file: Vocabulary size of ``inputs_ids`` in ``XLNetModel``. 39 | d_model: Size of the encoder layers and the pooler layer. 40 | n_layer: Number of hidden layers in the Transformer encoder. 41 | n_head: Number of attention heads for each attention layer in 42 | the Transformer encoder. 43 | d_inner: The size of the "intermediate" (i.e., feed-forward) 44 | layer in the Transformer encoder. 45 | ff_activation: The non-linear activation function (function or string) in the 46 | encoder and pooler. If string, "gelu", "relu" and "swish" are supported. 47 | untie_r: untie relative position biases 48 | attn_type: 'bi' for XLNet, 'uni' for Transformer-XL 49 | 50 | dropout: The dropout probabilitiy for all fully connected 51 | layers in the embeddings, encoder, and pooler. 52 | initializer_range: The sttdev of the truncated_normal_initializer for 53 | initializing all weight matrices. 54 | layer_norm_eps: The epsilon used by LayerNorm. 55 | 56 | dropout: float, dropout rate. 57 | init: str, the initialization scheme, either "normal" or "uniform". 58 | init_range: float, initialize the parameters with a uniform distribution 59 | in [-init_range, init_range]. Only effective when init="uniform". 60 | init_std: float, initialize the parameters with a normal distribution 61 | with mean 0 and stddev init_std. Only effective when init="normal". 62 | mem_len: int, the number of tokens to cache. 63 | reuse_len: int, the number of tokens in the currect batch to be cached 64 | and reused in the future. 65 | bi_data: bool, whether to use bidirectional input pipeline. 66 | Usually set to True during pretraining and False during finetuning. 67 | clamp_len: int, clamp all relative distances larger than clamp_len. 68 | -1 means no clamping. 69 | same_length: bool, whether to use the same attention length for each token. 70 | finetuning_task: name of the glue task on which the model was fine-tuned if any 71 | """ 72 | pretrained_config_archive_map = XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP 73 | 74 | def __init__(self, 75 | vocab_size_or_config_json_file=32000, 76 | d_model=1024, 77 | n_layer=24, 78 | n_head=16, 79 | d_inner=4096, 80 | max_position_embeddings=512, 81 | ff_activation="gelu", 82 | untie_r=True, 83 | attn_type="bi", 84 | 85 | initializer_range=0.02, 86 | layer_norm_eps=1e-12, 87 | 88 | dropout=0.1, 89 | mem_len=None, 90 | reuse_len=None, 91 | bi_data=False, 92 | clamp_len=-1, 93 | same_length=False, 94 | 95 | finetuning_task=None, 96 | num_labels=2, 97 | summary_type='last', 98 | summary_use_proj=True, 99 | summary_activation='tanh', 100 | summary_last_dropout=0.1, 101 | start_n_top=5, 102 | end_n_top=5, 103 | **kwargs): 104 | """Constructs XLNetConfig. 105 | """ 106 | super(XLNetConfig, self).__init__(**kwargs) 107 | 108 | if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 109 | and isinstance(vocab_size_or_config_json_file, unicode)): 110 | with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: 111 | json_config = json.loads(reader.read()) 112 | for key, value in json_config.items(): 113 | setattr(config, key, value) 114 | elif isinstance(vocab_size_or_config_json_file, int): 115 | self.n_token = vocab_size_or_config_json_file 116 | self.d_model = d_model 117 | self.n_layer = n_layer 118 | self.n_head = n_head 119 | assert d_model % n_head == 0 120 | self.d_head = d_model // n_head 121 | self.ff_activation = ff_activation 122 | self.d_inner = d_inner 123 | self.untie_r = untie_r 124 | self.attn_type = attn_type 125 | 126 | self.initializer_range = initializer_range 127 | self.layer_norm_eps = layer_norm_eps 128 | 129 | self.dropout = dropout 130 | self.mem_len = mem_len 131 | self.reuse_len = reuse_len 132 | self.bi_data = bi_data 133 | self.clamp_len = clamp_len 134 | self.same_length = same_length 135 | 136 | self.finetuning_task = finetuning_task 137 | self.num_labels = num_labels 138 | self.summary_type = summary_type 139 | self.summary_use_proj = summary_use_proj 140 | self.summary_activation = summary_activation 141 | self.summary_last_dropout = summary_last_dropout 142 | self.start_n_top = start_n_top 143 | self.end_n_top = end_n_top 144 | else: 145 | raise ValueError("First argument must be either a vocabulary size (int)" 146 | " or the path to a pretrained model config file (str)") 147 | 148 | @property 149 | def max_position_embeddings(self): 150 | return -1 151 | 152 | @property 153 | def vocab_size(self): 154 | return self.n_token 155 | 156 | @vocab_size.setter 157 | def vocab_size(self, value): 158 | self.n_token = value 159 | 160 | @property 161 | def hidden_size(self): 162 | return self.d_model 163 | 164 | @property 165 | def num_attention_heads(self): 166 | return self.n_head 167 | 168 | @property 169 | def num_hidden_layers(self): 170 | return self.n_layer 171 | -------------------------------------------------------------------------------- /transformers/convert_bert_original_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert BERT checkpoint.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import argparse 22 | import torch 23 | 24 | from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert 25 | 26 | import logging 27 | logging.basicConfig(level=logging.INFO) 28 | 29 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): 30 | # Initialise PyTorch model 31 | config = BertConfig.from_json_file(bert_config_file) 32 | print("Building PyTorch model from configuration: {}".format(str(config))) 33 | model = BertForPreTraining(config) 34 | 35 | # Load weights from tf checkpoint 36 | load_tf_weights_in_bert(model, config, tf_checkpoint_path) 37 | 38 | # Save pytorch-model 39 | print("Save PyTorch model to {}".format(pytorch_dump_path)) 40 | torch.save(model.state_dict(), pytorch_dump_path) 41 | 42 | 43 | if __name__ == "__main__": 44 | parser = argparse.ArgumentParser() 45 | ## Required parameters 46 | parser.add_argument("--tf_checkpoint_path", 47 | default = None, 48 | type = str, 49 | required = True, 50 | help = "Path to the TensorFlow checkpoint path.") 51 | parser.add_argument("--bert_config_file", 52 | default = None, 53 | type = str, 54 | required = True, 55 | help = "The config json file corresponding to the pre-trained BERT model. \n" 56 | "This specifies the model architecture.") 57 | parser.add_argument("--pytorch_dump_path", 58 | default = None, 59 | type = str, 60 | required = True, 61 | help = "Path to the output PyTorch model.") 62 | args = parser.parse_args() 63 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, 64 | args.bert_config_file, 65 | args.pytorch_dump_path) 66 | -------------------------------------------------------------------------------- /transformers/convert_bert_pytorch_checkpoint_to_original_tf.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Convert Huggingface Pytorch checkpoint to Tensorflow checkpoint.""" 17 | 18 | import os 19 | import argparse 20 | import torch 21 | import numpy as np 22 | import tensorflow as tf 23 | from transformers import BertModel 24 | 25 | 26 | def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:str): 27 | 28 | """ 29 | :param model:BertModel Pytorch model instance to be converted 30 | :param ckpt_dir: Tensorflow model directory 31 | :param model_name: model name 32 | :return: 33 | 34 | Currently supported HF models: 35 | Y BertModel 36 | N BertForMaskedLM 37 | N BertForPreTraining 38 | N BertForMultipleChoice 39 | N BertForNextSentencePrediction 40 | N BertForSequenceClassification 41 | N BertForQuestionAnswering 42 | """ 43 | 44 | tensors_to_transpose = ( 45 | "dense.weight", 46 | "attention.self.query", 47 | "attention.self.key", 48 | "attention.self.value" 49 | ) 50 | 51 | var_map = ( 52 | ('layer.', 'layer_'), 53 | ('word_embeddings.weight', 'word_embeddings'), 54 | ('position_embeddings.weight', 'position_embeddings'), 55 | ('token_type_embeddings.weight', 'token_type_embeddings'), 56 | ('.', '/'), 57 | ('LayerNorm/weight', 'LayerNorm/gamma'), 58 | ('LayerNorm/bias', 'LayerNorm/beta'), 59 | ('weight', 'kernel') 60 | ) 61 | 62 | if not os.path.isdir(ckpt_dir): 63 | os.makedirs(ckpt_dir) 64 | 65 | state_dict = model.state_dict() 66 | 67 | def to_tf_var_name(name:str): 68 | for patt, repl in iter(var_map): 69 | name = name.replace(patt, repl) 70 | return 'bert/{}'.format(name) 71 | 72 | def create_tf_var(tensor:np.ndarray, name:str, session:tf.Session): 73 | tf_dtype = tf.dtypes.as_dtype(tensor.dtype) 74 | tf_var = tf.get_variable(dtype=tf_dtype, shape=tensor.shape, name=name, initializer=tf.zeros_initializer()) 75 | session.run(tf.variables_initializer([tf_var])) 76 | session.run(tf_var) 77 | return tf_var 78 | 79 | tf.reset_default_graph() 80 | with tf.Session() as session: 81 | for var_name in state_dict: 82 | tf_name = to_tf_var_name(var_name) 83 | torch_tensor = state_dict[var_name].numpy() 84 | if any([x in var_name for x in tensors_to_transpose]): 85 | torch_tensor = torch_tensor.T 86 | tf_var = create_tf_var(tensor=torch_tensor, name=tf_name, session=session) 87 | tf.keras.backend.set_value(tf_var, torch_tensor) 88 | tf_weight = session.run(tf_var) 89 | print("Successfully created {}: {}".format(tf_name, np.allclose(tf_weight, torch_tensor))) 90 | 91 | saver = tf.train.Saver(tf.trainable_variables()) 92 | saver.save(session, os.path.join(ckpt_dir, model_name.replace("-", "_") + ".ckpt")) 93 | 94 | 95 | def main(raw_args=None): 96 | parser = argparse.ArgumentParser() 97 | parser.add_argument("--model_name", 98 | type=str, 99 | required=True, 100 | help="model name e.g. bert-base-uncased") 101 | parser.add_argument("--cache_dir", 102 | type=str, 103 | default=None, 104 | required=False, 105 | help="Directory containing pytorch model") 106 | parser.add_argument("--pytorch_model_path", 107 | type=str, 108 | required=True, 109 | help="/path/to/.bin") 110 | parser.add_argument("--tf_cache_dir", 111 | type=str, 112 | required=True, 113 | help="Directory in which to save tensorflow model") 114 | args = parser.parse_args(raw_args) 115 | 116 | model = BertModel.from_pretrained( 117 | pretrained_model_name_or_path=args.model_name, 118 | state_dict=torch.load(args.pytorch_model_path), 119 | cache_dir=args.cache_dir 120 | ) 121 | 122 | convert_pytorch_checkpoint_to_tf( 123 | model=model, 124 | ckpt_dir=args.tf_cache_dir, 125 | model_name=args.model_name 126 | ) 127 | 128 | 129 | if __name__ == "__main__": 130 | main() 131 | -------------------------------------------------------------------------------- /transformers/convert_gpt2_original_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert OpenAI GPT checkpoint.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | from io import open 21 | 22 | import torch 23 | 24 | from transformers import (CONFIG_NAME, WEIGHTS_NAME, 25 | GPT2Config, 26 | GPT2Model, 27 | load_tf_weights_in_gpt2) 28 | 29 | import logging 30 | logging.basicConfig(level=logging.INFO) 31 | 32 | 33 | def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, pytorch_dump_folder_path): 34 | # Construct model 35 | if gpt2_config_file == "": 36 | config = GPT2Config() 37 | else: 38 | config = GPT2Config.from_json_file(gpt2_config_file) 39 | model = GPT2Model(config) 40 | 41 | # Load weights from numpy 42 | load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path) 43 | 44 | # Save pytorch-model 45 | pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME 46 | pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME 47 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) 48 | torch.save(model.state_dict(), pytorch_weights_dump_path) 49 | print("Save configuration file to {}".format(pytorch_config_dump_path)) 50 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 51 | f.write(config.to_json_string()) 52 | 53 | 54 | if __name__ == "__main__": 55 | parser = argparse.ArgumentParser() 56 | ## Required parameters 57 | parser.add_argument("--gpt2_checkpoint_path", 58 | default = None, 59 | type = str, 60 | required = True, 61 | help = "Path to the TensorFlow checkpoint path.") 62 | parser.add_argument("--pytorch_dump_folder_path", 63 | default = None, 64 | type = str, 65 | required = True, 66 | help = "Path to the output PyTorch model.") 67 | parser.add_argument("--gpt2_config_file", 68 | default = "", 69 | type = str, 70 | help = "An optional config json file corresponding to the pre-trained OpenAI model. \n" 71 | "This specifies the model architecture.") 72 | args = parser.parse_args() 73 | convert_gpt2_checkpoint_to_pytorch(args.gpt2_checkpoint_path, 74 | args.gpt2_config_file, 75 | args.pytorch_dump_folder_path) 76 | -------------------------------------------------------------------------------- /transformers/convert_openai_original_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert OpenAI GPT checkpoint.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | from io import open 21 | 22 | import torch 23 | 24 | from transformers import (CONFIG_NAME, WEIGHTS_NAME, 25 | OpenAIGPTConfig, 26 | OpenAIGPTModel, 27 | load_tf_weights_in_openai_gpt) 28 | 29 | import logging 30 | logging.basicConfig(level=logging.INFO) 31 | 32 | 33 | def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path): 34 | # Construct model 35 | if openai_config_file == "": 36 | config = OpenAIGPTConfig() 37 | else: 38 | config = OpenAIGPTConfig.from_json_file(openai_config_file) 39 | model = OpenAIGPTModel(config) 40 | 41 | # Load weights from numpy 42 | load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path) 43 | 44 | # Save pytorch-model 45 | pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME 46 | pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME 47 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) 48 | torch.save(model.state_dict(), pytorch_weights_dump_path) 49 | print("Save configuration file to {}".format(pytorch_config_dump_path)) 50 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 51 | f.write(config.to_json_string()) 52 | 53 | 54 | if __name__ == "__main__": 55 | parser = argparse.ArgumentParser() 56 | ## Required parameters 57 | parser.add_argument("--openai_checkpoint_folder_path", 58 | default = None, 59 | type = str, 60 | required = True, 61 | help = "Path to the TensorFlow checkpoint path.") 62 | parser.add_argument("--pytorch_dump_folder_path", 63 | default = None, 64 | type = str, 65 | required = True, 66 | help = "Path to the output PyTorch model.") 67 | parser.add_argument("--openai_config_file", 68 | default = "", 69 | type = str, 70 | help = "An optional config json file corresponding to the pre-trained OpenAI model. \n" 71 | "This specifies the model architecture.") 72 | args = parser.parse_args() 73 | convert_openai_checkpoint_to_pytorch(args.openai_checkpoint_folder_path, 74 | args.openai_config_file, 75 | args.pytorch_dump_folder_path) 76 | -------------------------------------------------------------------------------- /transformers/convert_transfo_xl_original_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert Transformer XL checkpoint and datasets.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | import os 21 | import sys 22 | from io import open 23 | 24 | import torch 25 | 26 | import transformers.tokenization_transfo_xl as data_utils 27 | 28 | from transformers import CONFIG_NAME, WEIGHTS_NAME 29 | from transformers import (TransfoXLConfig, TransfoXLLMHeadModel, 30 | load_tf_weights_in_transfo_xl) 31 | from transformers.tokenization_transfo_xl import (CORPUS_NAME, VOCAB_FILES_NAMES) 32 | 33 | if sys.version_info[0] == 2: 34 | import cPickle as pickle 35 | else: 36 | import pickle 37 | 38 | import logging 39 | logging.basicConfig(level=logging.INFO) 40 | 41 | # We do this to be able to load python 2 datasets pickles 42 | # See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918 43 | data_utils.Vocab = data_utils.TransfoXLTokenizer 44 | data_utils.Corpus = data_utils.TransfoXLCorpus 45 | sys.modules['data_utils'] = data_utils 46 | sys.modules['vocabulary'] = data_utils 47 | 48 | def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path, 49 | transfo_xl_config_file, 50 | pytorch_dump_folder_path, 51 | transfo_xl_dataset_file): 52 | if transfo_xl_dataset_file: 53 | # Convert a pre-processed corpus (see original TensorFlow repo) 54 | with open(transfo_xl_dataset_file, "rb") as fp: 55 | corpus = pickle.load(fp, encoding="latin1") 56 | # Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term) 57 | pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_FILES_NAMES['pretrained_vocab_file'] 58 | print("Save vocabulary to {}".format(pytorch_vocab_dump_path)) 59 | corpus_vocab_dict = corpus.vocab.__dict__ 60 | torch.save(corpus_vocab_dict, pytorch_vocab_dump_path) 61 | 62 | corpus_dict_no_vocab = corpus.__dict__ 63 | corpus_dict_no_vocab.pop('vocab', None) 64 | pytorch_dataset_dump_path = pytorch_dump_folder_path + '/' + CORPUS_NAME 65 | print("Save dataset to {}".format(pytorch_dataset_dump_path)) 66 | torch.save(corpus_dict_no_vocab, pytorch_dataset_dump_path) 67 | 68 | if tf_checkpoint_path: 69 | # Convert a pre-trained TensorFlow model 70 | config_path = os.path.abspath(transfo_xl_config_file) 71 | tf_path = os.path.abspath(tf_checkpoint_path) 72 | 73 | print("Converting Transformer XL checkpoint from {} with config at {}".format(tf_path, config_path)) 74 | # Initialise PyTorch model 75 | if transfo_xl_config_file == "": 76 | config = TransfoXLConfig() 77 | else: 78 | config = TransfoXLConfig.from_json_file(transfo_xl_config_file) 79 | print("Building PyTorch model from configuration: {}".format(str(config))) 80 | model = TransfoXLLMHeadModel(config) 81 | 82 | model = load_tf_weights_in_transfo_xl(model, config, tf_path) 83 | # Save pytorch-model 84 | pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME) 85 | pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME) 86 | print("Save PyTorch model to {}".format(os.path.abspath(pytorch_weights_dump_path))) 87 | torch.save(model.state_dict(), pytorch_weights_dump_path) 88 | print("Save configuration file to {}".format(os.path.abspath(pytorch_config_dump_path))) 89 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 90 | f.write(config.to_json_string()) 91 | 92 | 93 | if __name__ == "__main__": 94 | parser = argparse.ArgumentParser() 95 | parser.add_argument("--pytorch_dump_folder_path", 96 | default = None, 97 | type = str, 98 | required = True, 99 | help = "Path to the folder to store the PyTorch model or dataset/vocab.") 100 | parser.add_argument("--tf_checkpoint_path", 101 | default = "", 102 | type = str, 103 | help = "An optional path to a TensorFlow checkpoint path to be converted.") 104 | parser.add_argument("--transfo_xl_config_file", 105 | default = "", 106 | type = str, 107 | help = "An optional config json file corresponding to the pre-trained BERT model. \n" 108 | "This specifies the model architecture.") 109 | parser.add_argument("--transfo_xl_dataset_file", 110 | default = "", 111 | type = str, 112 | help = "An optional dataset file to be converted in a vocabulary.") 113 | args = parser.parse_args() 114 | convert_transfo_xl_checkpoint_to_pytorch(args.tf_checkpoint_path, 115 | args.transfo_xl_config_file, 116 | args.pytorch_dump_folder_path, 117 | args.transfo_xl_dataset_file) 118 | -------------------------------------------------------------------------------- /transformers/convert_xlm_original_pytorch_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert OpenAI GPT checkpoint.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | import json 21 | from io import open 22 | 23 | import torch 24 | import numpy 25 | 26 | from transformers import CONFIG_NAME, WEIGHTS_NAME 27 | from transformers.tokenization_xlm import VOCAB_FILES_NAMES 28 | 29 | import logging 30 | logging.basicConfig(level=logging.INFO) 31 | 32 | def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_path): 33 | # Load checkpoint 34 | chkpt = torch.load(xlm_checkpoint_path, map_location='cpu') 35 | 36 | state_dict = chkpt['model'] 37 | 38 | # We have the base model one level deeper than the original XLM repository 39 | two_levels_state_dict = {} 40 | for k, v in state_dict.items(): 41 | if 'pred_layer' in k: 42 | two_levels_state_dict[k] = v 43 | else: 44 | two_levels_state_dict['transformer.' + k] = v 45 | 46 | config = chkpt['params'] 47 | config = dict((n, v) for n, v in config.items() if not isinstance(v, (torch.FloatTensor, numpy.ndarray))) 48 | 49 | vocab = chkpt['dico_word2id'] 50 | vocab = dict((s + '' if s.find('@@') == -1 and i > 13 else s.replace('@@', ''), i) for s, i in vocab.items()) 51 | 52 | # Save pytorch-model 53 | pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME 54 | pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME 55 | pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_FILES_NAMES['vocab_file'] 56 | 57 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) 58 | torch.save(two_levels_state_dict, pytorch_weights_dump_path) 59 | 60 | print("Save configuration file to {}".format(pytorch_config_dump_path)) 61 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 62 | f.write(json.dumps(config, indent=2) + "\n") 63 | 64 | print("Save vocab file to {}".format(pytorch_config_dump_path)) 65 | with open(pytorch_vocab_dump_path, "w", encoding="utf-8") as f: 66 | f.write(json.dumps(vocab, indent=2) + "\n") 67 | 68 | 69 | if __name__ == "__main__": 70 | parser = argparse.ArgumentParser() 71 | ## Required parameters 72 | parser.add_argument("--xlm_checkpoint_path", 73 | default = None, 74 | type = str, 75 | required = True, 76 | help = "Path the official PyTorch dump.") 77 | parser.add_argument("--pytorch_dump_folder_path", 78 | default = None, 79 | type = str, 80 | required = True, 81 | help = "Path to the output PyTorch model.") 82 | args = parser.parse_args() 83 | convert_xlm_checkpoint_to_pytorch(args.xlm_checkpoint_path, args.pytorch_dump_folder_path) 84 | -------------------------------------------------------------------------------- /transformers/convert_xlnet_original_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert BERT checkpoint.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import argparse 23 | import torch 24 | 25 | from transformers import (CONFIG_NAME, WEIGHTS_NAME, 26 | XLNetConfig, 27 | XLNetLMHeadModel, XLNetForQuestionAnswering, 28 | XLNetForSequenceClassification, 29 | load_tf_weights_in_xlnet) 30 | 31 | GLUE_TASKS_NUM_LABELS = { 32 | "cola": 2, 33 | "mnli": 3, 34 | "mrpc": 2, 35 | "sst-2": 2, 36 | "sts-b": 1, 37 | "qqp": 2, 38 | "qnli": 2, 39 | "rte": 2, 40 | "wnli": 2, 41 | } 42 | 43 | import logging 44 | logging.basicConfig(level=logging.INFO) 45 | 46 | def convert_xlnet_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_folder_path, finetuning_task=None): 47 | # Initialise PyTorch model 48 | config = XLNetConfig.from_json_file(bert_config_file) 49 | 50 | finetuning_task = finetuning_task.lower() if finetuning_task is not None else "" 51 | if finetuning_task in GLUE_TASKS_NUM_LABELS: 52 | print("Building PyTorch XLNetForSequenceClassification model from configuration: {}".format(str(config))) 53 | config.finetuning_task = finetuning_task 54 | config.num_labels = GLUE_TASKS_NUM_LABELS[finetuning_task] 55 | model = XLNetForSequenceClassification(config) 56 | elif 'squad' in finetuning_task: 57 | config.finetuning_task = finetuning_task 58 | model = XLNetForQuestionAnswering(config) 59 | else: 60 | model = XLNetLMHeadModel(config) 61 | 62 | # Load weights from tf checkpoint 63 | load_tf_weights_in_xlnet(model, config, tf_checkpoint_path) 64 | 65 | # Save pytorch-model 66 | pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME) 67 | pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME) 68 | print("Save PyTorch model to {}".format(os.path.abspath(pytorch_weights_dump_path))) 69 | torch.save(model.state_dict(), pytorch_weights_dump_path) 70 | print("Save configuration file to {}".format(os.path.abspath(pytorch_config_dump_path))) 71 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 72 | f.write(config.to_json_string()) 73 | 74 | 75 | if __name__ == "__main__": 76 | parser = argparse.ArgumentParser() 77 | ## Required parameters 78 | parser.add_argument("--tf_checkpoint_path", 79 | default = None, 80 | type = str, 81 | required = True, 82 | help = "Path to the TensorFlow checkpoint path.") 83 | parser.add_argument("--xlnet_config_file", 84 | default = None, 85 | type = str, 86 | required = True, 87 | help = "The config json file corresponding to the pre-trained XLNet model. \n" 88 | "This specifies the model architecture.") 89 | parser.add_argument("--pytorch_dump_folder_path", 90 | default = None, 91 | type = str, 92 | required = True, 93 | help = "Path to the folder to store the PyTorch model or dataset/vocab.") 94 | parser.add_argument("--finetuning_task", 95 | default = None, 96 | type = str, 97 | help = "Name of a task on which the XLNet TensorFloaw model was fine-tuned") 98 | args = parser.parse_args() 99 | print(args) 100 | 101 | convert_xlnet_checkpoint_to_pytorch(args.tf_checkpoint_path, 102 | args.xlnet_config_file, 103 | args.pytorch_dump_folder_path, 104 | args.finetuning_task) 105 | -------------------------------------------------------------------------------- /transformers/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .processors import InputExample, InputFeatures, DataProcessor 2 | from .processors import glue_output_modes, glue_processors, glue_tasks_num_labels, glue_convert_examples_to_features 3 | 4 | from .metrics import is_sklearn_available 5 | if is_sklearn_available(): 6 | from .metrics import glue_compute_metrics 7 | -------------------------------------------------------------------------------- /transformers/data/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import csv 18 | import sys 19 | import logging 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | try: 24 | from scipy.stats import pearsonr, spearmanr 25 | from sklearn.metrics import matthews_corrcoef, f1_score 26 | _has_sklearn = True 27 | except (AttributeError, ImportError) as e: 28 | logger.warning("To use data.metrics please install scikit-learn. See https://scikit-learn.org/stable/index.html") 29 | _has_sklearn = False 30 | 31 | def is_sklearn_available(): 32 | return _has_sklearn 33 | 34 | if _has_sklearn: 35 | 36 | def simple_accuracy(preds, labels): 37 | return (preds == labels).mean() 38 | 39 | 40 | def acc_and_f1(preds, labels): 41 | acc = simple_accuracy(preds, labels) 42 | f1 = f1_score(y_true=labels, y_pred=preds) 43 | return { 44 | "acc": acc, 45 | "f1": f1, 46 | "acc_and_f1": (acc + f1) / 2, 47 | } 48 | 49 | 50 | def pearson_and_spearman(preds, labels): 51 | pearson_corr = pearsonr(preds, labels)[0] 52 | spearman_corr = spearmanr(preds, labels)[0] 53 | return { 54 | "pearson": pearson_corr, 55 | "spearmanr": spearman_corr, 56 | "corr": (pearson_corr + spearman_corr) / 2, 57 | } 58 | 59 | 60 | def glue_compute_metrics(task_name, preds, labels): 61 | assert len(preds) == len(labels) 62 | if task_name == "cola": 63 | return {"mcc": matthews_corrcoef(labels, preds)} 64 | elif task_name == "sst-2": 65 | return {"acc": simple_accuracy(preds, labels)} 66 | elif task_name == "mrpc": 67 | return acc_and_f1(preds, labels) 68 | elif task_name == "sts-b": 69 | return pearson_and_spearman(preds, labels) 70 | elif task_name == "qqp": 71 | return acc_and_f1(preds, labels) 72 | elif task_name == "mnli": 73 | return {"acc": simple_accuracy(preds, labels)} 74 | elif task_name == "mnli-mm": 75 | return {"acc": simple_accuracy(preds, labels)} 76 | elif task_name == "qnli": 77 | return {"acc": simple_accuracy(preds, labels)} 78 | elif task_name == "rte": 79 | return {"acc": simple_accuracy(preds, labels)} 80 | elif task_name == "wnli": 81 | return {"acc": simple_accuracy(preds, labels)} 82 | else: 83 | raise KeyError(task_name) 84 | -------------------------------------------------------------------------------- /transformers/data/processors/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import InputExample, InputFeatures, DataProcessor 2 | from .glue import glue_output_modes, glue_processors, glue_tasks_num_labels, glue_convert_examples_to_features 3 | 4 | -------------------------------------------------------------------------------- /transformers/data/processors/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import csv 18 | import sys 19 | import copy 20 | import json 21 | 22 | class InputExample(object): 23 | """ 24 | A single training/test example for simple sequence classification. 25 | 26 | Args: 27 | guid: Unique id for the example. 28 | text_a: string. The untokenized text of the first sequence. For single 29 | sequence tasks, only this sequence must be specified. 30 | text_b: (Optional) string. The untokenized text of the second sequence. 31 | Only must be specified for sequence pair tasks. 32 | label: (Optional) string. The label of the example. This should be 33 | specified for train and dev examples, but not for test examples. 34 | """ 35 | def __init__(self, guid, text_a, text_b=None, label=None): 36 | self.guid = guid 37 | self.text_a = text_a 38 | self.text_b = text_b 39 | self.label = label 40 | 41 | def __repr__(self): 42 | return str(self.to_json_string()) 43 | 44 | def to_dict(self): 45 | """Serializes this instance to a Python dictionary.""" 46 | output = copy.deepcopy(self.__dict__) 47 | return output 48 | 49 | def to_json_string(self): 50 | """Serializes this instance to a JSON string.""" 51 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 52 | 53 | 54 | class InputFeatures(object): 55 | """ 56 | A single set of features of data. 57 | 58 | Args: 59 | input_ids: Indices of input sequence tokens in the vocabulary. 60 | attention_mask: Mask to avoid performing attention on padding token indices. 61 | Mask values selected in ``[0, 1]``: 62 | Usually ``1`` for tokens that are NOT MASKED, ``0`` for MASKED (padded) tokens. 63 | token_type_ids: Segment token indices to indicate first and second portions of the inputs. 64 | label: Label corresponding to the input 65 | """ 66 | 67 | def __init__(self, input_ids, attention_mask, token_type_ids, label): 68 | self.input_ids = input_ids 69 | self.attention_mask = attention_mask 70 | self.token_type_ids = token_type_ids 71 | self.label = label 72 | 73 | def __repr__(self): 74 | return str(self.to_json_string()) 75 | 76 | def to_dict(self): 77 | """Serializes this instance to a Python dictionary.""" 78 | output = copy.deepcopy(self.__dict__) 79 | return output 80 | 81 | def to_json_string(self): 82 | """Serializes this instance to a JSON string.""" 83 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 84 | 85 | 86 | class DataProcessor(object): 87 | """Base class for data converters for sequence classification data sets.""" 88 | 89 | def get_example_from_tensor_dict(self, tensor_dict): 90 | """Gets an example from a dict with tensorflow tensors 91 | 92 | Args: 93 | tensor_dict: Keys and values should match the corresponding Glue 94 | tensorflow_dataset examples. 95 | """ 96 | raise NotImplementedError() 97 | 98 | def get_train_examples(self, data_dir): 99 | """Gets a collection of `InputExample`s for the train set.""" 100 | raise NotImplementedError() 101 | 102 | def get_dev_examples(self, data_dir): 103 | """Gets a collection of `InputExample`s for the dev set.""" 104 | raise NotImplementedError() 105 | 106 | def get_labels(self): 107 | """Gets the list of labels for this data set.""" 108 | raise NotImplementedError() 109 | 110 | def tfds_map(self, example): 111 | """Some tensorflow_datasets datasets are not formatted the same way the GLUE datasets are. 112 | This method converts examples to the correct format.""" 113 | if len(self.get_labels()) > 1: 114 | example.label = self.get_labels()[int(example.label)] 115 | return example 116 | 117 | @classmethod 118 | def _read_tsv(cls, input_file, quotechar=None): 119 | """Reads a tab separated value file.""" 120 | with open(input_file, "r", encoding="utf-8-sig") as f: 121 | reader = csv.reader(f, delimiter="\t", quotechar=quotechar) 122 | lines = [] 123 | for line in reader: 124 | if sys.version_info[0] == 2: 125 | line = list(unicode(cell, 'utf-8') for cell in line) 126 | lines.append(line) 127 | return lines 128 | -------------------------------------------------------------------------------- /transformers/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chevalier1024/CCF-BDCI-ABSA/5c534e443dd1d3ee8932c8369ebd80d2ea6bacec/transformers/tests/__init__.py -------------------------------------------------------------------------------- /transformers/tests/configuration_common_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 HuggingFace Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import copy 20 | import os 21 | import shutil 22 | import json 23 | import random 24 | import uuid 25 | 26 | import unittest 27 | import logging 28 | 29 | 30 | class ConfigTester(object): 31 | def __init__(self, parent, config_class=None, **kwargs): 32 | self.parent = parent 33 | self.config_class = config_class 34 | self.inputs_dict = kwargs 35 | 36 | def create_and_test_config_common_properties(self): 37 | config = self.config_class(**self.inputs_dict) 38 | self.parent.assertTrue(hasattr(config, 'vocab_size')) 39 | self.parent.assertTrue(hasattr(config, 'hidden_size')) 40 | self.parent.assertTrue(hasattr(config, 'num_attention_heads')) 41 | self.parent.assertTrue(hasattr(config, 'num_hidden_layers')) 42 | 43 | def create_and_test_config_to_json_string(self): 44 | config = self.config_class(**self.inputs_dict) 45 | obj = json.loads(config.to_json_string()) 46 | for key, value in self.inputs_dict.items(): 47 | self.parent.assertEqual(obj[key], value) 48 | 49 | def create_and_test_config_to_json_file(self): 50 | config_first = self.config_class(**self.inputs_dict) 51 | json_file_path = os.path.join(os.getcwd(), "config_" + str(uuid.uuid4()) + ".json") 52 | config_first.to_json_file(json_file_path) 53 | config_second = self.config_class.from_json_file(json_file_path) 54 | os.remove(json_file_path) 55 | self.parent.assertEqual(config_second.to_dict(), config_first.to_dict()) 56 | 57 | def run_common_tests(self): 58 | self.create_and_test_config_common_properties() 59 | self.create_and_test_config_to_json_string() 60 | self.create_and_test_config_to_json_file() 61 | 62 | if __name__ == "__main__": 63 | unittest.main() -------------------------------------------------------------------------------- /transformers/tests/conftest.py: -------------------------------------------------------------------------------- 1 | # content of conftest.py 2 | 3 | import pytest 4 | 5 | 6 | def pytest_addoption(parser): 7 | parser.addoption( 8 | "--runslow", action="store_true", default=False, help="run slow tests" 9 | ) 10 | parser.addoption( 11 | "--use_cuda", action="store_true", default=False, help="run tests on gpu" 12 | ) 13 | 14 | 15 | def pytest_configure(config): 16 | config.addinivalue_line("markers", "slow: mark test as slow to run") 17 | 18 | 19 | def pytest_collection_modifyitems(config, items): 20 | if config.getoption("--runslow"): 21 | # --runslow given in cli: do not skip slow tests 22 | return 23 | skip_slow = pytest.mark.skip(reason="need --runslow option to run") 24 | for item in items: 25 | if "slow" in item.keywords: 26 | item.add_marker(skip_slow) 27 | 28 | @pytest.fixture 29 | def use_cuda(request): 30 | """ Run test on gpu """ 31 | return request.config.getoption("--use_cuda") 32 | -------------------------------------------------------------------------------- /transformers/tests/fixtures/input.txt: -------------------------------------------------------------------------------- 1 | Who was Jim Henson ? ||| Jim Henson was a puppeteer 2 | -------------------------------------------------------------------------------- /transformers/tests/fixtures/sample_text.txt: -------------------------------------------------------------------------------- 1 | This text is included to make sure Unicode is handled properly: 力加勝北区ᴵᴺᵀᵃছজটডণত 2 | Text should be one-sentence-per-line, with empty lines between documents. 3 | This sample text is public domain and was randomly selected from Project Guttenberg. 4 | 5 | The rain had only ceased with the gray streaks of morning at Blazing Star, and the settlement awoke to a moral sense of cleanliness, and the finding of forgotten knives, tin cups, and smaller camp utensils, where the heavy showers had washed away the debris and dust heaps before the cabin doors. 6 | Indeed, it was recorded in Blazing Star that a fortunate early riser had once picked up on the highway a solid chunk of gold quartz which the rain had freed from its incumbering soil, and washed into immediate and glittering popularity. 7 | Possibly this may have been the reason why early risers in that locality, during the rainy season, adopted a thoughtful habit of body, and seldom lifted their eyes to the rifted or india-ink washed skies above them. 8 | "Cass" Beard had risen early that morning, but not with a view to discovery. 9 | A leak in his cabin roof,--quite consistent with his careless, improvident habits,--had roused him at 4 A. M., with a flooded "bunk" and wet blankets. 10 | The chips from his wood pile refused to kindle a fire to dry his bed-clothes, and he had recourse to a more provident neighbor's to supply the deficiency. 11 | This was nearly opposite. 12 | Mr. Cassius crossed the highway, and stopped suddenly. 13 | Something glittered in the nearest red pool before him. 14 | Gold, surely! 15 | But, wonderful to relate, not an irregular, shapeless fragment of crude ore, fresh from Nature's crucible, but a bit of jeweler's handicraft in the form of a plain gold ring. 16 | Looking at it more attentively, he saw that it bore the inscription, "May to Cass." 17 | Like most of his fellow gold-seekers, Cass was superstitious. 18 | 19 | The fountain of classic wisdom, Hypatia herself. 20 | As the ancient sage--the name is unimportant to a monk--pumped water nightly that he might study by day, so I, the guardian of cloaks and parasols, at the sacred doors of her lecture-room, imbibe celestial knowledge. 21 | From my youth I felt in me a soul above the matter-entangled herd. 22 | She revealed to me the glorious fact, that I am a spark of Divinity itself. 23 | A fallen star, I am, sir!' continued he, pensively, stroking his lean stomach--'a fallen star!--fallen, if the dignity of philosophy will allow of the simile, among the hogs of the lower world--indeed, even into the hog-bucket itself. Well, after all, I will show you the way to the Archbishop's. 24 | There is a philosophic pleasure in opening one's treasures to the modest young. 25 | Perhaps you will assist me by carrying this basket of fruit?' And the little man jumped up, put his basket on Philammon's head, and trotted off up a neighbouring street. 26 | Philammon followed, half contemptuous, half wondering at what this philosophy might be, which could feed the self-conceit of anything so abject as his ragged little apish guide; 27 | but the novel roar and whirl of the street, the perpetual stream of busy faces, the line of curricles, palanquins, laden asses, camels, elephants, which met and passed him, and squeezed him up steps and into doorways, as they threaded their way through the great Moon-gate into the ample street beyond, drove everything from his mind but wondering curiosity, and a vague, helpless dread of that great living wilderness, more terrible than any dead wilderness of sand which he had left behind. 28 | Already he longed for the repose, the silence of the Laura--for faces which knew him and smiled upon him; but it was too late to turn back now. 29 | His guide held on for more than a mile up the great main street, crossed in the centre of the city, at right angles, by one equally magnificent, at each end of which, miles away, appeared, dim and distant over the heads of the living stream of passengers, the yellow sand-hills of the desert; 30 | while at the end of the vista in front of them gleamed the blue harbour, through a network of countless masts. 31 | At last they reached the quay at the opposite end of the street; 32 | and there burst on Philammon's astonished eyes a vast semicircle of blue sea, ringed with palaces and towers. 33 | He stopped involuntarily; and his little guide stopped also, and looked askance at the young monk, to watch the effect which that grand panorama should produce on him. 34 | -------------------------------------------------------------------------------- /transformers/tests/fixtures/test_sentencepiece.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chevalier1024/CCF-BDCI-ABSA/5c534e443dd1d3ee8932c8369ebd80d2ea6bacec/transformers/tests/fixtures/test_sentencepiece.model -------------------------------------------------------------------------------- /transformers/tests/modeling_auto_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import unittest 20 | import shutil 21 | import pytest 22 | import logging 23 | 24 | from transformers import is_torch_available 25 | 26 | if is_torch_available(): 27 | from transformers import (AutoConfig, BertConfig, 28 | AutoModel, BertModel, 29 | AutoModelWithLMHead, BertForMaskedLM, 30 | AutoModelForSequenceClassification, BertForSequenceClassification, 31 | AutoModelForQuestionAnswering, BertForQuestionAnswering) 32 | from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP 33 | 34 | from .modeling_common_test import (CommonTestCases, ids_tensor) 35 | from .configuration_common_test import ConfigTester 36 | else: 37 | pytestmark = pytest.mark.skip("Require Torch") 38 | 39 | 40 | class AutoModelTest(unittest.TestCase): 41 | @pytest.mark.slow 42 | def test_model_from_pretrained(self): 43 | logging.basicConfig(level=logging.INFO) 44 | for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 45 | config = AutoConfig.from_pretrained(model_name) 46 | self.assertIsNotNone(config) 47 | self.assertIsInstance(config, BertConfig) 48 | 49 | model = AutoModel.from_pretrained(model_name) 50 | model, loading_info = AutoModel.from_pretrained(model_name, output_loading_info=True) 51 | self.assertIsNotNone(model) 52 | self.assertIsInstance(model, BertModel) 53 | for value in loading_info.values(): 54 | self.assertEqual(len(value), 0) 55 | 56 | @pytest.mark.slow 57 | def test_lmhead_model_from_pretrained(self): 58 | logging.basicConfig(level=logging.INFO) 59 | for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 60 | config = AutoConfig.from_pretrained(model_name) 61 | self.assertIsNotNone(config) 62 | self.assertIsInstance(config, BertConfig) 63 | 64 | model = AutoModelWithLMHead.from_pretrained(model_name) 65 | model, loading_info = AutoModelWithLMHead.from_pretrained(model_name, output_loading_info=True) 66 | self.assertIsNotNone(model) 67 | self.assertIsInstance(model, BertForMaskedLM) 68 | 69 | @pytest.mark.slow 70 | def test_sequence_classification_model_from_pretrained(self): 71 | logging.basicConfig(level=logging.INFO) 72 | for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 73 | config = AutoConfig.from_pretrained(model_name) 74 | self.assertIsNotNone(config) 75 | self.assertIsInstance(config, BertConfig) 76 | 77 | model = AutoModelForSequenceClassification.from_pretrained(model_name) 78 | model, loading_info = AutoModelForSequenceClassification.from_pretrained(model_name, output_loading_info=True) 79 | self.assertIsNotNone(model) 80 | self.assertIsInstance(model, BertForSequenceClassification) 81 | 82 | @pytest.mark.slow 83 | def test_question_answering_model_from_pretrained(self): 84 | logging.basicConfig(level=logging.INFO) 85 | for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 86 | config = AutoConfig.from_pretrained(model_name) 87 | self.assertIsNotNone(config) 88 | self.assertIsInstance(config, BertConfig) 89 | 90 | model = AutoModelForQuestionAnswering.from_pretrained(model_name) 91 | model, loading_info = AutoModelForQuestionAnswering.from_pretrained(model_name, output_loading_info=True) 92 | self.assertIsNotNone(model) 93 | self.assertIsInstance(model, BertForQuestionAnswering) 94 | 95 | 96 | if __name__ == "__main__": 97 | unittest.main() 98 | -------------------------------------------------------------------------------- /transformers/tests/modeling_encoder_decoder_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Hugging Face Inc. Team 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import logging 17 | import unittest 18 | import pytest 19 | 20 | from transformers import is_torch_available 21 | 22 | if is_torch_available(): 23 | from transformers import BertModel, BertForMaskedLM, Model2Model 24 | from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP 25 | else: 26 | pytestmark = pytest.mark.skip("Require Torch") 27 | 28 | 29 | class EncoderDecoderModelTest(unittest.TestCase): 30 | @pytest.mark.slow 31 | def test_model2model_from_pretrained(self): 32 | logging.basicConfig(level=logging.INFO) 33 | for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 34 | model = Model2Model.from_pretrained(model_name) 35 | self.assertIsInstance(model.encoder, BertModel) 36 | self.assertIsInstance(model.decoder, BertForMaskedLM) 37 | self.assertEqual(model.decoder.config.is_decoder, True) 38 | self.assertEqual(model.encoder.config.is_decoder, False) 39 | 40 | def test_model2model_from_pretrained_not_bert(self): 41 | logging.basicConfig(level=logging.INFO) 42 | with self.assertRaises(ValueError): 43 | _ = Model2Model.from_pretrained('roberta') 44 | 45 | with self.assertRaises(ValueError): 46 | _ = Model2Model.from_pretrained('distilbert') 47 | 48 | with self.assertRaises(ValueError): 49 | _ = Model2Model.from_pretrained('does-not-exist') 50 | 51 | 52 | if __name__ == "__main__": 53 | unittest.main() 54 | -------------------------------------------------------------------------------- /transformers/tests/modeling_tf_auto_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import unittest 20 | import shutil 21 | import pytest 22 | import logging 23 | 24 | from transformers import is_tf_available 25 | 26 | if is_tf_available(): 27 | from transformers import (AutoConfig, BertConfig, 28 | TFAutoModel, TFBertModel, 29 | TFAutoModelWithLMHead, TFBertForMaskedLM, 30 | TFAutoModelForSequenceClassification, TFBertForSequenceClassification, 31 | TFAutoModelForQuestionAnswering, TFBertForQuestionAnswering) 32 | from transformers.modeling_tf_bert import TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP 33 | 34 | from .modeling_common_test import (CommonTestCases, ids_tensor) 35 | from .configuration_common_test import ConfigTester 36 | else: 37 | pytestmark = pytest.mark.skip("Require TensorFlow") 38 | 39 | 40 | class TFAutoModelTest(unittest.TestCase): 41 | def test_model_from_pretrained(self): 42 | import h5py 43 | self.assertTrue(h5py.version.hdf5_version.startswith("1.10")) 44 | 45 | logging.basicConfig(level=logging.INFO) 46 | # for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 47 | for model_name in ['bert-base-uncased']: 48 | config = AutoConfig.from_pretrained(model_name, force_download=True) 49 | self.assertIsNotNone(config) 50 | self.assertIsInstance(config, BertConfig) 51 | 52 | model = TFAutoModel.from_pretrained(model_name, force_download=True) 53 | self.assertIsNotNone(model) 54 | self.assertIsInstance(model, TFBertModel) 55 | 56 | def test_lmhead_model_from_pretrained(self): 57 | logging.basicConfig(level=logging.INFO) 58 | # for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 59 | for model_name in ['bert-base-uncased']: 60 | config = AutoConfig.from_pretrained(model_name, force_download=True) 61 | self.assertIsNotNone(config) 62 | self.assertIsInstance(config, BertConfig) 63 | 64 | model = TFAutoModelWithLMHead.from_pretrained(model_name, force_download=True) 65 | self.assertIsNotNone(model) 66 | self.assertIsInstance(model, TFBertForMaskedLM) 67 | 68 | def test_sequence_classification_model_from_pretrained(self): 69 | logging.basicConfig(level=logging.INFO) 70 | # for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 71 | for model_name in ['bert-base-uncased']: 72 | config = AutoConfig.from_pretrained(model_name, force_download=True) 73 | self.assertIsNotNone(config) 74 | self.assertIsInstance(config, BertConfig) 75 | 76 | model = TFAutoModelForSequenceClassification.from_pretrained(model_name, force_download=True) 77 | self.assertIsNotNone(model) 78 | self.assertIsInstance(model, TFBertForSequenceClassification) 79 | 80 | def test_question_answering_model_from_pretrained(self): 81 | logging.basicConfig(level=logging.INFO) 82 | # for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 83 | for model_name in ['bert-base-uncased']: 84 | config = AutoConfig.from_pretrained(model_name, force_download=True) 85 | self.assertIsNotNone(config) 86 | self.assertIsInstance(config, BertConfig) 87 | 88 | model = TFAutoModelForQuestionAnswering.from_pretrained(model_name, force_download=True) 89 | self.assertIsNotNone(model) 90 | self.assertIsInstance(model, TFBertForQuestionAnswering) 91 | 92 | 93 | if __name__ == "__main__": 94 | unittest.main() 95 | -------------------------------------------------------------------------------- /transformers/tests/optimization_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import unittest 20 | import os 21 | import pytest 22 | 23 | from transformers import is_torch_available 24 | 25 | if is_torch_available(): 26 | import torch 27 | 28 | from transformers import (AdamW, 29 | get_constant_schedule, 30 | get_constant_schedule_with_warmup, 31 | get_cosine_schedule_with_warmup, 32 | get_cosine_with_hard_restarts_schedule_with_warmup, 33 | get_linear_schedule_with_warmup) 34 | else: 35 | pytestmark = pytest.mark.skip("Require Torch") 36 | 37 | from .tokenization_tests_commons import TemporaryDirectory 38 | 39 | 40 | def unwrap_schedule(scheduler, num_steps=10): 41 | lrs = [] 42 | for _ in range(num_steps): 43 | scheduler.step() 44 | lrs.append(scheduler.get_lr()) 45 | return lrs 46 | 47 | def unwrap_and_save_reload_schedule(scheduler, num_steps=10): 48 | lrs = [] 49 | for step in range(num_steps): 50 | scheduler.step() 51 | lrs.append(scheduler.get_lr()) 52 | if step == num_steps // 2: 53 | with TemporaryDirectory() as tmpdirname: 54 | file_name = os.path.join(tmpdirname, 'schedule.bin') 55 | torch.save(scheduler.state_dict(), file_name) 56 | 57 | state_dict = torch.load(file_name) 58 | scheduler.load_state_dict(state_dict) 59 | return lrs 60 | 61 | class OptimizationTest(unittest.TestCase): 62 | 63 | def assertListAlmostEqual(self, list1, list2, tol): 64 | self.assertEqual(len(list1), len(list2)) 65 | for a, b in zip(list1, list2): 66 | self.assertAlmostEqual(a, b, delta=tol) 67 | 68 | def test_adam_w(self): 69 | w = torch.tensor([0.1, -0.2, -0.1], requires_grad=True) 70 | target = torch.tensor([0.4, 0.2, -0.5]) 71 | criterion = torch.nn.MSELoss() 72 | # No warmup, constant schedule, no gradient clipping 73 | optimizer = AdamW(params=[w], lr=2e-1, weight_decay=0.0) 74 | for _ in range(100): 75 | loss = criterion(w, target) 76 | loss.backward() 77 | optimizer.step() 78 | w.grad.detach_() # No zero_grad() function on simple tensors. we do it ourselves. 79 | w.grad.zero_() 80 | self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2) 81 | 82 | 83 | class ScheduleInitTest(unittest.TestCase): 84 | m = torch.nn.Linear(50, 50) if is_torch_available() else None 85 | optimizer = AdamW(m.parameters(), lr=10.) if is_torch_available() else None 86 | num_steps = 10 87 | 88 | def assertListAlmostEqual(self, list1, list2, tol): 89 | self.assertEqual(len(list1), len(list2)) 90 | for a, b in zip(list1, list2): 91 | self.assertAlmostEqual(a, b, delta=tol) 92 | 93 | def test_constant_scheduler(self): 94 | scheduler = get_constant_schedule(self.optimizer) 95 | lrs = unwrap_schedule(scheduler, self.num_steps) 96 | expected_learning_rates = [10.] * self.num_steps 97 | self.assertEqual(len(lrs[0]), 1) 98 | self.assertListEqual([l[0] for l in lrs], expected_learning_rates) 99 | 100 | scheduler = get_constant_schedule(self.optimizer) 101 | lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) 102 | self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2]) 103 | 104 | def test_warmup_constant_scheduler(self): 105 | scheduler = get_constant_schedule_with_warmup(self.optimizer, num_warmup_steps=4) 106 | lrs = unwrap_schedule(scheduler, self.num_steps) 107 | expected_learning_rates = [2.5, 5.0, 7.5, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0] 108 | self.assertEqual(len(lrs[0]), 1) 109 | self.assertListEqual([l[0] for l in lrs], expected_learning_rates) 110 | 111 | scheduler = get_constant_schedule_with_warmup(self.optimizer, num_warmup_steps=4) 112 | lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) 113 | self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2]) 114 | 115 | def test_warmup_linear_scheduler(self): 116 | scheduler = get_linear_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_training_steps=10) 117 | lrs = unwrap_schedule(scheduler, self.num_steps) 118 | expected_learning_rates = [5.0, 10.0, 8.75, 7.5, 6.25, 5.0, 3.75, 2.5, 1.25, 0.0] 119 | self.assertEqual(len(lrs[0]), 1) 120 | self.assertListEqual([l[0] for l in lrs], expected_learning_rates) 121 | 122 | scheduler = get_linear_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_training_steps=10) 123 | lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) 124 | self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2]) 125 | 126 | def test_warmup_cosine_scheduler(self): 127 | scheduler = get_cosine_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_training_steps=10) 128 | lrs = unwrap_schedule(scheduler, self.num_steps) 129 | expected_learning_rates = [5.0, 10.0, 9.61, 8.53, 6.91, 5.0, 3.08, 1.46, 0.38, 0.0] 130 | self.assertEqual(len(lrs[0]), 1) 131 | self.assertListAlmostEqual([l[0] for l in lrs], expected_learning_rates, tol=1e-2) 132 | 133 | scheduler = get_cosine_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_training_steps=10) 134 | lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) 135 | self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2]) 136 | 137 | def test_warmup_cosine_hard_restart_scheduler(self): 138 | scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_cycles=2, num_training_steps=10) 139 | lrs = unwrap_schedule(scheduler, self.num_steps) 140 | expected_learning_rates = [5.0, 10.0, 8.53, 5.0, 1.46, 10.0, 8.53, 5.0, 1.46, 0.0] 141 | self.assertEqual(len(lrs[0]), 1) 142 | self.assertListAlmostEqual([l[0] for l in lrs], expected_learning_rates, tol=1e-2) 143 | 144 | scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_cycles=2, num_training_steps=10) 145 | lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) 146 | self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2]) 147 | 148 | 149 | if __name__ == "__main__": 150 | unittest.main() 151 | -------------------------------------------------------------------------------- /transformers/tests/tokenization_auto_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import unittest 20 | import shutil 21 | import pytest 22 | import logging 23 | 24 | from transformers import AutoTokenizer, BertTokenizer, AutoTokenizer, GPT2Tokenizer 25 | from transformers import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP 26 | 27 | 28 | class AutoTokenizerTest(unittest.TestCase): 29 | @pytest.mark.slow 30 | def test_tokenizer_from_pretrained(self): 31 | logging.basicConfig(level=logging.INFO) 32 | for model_name in list(BERT_PRETRAINED_CONFIG_ARCHIVE_MAP.keys())[:1]: 33 | tokenizer = AutoTokenizer.from_pretrained(model_name) 34 | self.assertIsNotNone(tokenizer) 35 | self.assertIsInstance(tokenizer, BertTokenizer) 36 | self.assertGreater(len(tokenizer), 0) 37 | 38 | for model_name in list(GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP.keys())[:1]: 39 | tokenizer = AutoTokenizer.from_pretrained(model_name) 40 | self.assertIsNotNone(tokenizer) 41 | self.assertIsInstance(tokenizer, GPT2Tokenizer) 42 | self.assertGreater(len(tokenizer), 0) 43 | 44 | 45 | if __name__ == "__main__": 46 | unittest.main() 47 | -------------------------------------------------------------------------------- /transformers/tests/tokenization_bert_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | import pytest 20 | from io import open 21 | 22 | from transformers.tokenization_bert import (BasicTokenizer, 23 | BertTokenizer, 24 | WordpieceTokenizer, 25 | _is_control, _is_punctuation, 26 | _is_whitespace, VOCAB_FILES_NAMES) 27 | 28 | from .tokenization_tests_commons import CommonTestCases 29 | 30 | class BertTokenizationTest(CommonTestCases.CommonTokenizerTester): 31 | 32 | tokenizer_class = BertTokenizer 33 | 34 | def setUp(self): 35 | super(BertTokenizationTest, self).setUp() 36 | 37 | vocab_tokens = [ 38 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 39 | "##ing", ",", "low", "lowest", 40 | ] 41 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) 42 | with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer: 43 | vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) 44 | 45 | def get_tokenizer(self, **kwargs): 46 | return BertTokenizer.from_pretrained(self.tmpdirname, **kwargs) 47 | 48 | def get_input_output_texts(self): 49 | input_text = u"UNwant\u00E9d,running" 50 | output_text = u"unwanted, running" 51 | return input_text, output_text 52 | 53 | def test_full_tokenizer(self): 54 | tokenizer = self.tokenizer_class(self.vocab_file) 55 | 56 | tokens = tokenizer.tokenize(u"UNwant\u00E9d,running") 57 | self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"]) 58 | self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) 59 | 60 | def test_chinese(self): 61 | tokenizer = BasicTokenizer() 62 | 63 | self.assertListEqual( 64 | tokenizer.tokenize(u"ah\u535A\u63A8zz"), 65 | [u"ah", u"\u535A", u"\u63A8", u"zz"]) 66 | 67 | def test_basic_tokenizer_lower(self): 68 | tokenizer = BasicTokenizer(do_lower_case=True) 69 | 70 | self.assertListEqual( 71 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), 72 | ["hello", "!", "how", "are", "you", "?"]) 73 | self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"]) 74 | 75 | def test_basic_tokenizer_no_lower(self): 76 | tokenizer = BasicTokenizer(do_lower_case=False) 77 | 78 | self.assertListEqual( 79 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), 80 | ["HeLLo", "!", "how", "Are", "yoU", "?"]) 81 | 82 | def test_wordpiece_tokenizer(self): 83 | vocab_tokens = [ 84 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 85 | "##ing" 86 | ] 87 | 88 | vocab = {} 89 | for (i, token) in enumerate(vocab_tokens): 90 | vocab[token] = i 91 | tokenizer = WordpieceTokenizer(vocab=vocab, unk_token="[UNK]") 92 | 93 | self.assertListEqual(tokenizer.tokenize(""), []) 94 | 95 | self.assertListEqual( 96 | tokenizer.tokenize("unwanted running"), 97 | ["un", "##want", "##ed", "runn", "##ing"]) 98 | 99 | self.assertListEqual( 100 | tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"]) 101 | 102 | def test_is_whitespace(self): 103 | self.assertTrue(_is_whitespace(u" ")) 104 | self.assertTrue(_is_whitespace(u"\t")) 105 | self.assertTrue(_is_whitespace(u"\r")) 106 | self.assertTrue(_is_whitespace(u"\n")) 107 | self.assertTrue(_is_whitespace(u"\u00A0")) 108 | 109 | self.assertFalse(_is_whitespace(u"A")) 110 | self.assertFalse(_is_whitespace(u"-")) 111 | 112 | def test_is_control(self): 113 | self.assertTrue(_is_control(u"\u0005")) 114 | 115 | self.assertFalse(_is_control(u"A")) 116 | self.assertFalse(_is_control(u" ")) 117 | self.assertFalse(_is_control(u"\t")) 118 | self.assertFalse(_is_control(u"\r")) 119 | 120 | def test_is_punctuation(self): 121 | self.assertTrue(_is_punctuation(u"-")) 122 | self.assertTrue(_is_punctuation(u"$")) 123 | self.assertTrue(_is_punctuation(u"`")) 124 | self.assertTrue(_is_punctuation(u".")) 125 | 126 | self.assertFalse(_is_punctuation(u"A")) 127 | self.assertFalse(_is_punctuation(u" ")) 128 | 129 | @pytest.mark.slow 130 | def test_sequence_builders(self): 131 | tokenizer = self.tokenizer_class.from_pretrained("bert-base-uncased") 132 | 133 | text = tokenizer.encode("sequence builders", add_special_tokens=False) 134 | text_2 = tokenizer.encode("multi-sequence build", add_special_tokens=False) 135 | 136 | encoded_sentence = tokenizer.build_inputs_with_special_tokens(text) 137 | encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2) 138 | 139 | assert encoded_sentence == [101] + text + [102] 140 | assert encoded_pair == [101] + text + [102] + text_2 + [102] 141 | 142 | if __name__ == '__main__': 143 | unittest.main() 144 | -------------------------------------------------------------------------------- /transformers/tests/tokenization_ctrl_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Salesforce and HuggingFace Inc. team. 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from __future__ import absolute_import, division, print_function, unicode_literals 15 | 16 | import os 17 | import unittest 18 | import json 19 | from io import open 20 | 21 | from transformers.tokenization_ctrl import CTRLTokenizer, VOCAB_FILES_NAMES 22 | 23 | from .tokenization_tests_commons import CommonTestCases 24 | 25 | class CTRLTokenizationTest(CommonTestCases.CommonTokenizerTester): 26 | 27 | tokenizer_class = CTRLTokenizer 28 | 29 | def setUp(self): 30 | super(CTRLTokenizationTest, self).setUp() 31 | 32 | # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt 33 | vocab = ['adapt', 're@@', 'a@@', 'apt', 'c@@', 't', ''] 34 | vocab_tokens = dict(zip(vocab, range(len(vocab)))) 35 | merges = ["#version: 0.2", 'a p', 'ap t', 'r e', 'a d', 'ad apt', ''] 36 | self.special_tokens_map = {"unk_token": ""} 37 | 38 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) 39 | self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file']) 40 | with open(self.vocab_file, "w", encoding="utf-8") as fp: 41 | fp.write(json.dumps(vocab_tokens) + "\n") 42 | with open(self.merges_file, "w", encoding="utf-8") as fp: 43 | fp.write("\n".join(merges)) 44 | 45 | def get_tokenizer(self, **kwargs): 46 | kwargs.update(self.special_tokens_map) 47 | return CTRLTokenizer.from_pretrained(self.tmpdirname, **kwargs) 48 | 49 | def get_input_output_texts(self): 50 | input_text = u"adapt react readapt apt" 51 | output_text = u"adapt react readapt apt" 52 | return input_text, output_text 53 | 54 | def test_full_tokenizer(self): 55 | tokenizer = CTRLTokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map) 56 | text = "adapt react readapt apt" 57 | bpe_tokens = 'adapt re@@ a@@ c@@ t re@@ adapt apt'.split() 58 | tokens = tokenizer.tokenize(text) 59 | self.assertListEqual(tokens, bpe_tokens) 60 | 61 | input_tokens = tokens + [tokenizer.unk_token] 62 | 63 | input_bpe_tokens = [0, 1, 2, 4, 5, 1, 0, 3, 6] 64 | self.assertListEqual( 65 | tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) 66 | 67 | 68 | if __name__ == '__main__': 69 | unittest.main() 70 | -------------------------------------------------------------------------------- /transformers/tests/tokenization_distilbert_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | import pytest 20 | from io import open 21 | 22 | from transformers.tokenization_distilbert import (DistilBertTokenizer) 23 | 24 | from .tokenization_tests_commons import CommonTestCases 25 | from .tokenization_bert_test import BertTokenizationTest 26 | 27 | class DistilBertTokenizationTest(BertTokenizationTest): 28 | 29 | tokenizer_class = DistilBertTokenizer 30 | 31 | def get_tokenizer(self, **kwargs): 32 | return DistilBertTokenizer.from_pretrained(self.tmpdirname, **kwargs) 33 | 34 | @pytest.mark.slow 35 | def test_sequence_builders(self): 36 | tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") 37 | 38 | text = tokenizer.encode("sequence builders", add_special_tokens=False) 39 | text_2 = tokenizer.encode("multi-sequence build", add_special_tokens=False) 40 | 41 | encoded_sentence = tokenizer.build_inputs_with_special_tokens(text) 42 | encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2) 43 | 44 | assert encoded_sentence == [tokenizer.cls_token_id] + text + [tokenizer.sep_token_id] 45 | assert encoded_pair == [tokenizer.cls_token_id] + text + [tokenizer.sep_token_id] + \ 46 | text_2 + [tokenizer.sep_token_id] 47 | 48 | 49 | if __name__ == '__main__': 50 | unittest.main() 51 | -------------------------------------------------------------------------------- /transformers/tests/tokenization_gpt2_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | import json 20 | from io import open 21 | 22 | from transformers.tokenization_gpt2 import GPT2Tokenizer, VOCAB_FILES_NAMES 23 | 24 | from .tokenization_tests_commons import CommonTestCases 25 | 26 | class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester): 27 | 28 | tokenizer_class = GPT2Tokenizer 29 | 30 | def setUp(self): 31 | super(GPT2TokenizationTest, self).setUp() 32 | 33 | # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt 34 | vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", 35 | "\u0120", "\u0120l", "\u0120n", 36 | "\u0120lo", "\u0120low", "er", 37 | "\u0120lowest", "\u0120newer", "\u0120wider", ""] 38 | vocab_tokens = dict(zip(vocab, range(len(vocab)))) 39 | merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""] 40 | self.special_tokens_map = {"unk_token": ""} 41 | 42 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) 43 | self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file']) 44 | with open(self.vocab_file, "w", encoding="utf-8") as fp: 45 | fp.write(json.dumps(vocab_tokens) + "\n") 46 | with open(self.merges_file, "w", encoding="utf-8") as fp: 47 | fp.write("\n".join(merges)) 48 | 49 | def get_tokenizer(self, **kwargs): 50 | kwargs.update(self.special_tokens_map) 51 | return GPT2Tokenizer.from_pretrained(self.tmpdirname, **kwargs) 52 | 53 | def get_input_output_texts(self): 54 | input_text = u"lower newer" 55 | output_text = u"lower newer" 56 | return input_text, output_text 57 | 58 | def test_full_tokenizer(self): 59 | tokenizer = GPT2Tokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map) 60 | text = "lower newer" 61 | bpe_tokens = ["\u0120low", "er", "\u0120", "n", "e", "w", "er"] 62 | tokens = tokenizer.tokenize(text, add_prefix_space=True) 63 | self.assertListEqual(tokens, bpe_tokens) 64 | 65 | input_tokens = tokens + [tokenizer.unk_token] 66 | input_bpe_tokens = [14, 15, 10, 9, 3, 2, 15, 19] 67 | self.assertListEqual( 68 | tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) 69 | 70 | 71 | if __name__ == '__main__': 72 | unittest.main() 73 | -------------------------------------------------------------------------------- /transformers/tests/tokenization_openai_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | import json 20 | 21 | from transformers.tokenization_openai import OpenAIGPTTokenizer, VOCAB_FILES_NAMES 22 | 23 | from .tokenization_tests_commons import CommonTestCases 24 | 25 | 26 | class OpenAIGPTTokenizationTest(CommonTestCases.CommonTokenizerTester): 27 | 28 | tokenizer_class = OpenAIGPTTokenizer 29 | 30 | def setUp(self): 31 | super(OpenAIGPTTokenizationTest, self).setUp() 32 | 33 | # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt 34 | vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", 35 | "w", "r", "t", 36 | "lo", "low", "er", 37 | "low", "lowest", "newer", "wider", ""] 38 | vocab_tokens = dict(zip(vocab, range(len(vocab)))) 39 | merges = ["#version: 0.2", "l o", "lo w", "e r", ""] 40 | 41 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) 42 | self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file']) 43 | with open(self.vocab_file, "w") as fp: 44 | fp.write(json.dumps(vocab_tokens)) 45 | with open(self.merges_file, "w") as fp: 46 | fp.write("\n".join(merges)) 47 | 48 | def get_tokenizer(self, **kwargs): 49 | return OpenAIGPTTokenizer.from_pretrained(self.tmpdirname, **kwargs) 50 | 51 | def get_input_output_texts(self): 52 | input_text = u"lower newer" 53 | output_text = u"lower newer" 54 | return input_text, output_text 55 | 56 | 57 | def test_full_tokenizer(self): 58 | tokenizer = OpenAIGPTTokenizer(self.vocab_file, self.merges_file) 59 | 60 | text = "lower" 61 | bpe_tokens = ["low", "er"] 62 | tokens = tokenizer.tokenize(text) 63 | self.assertListEqual(tokens, bpe_tokens) 64 | 65 | input_tokens = tokens + [""] 66 | input_bpe_tokens = [14, 15, 20] 67 | self.assertListEqual( 68 | tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) 69 | 70 | 71 | if __name__ == '__main__': 72 | unittest.main() 73 | -------------------------------------------------------------------------------- /transformers/tests/tokenization_roberta_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import json 19 | import unittest 20 | import pytest 21 | from io import open 22 | 23 | from transformers.tokenization_roberta import RobertaTokenizer, VOCAB_FILES_NAMES 24 | from .tokenization_tests_commons import CommonTestCases 25 | 26 | 27 | class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester): 28 | tokenizer_class = RobertaTokenizer 29 | 30 | def setUp(self): 31 | super(RobertaTokenizationTest, self).setUp() 32 | 33 | # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt 34 | vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", 35 | "\u0120", "\u0120l", "\u0120n", 36 | "\u0120lo", "\u0120low", "er", 37 | "\u0120lowest", "\u0120newer", "\u0120wider", ""] 38 | vocab_tokens = dict(zip(vocab, range(len(vocab)))) 39 | merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""] 40 | self.special_tokens_map = {"unk_token": ""} 41 | 42 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) 43 | self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file']) 44 | with open(self.vocab_file, "w", encoding="utf-8") as fp: 45 | fp.write(json.dumps(vocab_tokens) + "\n") 46 | with open(self.merges_file, "w", encoding="utf-8") as fp: 47 | fp.write("\n".join(merges)) 48 | 49 | def get_tokenizer(self, **kwargs): 50 | kwargs.update(self.special_tokens_map) 51 | return RobertaTokenizer.from_pretrained(self.tmpdirname, **kwargs) 52 | 53 | def get_input_output_texts(self): 54 | input_text = u"lower newer" 55 | output_text = u"lower newer" 56 | return input_text, output_text 57 | 58 | def test_full_tokenizer(self): 59 | tokenizer = RobertaTokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map) 60 | text = "lower newer" 61 | bpe_tokens = ["\u0120low", "er", "\u0120", "n", "e", "w", "er"] 62 | tokens = tokenizer.tokenize(text, add_prefix_space=True) 63 | self.assertListEqual(tokens, bpe_tokens) 64 | 65 | input_tokens = tokens + [tokenizer.unk_token] 66 | input_bpe_tokens = [14, 15, 10, 9, 3, 2, 15, 19] 67 | self.assertListEqual( 68 | tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) 69 | 70 | def roberta_dict_integration_testing(self): 71 | tokenizer = self.get_tokenizer() 72 | 73 | self.assertListEqual( 74 | tokenizer.encode('Hello world!', add_special_tokens=False), 75 | [0, 31414, 232, 328, 2] 76 | ) 77 | self.assertListEqual( 78 | tokenizer.encode('Hello world! cécé herlolip 418', add_special_tokens=False), 79 | [0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2] 80 | ) 81 | 82 | @pytest.mark.slow 83 | def test_sequence_builders(self): 84 | tokenizer = RobertaTokenizer.from_pretrained("roberta-base") 85 | 86 | text = tokenizer.encode("sequence builders", add_special_tokens=False) 87 | text_2 = tokenizer.encode("multi-sequence build", add_special_tokens=False) 88 | 89 | encoded_text_from_decode = tokenizer.encode("sequence builders", add_special_tokens=True) 90 | encoded_pair_from_decode = tokenizer.encode("sequence builders", "multi-sequence build", add_special_tokens=True) 91 | 92 | encoded_sentence = tokenizer.build_inputs_with_special_tokens(text) 93 | encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2) 94 | 95 | assert encoded_sentence == encoded_text_from_decode 96 | assert encoded_pair == encoded_pair_from_decode 97 | 98 | 99 | if __name__ == '__main__': 100 | unittest.main() 101 | -------------------------------------------------------------------------------- /transformers/tests/tokenization_transfo_xl_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | import pytest 20 | from io import open 21 | 22 | from transformers import is_torch_available 23 | 24 | if is_torch_available(): 25 | import torch 26 | from transformers.tokenization_transfo_xl import TransfoXLTokenizer, VOCAB_FILES_NAMES 27 | else: 28 | pytestmark = pytest.mark.skip("Require Torch") # TODO: untangle Transfo-XL tokenizer from torch.load and torch.save 29 | 30 | from .tokenization_tests_commons import CommonTestCases 31 | 32 | class TransfoXLTokenizationTest(CommonTestCases.CommonTokenizerTester): 33 | 34 | tokenizer_class = TransfoXLTokenizer if is_torch_available() else None 35 | 36 | def setUp(self): 37 | super(TransfoXLTokenizationTest, self).setUp() 38 | 39 | vocab_tokens = [ 40 | "", "[CLS]", "[SEP]", "want", "unwanted", "wa", "un", 41 | "running", ",", "low", "l", 42 | ] 43 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) 44 | with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer: 45 | vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) 46 | 47 | def get_tokenizer(self, **kwargs): 48 | kwargs['lower_case'] = True 49 | return TransfoXLTokenizer.from_pretrained(self.tmpdirname, **kwargs) 50 | 51 | def get_input_output_texts(self): 52 | input_text = u" UNwanted , running" 53 | output_text = u" unwanted, running" 54 | return input_text, output_text 55 | 56 | def test_full_tokenizer(self): 57 | tokenizer = TransfoXLTokenizer(vocab_file=self.vocab_file, lower_case=True) 58 | 59 | tokens = tokenizer.tokenize(u" UNwanted , running") 60 | self.assertListEqual(tokens, ["", "unwanted", ",", "running"]) 61 | 62 | self.assertListEqual( 63 | tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7]) 64 | 65 | def test_full_tokenizer_lower(self): 66 | tokenizer = TransfoXLTokenizer(lower_case=True) 67 | 68 | self.assertListEqual( 69 | tokenizer.tokenize(u" \tHeLLo ! how \n Are yoU ? "), 70 | ["hello", "!", "how", "are", "you", "?"]) 71 | 72 | def test_full_tokenizer_no_lower(self): 73 | tokenizer = TransfoXLTokenizer(lower_case=False) 74 | 75 | self.assertListEqual( 76 | tokenizer.tokenize(u" \tHeLLo ! how \n Are yoU ? "), 77 | ["HeLLo", "!", "how", "Are", "yoU", "?"]) 78 | 79 | 80 | if __name__ == '__main__': 81 | unittest.main() 82 | -------------------------------------------------------------------------------- /transformers/tests/tokenization_utils_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 HuggingFace Inc.. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import unittest 20 | import six 21 | import pytest 22 | 23 | from transformers import PreTrainedTokenizer 24 | from transformers.tokenization_gpt2 import GPT2Tokenizer 25 | 26 | class TokenizerUtilsTest(unittest.TestCase): 27 | @pytest.mark.slow 28 | def check_tokenizer_from_pretrained(self, tokenizer_class): 29 | s3_models = list(tokenizer_class.max_model_input_sizes.keys()) 30 | for model_name in s3_models[:1]: 31 | tokenizer = tokenizer_class.from_pretrained(model_name) 32 | self.assertIsNotNone(tokenizer) 33 | self.assertIsInstance(tokenizer, tokenizer_class) 34 | self.assertIsInstance(tokenizer, PreTrainedTokenizer) 35 | 36 | for special_tok in tokenizer.all_special_tokens: 37 | if six.PY2: 38 | self.assertIsInstance(special_tok, unicode) 39 | else: 40 | self.assertIsInstance(special_tok, str) 41 | special_tok_id = tokenizer.convert_tokens_to_ids(special_tok) 42 | self.assertIsInstance(special_tok_id, int) 43 | 44 | def test_pretrained_tokenizers(self): 45 | self.check_tokenizer_from_pretrained(GPT2Tokenizer) 46 | 47 | if __name__ == "__main__": 48 | unittest.main() 49 | -------------------------------------------------------------------------------- /transformers/tests/tokenization_xlm_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | import json 20 | import pytest 21 | 22 | from transformers.tokenization_xlm import XLMTokenizer, VOCAB_FILES_NAMES 23 | 24 | from .tokenization_tests_commons import CommonTestCases 25 | 26 | class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester): 27 | 28 | tokenizer_class = XLMTokenizer 29 | 30 | def setUp(self): 31 | super(XLMTokenizationTest, self).setUp() 32 | 33 | # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt 34 | vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", 35 | "w", "r", "t", 36 | "lo", "low", "er", 37 | "low", "lowest", "newer", "wider", ""] 38 | vocab_tokens = dict(zip(vocab, range(len(vocab)))) 39 | merges = ["l o 123", "lo w 1456", "e r 1789", ""] 40 | 41 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) 42 | self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file']) 43 | with open(self.vocab_file, "w") as fp: 44 | fp.write(json.dumps(vocab_tokens)) 45 | with open(self.merges_file, "w") as fp: 46 | fp.write("\n".join(merges)) 47 | 48 | def get_tokenizer(self, **kwargs): 49 | return XLMTokenizer.from_pretrained(self.tmpdirname, **kwargs) 50 | 51 | def get_input_output_texts(self): 52 | input_text = u"lower newer" 53 | output_text = u"lower newer" 54 | return input_text, output_text 55 | 56 | def test_full_tokenizer(self): 57 | """ Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """ 58 | tokenizer = XLMTokenizer(self.vocab_file, self.merges_file) 59 | 60 | text = "lower" 61 | bpe_tokens = ["low", "er"] 62 | tokens = tokenizer.tokenize(text) 63 | self.assertListEqual(tokens, bpe_tokens) 64 | 65 | input_tokens = tokens + [""] 66 | input_bpe_tokens = [14, 15, 20] 67 | self.assertListEqual( 68 | tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) 69 | 70 | @pytest.mark.slow 71 | def test_sequence_builders(self): 72 | tokenizer = XLMTokenizer.from_pretrained("xlm-mlm-en-2048") 73 | 74 | text = tokenizer.encode("sequence builders", add_special_tokens=False) 75 | text_2 = tokenizer.encode("multi-sequence build", add_special_tokens=False) 76 | 77 | encoded_sentence = tokenizer.build_inputs_with_special_tokens(text) 78 | encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2) 79 | 80 | assert encoded_sentence == [1] + text + [1] 81 | assert encoded_pair == [1] + text + [1] + text_2 + [1] 82 | 83 | if __name__ == '__main__': 84 | unittest.main() 85 | -------------------------------------------------------------------------------- /transformers/tests/tokenization_xlnet_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | import pytest 20 | 21 | from transformers.tokenization_xlnet import (XLNetTokenizer, SPIECE_UNDERLINE) 22 | 23 | from .tokenization_tests_commons import CommonTestCases 24 | 25 | SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), 26 | 'fixtures/test_sentencepiece.model') 27 | 28 | class XLNetTokenizationTest(CommonTestCases.CommonTokenizerTester): 29 | 30 | tokenizer_class = XLNetTokenizer 31 | 32 | def setUp(self): 33 | super(XLNetTokenizationTest, self).setUp() 34 | 35 | # We have a SentencePiece fixture for testing 36 | tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True) 37 | tokenizer.save_pretrained(self.tmpdirname) 38 | 39 | def get_tokenizer(self, **kwargs): 40 | return XLNetTokenizer.from_pretrained(self.tmpdirname, **kwargs) 41 | 42 | def get_input_output_texts(self): 43 | input_text = u"This is a test" 44 | output_text = u"This is a test" 45 | return input_text, output_text 46 | 47 | 48 | def test_full_tokenizer(self): 49 | tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True) 50 | 51 | tokens = tokenizer.tokenize(u'This is a test') 52 | self.assertListEqual(tokens, [u'▁This', u'▁is', u'▁a', u'▁t', u'est']) 53 | 54 | self.assertListEqual( 55 | tokenizer.convert_tokens_to_ids(tokens), [285, 46, 10, 170, 382]) 56 | 57 | tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.") 58 | self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', 59 | u'or', u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'', 60 | u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this', 61 | SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's', u'é', u'.']) 62 | ids = tokenizer.convert_tokens_to_ids(tokens) 63 | self.assertListEqual( 64 | ids, [8, 21, 84, 55, 24, 19, 7, 0, 65 | 602, 347, 347, 347, 3, 12, 66, 66 | 46, 72, 80, 6, 0, 4]) 67 | 68 | back_tokens = tokenizer.convert_ids_to_tokens(ids) 69 | self.assertListEqual(back_tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', 70 | u'or', u'n', SPIECE_UNDERLINE + u'in', 71 | SPIECE_UNDERLINE + u'', u'', u'2', u'0', u'0', u'0', u',', 72 | SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this', 73 | SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's', 74 | u'', u'.']) 75 | 76 | def test_tokenizer_lower(self): 77 | tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=True) 78 | tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.") 79 | self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'', u'i', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', 80 | u'or', u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'', 81 | u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this', 82 | SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u'se', u'.']) 83 | self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), [u"▁he", u"ll", u"o"]) 84 | 85 | def test_tokenizer_no_lower(self): 86 | tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=False) 87 | tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.") 88 | self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', u'or', 89 | u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'', 90 | u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this', 91 | SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u'se', u'.']) 92 | 93 | @pytest.mark.slow 94 | def test_sequence_builders(self): 95 | tokenizer = XLNetTokenizer.from_pretrained("xlnet-base-cased") 96 | 97 | text = tokenizer.encode("sequence builders", add_special_tokens=False) 98 | text_2 = tokenizer.encode("multi-sequence build", add_special_tokens=False) 99 | 100 | encoded_sentence = tokenizer.build_inputs_with_special_tokens(text) 101 | encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2) 102 | 103 | assert encoded_sentence == text + [4, 3] 104 | assert encoded_pair == text + [4] + text_2 + [4, 3] 105 | 106 | 107 | if __name__ == '__main__': 108 | unittest.main() 109 | -------------------------------------------------------------------------------- /transformers/tokenization_camembert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License 15 | """ Tokenization classes for Camembert model.""" 16 | from __future__ import (absolute_import, division, print_function, 17 | unicode_literals) 18 | 19 | import logging 20 | import os 21 | from shutil import copyfile 22 | 23 | import sentencepiece as spm 24 | from transformers.tokenization_utils import PreTrainedTokenizer 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | VOCAB_FILES_NAMES = {'vocab_file': 'sentencepiece.bpe.model'} 29 | 30 | PRETRAINED_VOCAB_FILES_MAP = { 31 | 'vocab_file': 32 | { 33 | 'camembert-base': "https://s3.amazonaws.com/models.huggingface.co/bert/camembert-base-sentencepiece.bpe.model", 34 | } 35 | } 36 | 37 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 38 | 'camembert-base': None, 39 | } 40 | 41 | class CamembertTokenizer(PreTrainedTokenizer): 42 | """ 43 | Adapted from RobertaTokenizer and XLNetTokenizer 44 | SentencePiece based tokenizer. Peculiarities: 45 | 46 | - requires `SentencePiece `_ 47 | """ 48 | vocab_files_names = VOCAB_FILES_NAMES 49 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 50 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 51 | 52 | def __init__(self, vocab_file, bos_token="", eos_token="", sep_token="", 53 | cls_token="", unk_token="", pad_token='', mask_token='', 54 | additional_special_tokens=['NOTUSED', 'NOTUSED'], **kwargs): 55 | super(CamembertTokenizer, self).__init__(max_len=512, bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, 56 | sep_token=sep_token, cls_token=cls_token, pad_token=pad_token, 57 | mask_token=mask_token, additional_special_tokens=additional_special_tokens, 58 | **kwargs) 59 | self.max_len_single_sentence = self.max_len - 2 # take into account special tokens 60 | self.max_len_sentences_pair = self.max_len - 4 # take into account special tokens 61 | self.sp_model = spm.SentencePieceProcessor() 62 | self.sp_model.Load(str(vocab_file)) 63 | self.vocab_file = vocab_file 64 | # HACK: These tokens were added by fairseq but don't seem to be actually used when duplicated in the actual 65 | # sentencepiece vocabulary (this is the case for and 66 | self.fairseq_tokens_to_ids = {'NOTUSED': 0, '': 1, 'NOTUSED': 2, '': 3} 67 | self.fairseq_offset = len(self.fairseq_tokens_to_ids) 68 | self.fairseq_tokens_to_ids[''] = len(self.sp_model) + len(self.fairseq_tokens_to_ids) 69 | self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()} 70 | 71 | def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): 72 | """ 73 | Build model inputs from a sequence or a pair of sequence for sequence classification tasks 74 | by concatenating and adding special tokens. 75 | A RoBERTa sequence has the following format: 76 | single sequence: X 77 | pair of sequences: A B 78 | """ 79 | if token_ids_1 is None: 80 | return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] 81 | cls = [self.cls_token_id] 82 | sep = [self.sep_token_id] 83 | return cls + token_ids_0 + sep + sep + token_ids_1 + sep 84 | 85 | def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False): 86 | """ 87 | Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding 88 | special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods. 89 | 90 | Args: 91 | token_ids_0: list of ids (must not contain special tokens) 92 | token_ids_1: Optional list of ids (must not contain special tokens), necessary when fetching sequence ids 93 | for sequence pairs 94 | already_has_special_tokens: (default False) Set to True if the token list is already formated with 95 | special tokens for the model 96 | 97 | Returns: 98 | A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. 99 | """ 100 | if already_has_special_tokens: 101 | if token_ids_1 is not None: 102 | raise ValueError("You should not supply a second sequence if the provided sequence of " 103 | "ids is already formated with special tokens for the model.") 104 | return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) 105 | 106 | if token_ids_1 is None: 107 | return [1] + ([0] * len(token_ids_0)) + [1] 108 | return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1] 109 | 110 | def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None): 111 | """ 112 | Creates a mask from the two sequences passed to be used in a sequence-pair classification task. 113 | A RoBERTa sequence pair mask has the following format: 114 | 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 115 | | first sequence | second sequence 116 | 117 | if token_ids_1 is None, only returns the first portion of the mask (0's). 118 | """ 119 | sep = [self.sep_token_id] 120 | cls = [self.cls_token_id] 121 | 122 | if token_ids_1 is None: 123 | return len(cls + token_ids_0 + sep) * [0] 124 | return len(cls + token_ids_0 + sep + sep) * [0] + len(token_ids_1 + sep) * [1] 125 | 126 | @property 127 | def vocab_size(self): 128 | return self.fairseq_offset + len(self.sp_model) 129 | 130 | def _tokenize(self, text): 131 | return self.sp_model.EncodeAsPieces(text) 132 | 133 | def _convert_token_to_id(self, token): 134 | """ Converts a token (str/unicode) in an id using the vocab. """ 135 | if token in self.fairseq_tokens_to_ids: 136 | return self.fairseq_tokens_to_ids[token] 137 | return self.fairseq_offset + self.sp_model.PieceToId(token) 138 | 139 | def _convert_id_to_token(self, index): 140 | """Converts an index (integer) in a token (string/unicode) using the vocab.""" 141 | if index in self.fairseq_ids_to_tokens: 142 | return self.fairseq_ids_to_tokens[index] 143 | return self.sp_model.IdToPiece(index - self.fairseq_offset) 144 | 145 | def save_vocabulary(self, save_directory): 146 | """ Save the sentencepiece vocabulary (copy original file) and special tokens file 147 | to a directory. 148 | """ 149 | if not os.path.isdir(save_directory): 150 | logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) 151 | return 152 | out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file']) 153 | 154 | if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): 155 | copyfile(self.vocab_file, out_vocab_file) 156 | 157 | return (out_vocab_file,) 158 | -------------------------------------------------------------------------------- /transformers/tokenization_distilbert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for DistilBERT.""" 16 | 17 | from __future__ import absolute_import, division, print_function, unicode_literals 18 | 19 | import collections 20 | import logging 21 | import os 22 | import unicodedata 23 | from io import open 24 | 25 | from .tokenization_bert import BertTokenizer 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | VOCAB_FILES_NAMES = {'vocab_file': 'vocab.txt'} 30 | 31 | PRETRAINED_VOCAB_FILES_MAP = { 32 | 'vocab_file': 33 | { 34 | 'distilbert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", 35 | 'distilbert-base-uncased-distilled-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", 36 | } 37 | } 38 | 39 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 40 | 'distilbert-base-uncased': 512, 41 | 'distilbert-base-uncased-distilled-squad': 512, 42 | } 43 | 44 | 45 | class DistilBertTokenizer(BertTokenizer): 46 | r""" 47 | Constructs a DistilBertTokenizer. 48 | :class:`~transformers.DistilBertTokenizer` is identical to BertTokenizer and runs end-to-end tokenization: punctuation splitting + wordpiece 49 | 50 | Args: 51 | vocab_file: Path to a one-wordpiece-per-line vocabulary file 52 | do_lower_case: Whether to lower case the input. Only has an effect when do_wordpiece_only=False 53 | do_basic_tokenize: Whether to do basic tokenization before wordpiece. 54 | max_len: An artificial maximum length to truncate tokenized sequences to; Effective maximum length is always the 55 | minimum of this value (if specified) and the underlying BERT model's sequence length. 56 | never_split: List of tokens which will never be split during tokenization. Only has an effect when 57 | do_wordpiece_only=False 58 | """ 59 | 60 | vocab_files_names = VOCAB_FILES_NAMES 61 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 62 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 63 | -------------------------------------------------------------------------------- /transformers/tokenization_roberta.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for RoBERTa.""" 16 | from __future__ import (absolute_import, division, print_function, 17 | unicode_literals) 18 | 19 | import sys 20 | import json 21 | import logging 22 | import os 23 | import regex as re 24 | from io import open 25 | 26 | from .tokenization_gpt2 import GPT2Tokenizer 27 | 28 | try: 29 | from functools import lru_cache 30 | except ImportError: 31 | # Just a dummy decorator to get the checks to run on python2 32 | # because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now. 33 | def lru_cache(): 34 | return lambda func: func 35 | 36 | logger = logging.getLogger(__name__) 37 | 38 | VOCAB_FILES_NAMES = { 39 | 'vocab_file': 'vocab.json', 40 | 'merges_file': 'merges.txt', 41 | } 42 | 43 | PRETRAINED_VOCAB_FILES_MAP = { 44 | 'vocab_file': 45 | { 46 | 'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-vocab.json", 47 | 'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json", 48 | 'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-vocab.json", 49 | 'distilroberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/distilroberta-base-vocab.json", 50 | 'roberta-base-openai-detector': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-vocab.json", 51 | 'roberta-large-openai-detector': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json", 52 | }, 53 | 'merges_file': 54 | { 55 | 'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-merges.txt", 56 | 'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt", 57 | 'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-merges.txt", 58 | 'distilroberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/distilroberta-base-merges.txt", 59 | 'roberta-base-openai-detector': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-merges.txt", 60 | 'roberta-large-openai-detector': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt", 61 | }, 62 | } 63 | 64 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 65 | 'roberta-base': 512, 66 | 'roberta-large': 512, 67 | 'roberta-large-mnli': 512, 68 | 'distilroberta-base': 512, 69 | 'roberta-base-openai-detector': 512, 70 | 'roberta-large-openai-detector': 512, 71 | } 72 | 73 | 74 | class RobertaTokenizer(GPT2Tokenizer): 75 | """ 76 | RoBERTa BPE tokenizer, derived from the GPT-2 tokenizer. Peculiarities: 77 | - Byte-level Byte-Pair-Encoding 78 | - Requires a space to start the input string => the encoding methods should be called with the 79 | ``add_prefix_space`` flag set to ``True``. 80 | Otherwise, this tokenizer ``encode`` and ``decode`` method will not conserve 81 | the absence of a space at the beginning of a string: `tokenizer.decode(tokenizer.encode("Hello")) = " Hello"` 82 | """ 83 | vocab_files_names = VOCAB_FILES_NAMES 84 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 85 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 86 | 87 | def __init__(self, vocab_file, merges_file, errors='replace', bos_token="", eos_token="", sep_token="", 88 | cls_token="", unk_token="", pad_token='', mask_token='', **kwargs): 89 | super(RobertaTokenizer, self).__init__(vocab_file=vocab_file, merges_file=merges_file, errors=errors, 90 | bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, 91 | sep_token=sep_token, cls_token=cls_token, pad_token=pad_token, 92 | mask_token=mask_token, **kwargs) 93 | self.max_len_single_sentence = self.max_len - 2 # take into account special tokens 94 | self.max_len_sentences_pair = self.max_len - 4 # take into account special tokens 95 | 96 | def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): 97 | """ 98 | Build model inputs from a sequence or a pair of sequence for sequence classification tasks 99 | by concatenating and adding special tokens. 100 | A RoBERTa sequence has the following format: 101 | single sequence: X 102 | pair of sequences: A B 103 | """ 104 | if token_ids_1 is None: 105 | return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] 106 | cls = [self.cls_token_id] 107 | sep = [self.sep_token_id] 108 | return cls + token_ids_0 + sep + sep + token_ids_1 + sep 109 | 110 | def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False): 111 | """ 112 | Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding 113 | special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods. 114 | 115 | Args: 116 | token_ids_0: list of ids (must not contain special tokens) 117 | token_ids_1: Optional list of ids (must not contain special tokens), necessary when fetching sequence ids 118 | for sequence pairs 119 | already_has_special_tokens: (default False) Set to True if the token list is already formated with 120 | special tokens for the model 121 | 122 | Returns: 123 | A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. 124 | """ 125 | if already_has_special_tokens: 126 | if token_ids_1 is not None: 127 | raise ValueError("You should not supply a second sequence if the provided sequence of " 128 | "ids is already formated with special tokens for the model.") 129 | return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) 130 | 131 | if token_ids_1 is None: 132 | return [1] + ([0] * len(token_ids_0)) + [1] 133 | return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1] 134 | 135 | def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None): 136 | """ 137 | Creates a mask from the two sequences passed to be used in a sequence-pair classification task. 138 | A RoBERTa sequence pair mask has the following format: 139 | 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 140 | | first sequence | second sequence 141 | 142 | if token_ids_1 is None, only returns the first portion of the mask (0's). 143 | """ 144 | sep = [self.sep_token_id] 145 | cls = [self.cls_token_id] 146 | 147 | if token_ids_1 is None: 148 | return len(cls + token_ids_0 + sep) * [0] 149 | return len(cls + token_ids_0 + sep + sep) * [0] + len(token_ids_1 + sep) * [1] 150 | -------------------------------------------------------------------------------- /【2019 CCF BDCI】-负面判定-登峰造极-答辩PPT-最终版.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chevalier1024/CCF-BDCI-ABSA/5c534e443dd1d3ee8932c8369ebd80d2ea6bacec/【2019 CCF BDCI】-负面判定-登峰造极-答辩PPT-最终版.pptx --------------------------------------------------------------------------------