├── figure ├── loss.png ├── clue-roformerv2-base-classification.jpg ├── clue-roformerv2-large-classification.jpg ├── clue-chinesebert-large-classification.jpg └── clue-roformerv2-large-classification-paddle.jpg ├── requirements.txt ├── examples ├── clue │ └── classification │ │ ├── predict.sh │ │ ├── train.sh │ │ ├── accuracy.py │ │ ├── run_clue_predict_no_trainer.py │ │ ├── clue_11.py │ │ ├── clue_10.py │ │ └── run_clue_no_trainer.py └── dummpy │ ├── run_chnsenti.sh │ ├── test_mlm_v1.py │ ├── test_mlm_v2.py │ ├── test_sim.py │ └── task_text_classification_chnsenti.py ├── setup.py ├── test ├── compare_tokenizer.py └── compare_model.py ├── .gitignore ├── src └── roformer │ ├── tokenization_utils.py │ ├── convert_roformer_original_tf_checkpoint_to_pytorch.py │ ├── __init__.py │ ├── configuration_roformer.py │ ├── tokenization_roformer_fast.py │ └── tokenization_roformer.py ├── LICENSE └── README.md /figure/loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JunnYu/RoFormer_pytorch/HEAD/figure/loss.png -------------------------------------------------------------------------------- /figure/clue-roformerv2-base-classification.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JunnYu/RoFormer_pytorch/HEAD/figure/clue-roformerv2-base-classification.jpg -------------------------------------------------------------------------------- /figure/clue-roformerv2-large-classification.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JunnYu/RoFormer_pytorch/HEAD/figure/clue-roformerv2-large-classification.jpg -------------------------------------------------------------------------------- /figure/clue-chinesebert-large-classification.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JunnYu/RoFormer_pytorch/HEAD/figure/clue-chinesebert-large-classification.jpg -------------------------------------------------------------------------------- /figure/clue-roformerv2-large-classification-paddle.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JunnYu/RoFormer_pytorch/HEAD/figure/clue-roformerv2-large-classification-paddle.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | git+https://github.com/lonePatient/TorchBlocks.git 2 | transformers>=4.13.0 3 | bert4keras 4 | rjieba 5 | jieba 6 | scikit-learn 7 | numpy 8 | pandas 9 | matplotlib -------------------------------------------------------------------------------- /examples/clue/classification/predict.sh: -------------------------------------------------------------------------------- 1 | ALL_TASKS="cmnli iflytek tnews afqmc ocnli cluewsc2020 csl" 2 | 3 | for TASK_NAME in $ALL_TASKS 4 | do 5 | python run_clue_predict_no_trainer.py \ 6 | --model_name_or_path "outputs/$TASK_NAME/epoch_best" \ 7 | --task_name $TASK_NAME 8 | done -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name="roformer", 5 | package_dir={"": "src"}, 6 | packages=find_packages("src"), 7 | version="0.4.3", 8 | license="Apache 2.0", 9 | description="roformer_pytorch", 10 | author="Jun Yu", 11 | author_email="573009727@qq.com", 12 | url="https://github.com/JunnYu/RoFormer_pytorch", 13 | keywords=["roformer", "pytorch", "tf2.0"], 14 | install_requires=["transformers>=4.13.0", "rjieba"], 15 | ) 16 | -------------------------------------------------------------------------------- /examples/clue/classification/train.sh: -------------------------------------------------------------------------------- 1 | #export TRANSFORMERS_CACHE=/mnt/f/hf/models 2 | 3 | ALL_TASKS="iflytek tnews afqmc ocnli cluewsc2020 csl cmnli" 4 | for TASK_NAME in $ALL_TASKS 5 | do 6 | if [ $TASK_NAME == "cluewsc2020" ] ;then 7 | EPOCHS=30 8 | else 9 | EPOCHS=10 10 | fi 11 | python run_clue_no_trainer.py \ 12 | --model_name_or_path "junnyu/roformer_v2_chinese_char_base" \ 13 | --task_name $TASK_NAME \ 14 | --max_length 128 \ 15 | --per_device_train_batch_size 32 \ 16 | --per_device_eval_batch_size 64 \ 17 | --learning_rate 2e-5 \ 18 | --num_train_epochs $EPOCHS \ 19 | --weight_decay 0.01 \ 20 | --num_warmup_steps_or_radios 0.1 \ 21 | --seed 42 \ 22 | --with_tracking \ 23 | --output_dir ./outputs/$TASK_NAME/ 24 | done -------------------------------------------------------------------------------- /examples/dummpy/run_chnsenti.sh: -------------------------------------------------------------------------------- 1 | export MODEL_DIR=junnyu/chinese_roformer_base 2 | export DATA_DIR=../dataset 3 | export OUTPUR_DIR=../outputs 4 | export TASK_NAME=chnsenti 5 | 6 | #-----------training----------------- 7 | python task_text_classification_chnsenti.py \ 8 | --model_type=roformer \ 9 | --model_path=$MODEL_DIR \ 10 | --task_name=$TASK_NAME \ 11 | --do_train \ 12 | --do_eval \ 13 | --eval_all_checkpoints \ 14 | --gpu=0 \ 15 | --monitor=eval_acc \ 16 | --data_dir=$DATA_DIR/${TASK_NAME}/ \ 17 | --train_max_seq_length=128 \ 18 | --eval_max_seq_length=128 \ 19 | --per_gpu_train_batch_size=16 \ 20 | --per_gpu_eval_batch_size=32 \ 21 | --learning_rate=3e-5 \ 22 | --num_train_epochs=10.0 \ 23 | --logging_steps=-1 \ 24 | --save_steps=-1 \ 25 | --output_dir=$OUTPUR_DIR/${TASK_NAME}_output/ \ 26 | --overwrite_output_dir \ 27 | --seed=42 28 | -------------------------------------------------------------------------------- /test/compare_tokenizer.py: -------------------------------------------------------------------------------- 1 | import jieba 2 | from bert4keras.tokenizers import Tokenizer 3 | 4 | from roformer import RoFormerTokenizer, RoFormerTokenizerFast 5 | 6 | dict_path = "E:/BaiduNetdiskDownload/chinese_roformer_L-12_H-768_A-12" 7 | text = "12312格ab局A B cdA,.567 861351 684!今天萨达天 气非常好王企。文保鹅按时发放了的撒这些seqetvgsa国内拉手的喀什。、]P[,./()*7656&【;,‘" 8 | # text = "这里基本保留了唐宋遗留下来的坊巷格局和大量明清古建筑,其中各级文保单位29处,被誉为“里坊制度的活化石”“明清建筑博物馆”!" 9 | bert4keras_tokenizer = Tokenizer( 10 | dict_path + "/vocab.txt", 11 | do_lower_case=True, 12 | pre_tokenize=lambda s: jieba.cut(s, HMM=False), 13 | ) 14 | roformer_tokenizer = RoFormerTokenizer.from_pretrained(dict_path) 15 | roformer_tokenizer_fast = RoFormerTokenizerFast.from_pretrained(dict_path) 16 | bert4keras_tokenizer_input_ids = bert4keras_tokenizer.encode(text)[0] 17 | roformer_tokenizer_input_ids = roformer_tokenizer.encode(text) 18 | roformer_fast_tokenizer_input_ids = roformer_tokenizer_fast.encode(text) 19 | print(bert4keras_tokenizer_input_ids == roformer_tokenizer_input_ids) 20 | print(bert4keras_tokenizer_input_ids == roformer_fast_tokenizer_input_ids) 21 | -------------------------------------------------------------------------------- /examples/dummpy/test_mlm_v1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import tensorflow as tf 3 | from transformers import RoFormerForMaskedLM, RoFormerTokenizer, TFRoFormerForMaskedLM 4 | 5 | text = "今天[MASK]很好,我[MASK]去公园玩。" 6 | tokenizer = RoFormerTokenizer.from_pretrained("junnyu/roformer_chinese_base") 7 | pt_model = RoFormerForMaskedLM.from_pretrained("junnyu/roformer_chinese_base") 8 | tf_model = TFRoFormerForMaskedLM.from_pretrained( 9 | "junnyu/roformer_chinese_base", from_pt=True 10 | ) 11 | pt_inputs = tokenizer(text, return_tensors="pt") 12 | tf_inputs = tokenizer(text, return_tensors="tf") 13 | # pytorch 14 | with torch.no_grad(): 15 | pt_outputs = pt_model(**pt_inputs).logits[0] 16 | pt_outputs_sentence = "pytorch: " 17 | for i, id in enumerate(tokenizer.encode(text)): 18 | if id == tokenizer.mask_token_id: 19 | tokens = tokenizer.convert_ids_to_tokens(pt_outputs[i].topk(k=5)[1]) 20 | pt_outputs_sentence += "[" + "||".join(tokens) + "]" 21 | else: 22 | pt_outputs_sentence += "".join( 23 | tokenizer.convert_ids_to_tokens([id], skip_special_tokens=True) 24 | ) 25 | print(pt_outputs_sentence) 26 | # tf 27 | tf_outputs = tf_model(**tf_inputs, training=False).logits[0] 28 | tf_outputs_sentence = "tf: " 29 | for i, id in enumerate(tokenizer.encode(text)): 30 | if id == tokenizer.mask_token_id: 31 | tokens = tokenizer.convert_ids_to_tokens(tf.math.top_k(tf_outputs[i], k=5)[1]) 32 | tf_outputs_sentence += "[" + "||".join(tokens) + "]" 33 | else: 34 | tf_outputs_sentence += "".join( 35 | tokenizer.convert_ids_to_tokens([id], skip_special_tokens=True) 36 | ) 37 | print(tf_outputs_sentence) 38 | # pytorch: 今天[天气||天||心情||阳光||空气]很好,我[想||要||打算||准备||喜欢]去公园玩。 39 | # tf: 今天[天气||天||心情||阳光||空气]很好,我[想||要||打算||准备||喜欢]去公园玩。 40 | -------------------------------------------------------------------------------- /examples/dummpy/test_mlm_v2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import tensorflow as tf 3 | from transformers import BertTokenizer 4 | from roformer import RoFormerForMaskedLM, TFRoFormerForMaskedLM 5 | 6 | text = "今天[MASK]很好,我[MASK]去公园玩。" 7 | tokenizer = BertTokenizer.from_pretrained("junnyu/roformer_v2_chinese_char_base") 8 | pt_model = RoFormerForMaskedLM.from_pretrained("junnyu/roformer_v2_chinese_char_base") 9 | tf_model = TFRoFormerForMaskedLM.from_pretrained( 10 | "junnyu/roformer_v2_chinese_char_base", from_pt=True 11 | ) 12 | pt_inputs = tokenizer(text, return_tensors="pt") 13 | tf_inputs = tokenizer(text, return_tensors="tf") 14 | # pytorch 15 | with torch.no_grad(): 16 | pt_outputs = pt_model(**pt_inputs).logits[0] 17 | pt_outputs_sentence = "pytorch: " 18 | for i, id in enumerate(tokenizer.encode(text)): 19 | if id == tokenizer.mask_token_id: 20 | tokens = tokenizer.convert_ids_to_tokens(pt_outputs[i].topk(k=5)[1]) 21 | pt_outputs_sentence += "[" + "||".join(tokens) + "]" 22 | else: 23 | pt_outputs_sentence += "".join( 24 | tokenizer.convert_ids_to_tokens([id], skip_special_tokens=True) 25 | ) 26 | print(pt_outputs_sentence) 27 | # tf 28 | tf_outputs = tf_model(**tf_inputs, training=False).logits[0] 29 | tf_outputs_sentence = "tf: " 30 | for i, id in enumerate(tokenizer.encode(text)): 31 | if id == tokenizer.mask_token_id: 32 | tokens = tokenizer.convert_ids_to_tokens(tf.math.top_k(tf_outputs[i], k=5)[1]) 33 | tf_outputs_sentence += "[" + "||".join(tokens) + "]" 34 | else: 35 | tf_outputs_sentence += "".join( 36 | tokenizer.convert_ids_to_tokens([id], skip_special_tokens=True) 37 | ) 38 | print(tf_outputs_sentence) 39 | # small 40 | # pytorch: 今天[的||,||是||很||也]很好,我[要||会||是||想||在]去公园玩。 41 | # tf: 今天[的||,||是||很||也]很好,我[要||会||是||想||在]去公园玩。 42 | # base 43 | # pytorch: 今天[我||天||晴||园||玩]很好,我[想||要||会||就||带]去公园玩。 44 | # tf: 今天[我||天||晴||园||玩]很好,我[想||要||会||就||带]去公园玩。 45 | # large 46 | # pytorch: 今天[天||气||我||空||阳]很好,我[又||想||会||就||爱]去公园玩。 47 | # tf: 今天[天||气||我||空||阳]很好,我[又||想||会||就||爱]去公园玩。 48 | -------------------------------------------------------------------------------- /examples/dummpy/test_sim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from roformer import RoFormerForCausalLM, RoFormerConfig 4 | from transformers import BertTokenizer 5 | 6 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 7 | pretrained_model = "junnyu/roformer_chinese_sim_char_base" 8 | tokenizer = BertTokenizer.from_pretrained(pretrained_model) 9 | config = RoFormerConfig.from_pretrained(pretrained_model) 10 | config.is_decoder = True 11 | config.eos_token_id = tokenizer.sep_token_id 12 | config.pooler_activation = "linear" 13 | model = RoFormerForCausalLM.from_pretrained(pretrained_model, config=config) 14 | model.to(device) 15 | model.eval() 16 | 17 | def gen_synonyms(text, n=100, k=20): 18 | ''''含义: 产生sent的n个相似句,然后返回最相似的k个。 19 | 做法:用seq2seq生成,并用encoder算相似度并排序。 20 | ''' 21 | # 寻找所有相似的句子 22 | r = [] 23 | inputs1 = tokenizer(text, return_tensors="pt") 24 | for _ in range(n): 25 | inputs1.to(device) 26 | output = tokenizer.batch_decode(model.generate(**inputs1, top_p=0.95, do_sample=True, max_length=128), skip_special_tokens=True)[0].replace(" ","").replace(text, "") # 去除空格,去除原始text文本。 27 | r.append(output) 28 | 29 | # 对相似的句子进行排序 30 | r = [i for i in set(r) if i != text and len(i) > 0] 31 | r = [text] + r 32 | inputs2 = tokenizer(r, padding=True, return_tensors="pt") 33 | with torch.no_grad(): 34 | inputs2.to(device) 35 | outputs = model(**inputs2) 36 | Z = outputs.pooler_output.cpu().numpy() 37 | Z /= (Z**2).sum(axis=1, keepdims=True)**0.5 38 | argsort = np.dot(Z[1:], -Z[0]).argsort() 39 | 40 | return [r[i + 1] for i in argsort[:k]] 41 | 42 | out = gen_synonyms("广州和深圳哪个好?") 43 | print(out) 44 | # ['深圳和广州哪个好?', 45 | # '广州和深圳哪个好', 46 | # '深圳和广州哪个好', 47 | # '深圳和广州哪个比较好。', 48 | # '深圳和广州哪个最好?', 49 | # '深圳和广州哪个比较好', 50 | # '广州和深圳那个比较好', 51 | # '深圳和广州哪个更好?', 52 | # '深圳与广州哪个好', 53 | # '深圳和广州,哪个比较好', 54 | # '广州与深圳比较哪个好', 55 | # '深圳和广州哪里比较好', 56 | # '深圳还是广州比较好?', 57 | # '广州和深圳哪个地方好一些?', 58 | # '广州好还是深圳好?', 59 | # '广州好还是深圳好呢?', 60 | # '广州与深圳哪个地方好点?', 61 | # '深圳好还是广州好', 62 | # '广州好还是深圳好', 63 | # '广州和深圳哪个城市好?'] 64 | -------------------------------------------------------------------------------- /test/compare_model.py: -------------------------------------------------------------------------------- 1 | import jieba 2 | import tensorflow as tf 3 | import torch 4 | from bert4keras.models import build_transformer_model 5 | from bert4keras.tokenizers import Tokenizer 6 | 7 | from roformer import RoFormerModel, RoFormerTokenizerFast, TFRoFormerModel 8 | 9 | jieba.initialize() 10 | config_path = ( 11 | "E:/BaiduNetdiskDownload/chinese_roformer_L-12_H-768_A-12/bert_config.json" 12 | ) 13 | checkpoint_path = ( 14 | "E:/BaiduNetdiskDownload/chinese_roformer_L-12_H-768_A-12/bert_model.ckpt" 15 | ) 16 | dict_path = "E:/BaiduNetdiskDownload/chinese_roformer_L-12_H-768_A-12/vocab.txt" 17 | # converted_ckpt_path = "pretrained_models/chinese_roformer_base" 18 | converted_ckpt_path = "junnyu/roformer_chinese_base" # https://huggingface.co/junnyu/roformer_chinese_base 19 | tokenizer = Tokenizer( 20 | dict_path, do_lower_case=True, pre_tokenize=lambda s: jieba.cut(s, HMM=False) 21 | ) 22 | text = "这里基本保留了唐宋遗留下来的坊巷格局和大量明清古建筑,其中各级文保单位29处,被誉为“里坊制度的活化石”“明清建筑博物馆”!" 23 | 24 | # bert4keras 25 | inputs = tokenizer.encode(text) 26 | tf_inputs = [ 27 | tf.convert_to_tensor(inputs[0])[None], 28 | tf.convert_to_tensor(inputs[1])[None], 29 | ] 30 | model = build_transformer_model( 31 | config_path=config_path, checkpoint_path=checkpoint_path, model="roformer" 32 | ) 33 | bert4keras_outputs = torch.tensor(model(tf_inputs, training=False).numpy()) 34 | 35 | # pt 36 | roformer_tokenizer = RoFormerTokenizerFast.from_pretrained(converted_ckpt_path) 37 | pt_model = RoFormerModel.from_pretrained(converted_ckpt_path) 38 | pt_inputs = roformer_tokenizer(text, return_tensors="pt") 39 | with torch.no_grad(): 40 | pt_outputs = pt_model(**pt_inputs).last_hidden_state 41 | 42 | # tf 43 | tf_model = TFRoFormerModel.from_pretrained(converted_ckpt_path, from_pt=True) 44 | tf_inputs = roformer_tokenizer(text, return_tensors="tf") 45 | tf_outputs = torch.from_numpy( 46 | tf_model(**tf_inputs, training=False).last_hidden_state.numpy() 47 | ) 48 | 49 | print("bert4keras vs pytorch") 50 | print("mean diff :", (bert4keras_outputs - pt_outputs).abs().mean()) 51 | print("max diff :", (bert4keras_outputs - pt_outputs).abs().max()) 52 | print("bert4keras vs tf2.0") 53 | print("mean diff :", (bert4keras_outputs - tf_outputs).abs().mean()) 54 | print("max diff :", (bert4keras_outputs - tf_outputs).abs().max()) 55 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | beifen/ 6 | # C extensions 7 | *.so 8 | model/old.py 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | test.py 131 | win.bat 132 | **/pytorch_model.bin 133 | 134 | dataset/chnsenti/*.tsv -------------------------------------------------------------------------------- /src/roformer/tokenization_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization utils for RoFormer.""" 16 | 17 | from typing import List 18 | 19 | from tokenizers import NormalizedString, PreTokenizedString, normalizers 20 | 21 | 22 | class JiebaPreTokenizer: 23 | def __init__(self, vocab) -> None: 24 | self.vocab = vocab 25 | self.normalizers = normalizers.BertNormalizer( 26 | clean_text=False, 27 | handle_chinese_chars=True, 28 | strip_accents=False, 29 | lowercase=False, 30 | ) 31 | try: 32 | import rjieba 33 | except ImportError: 34 | raise ImportError( 35 | "You need to install rjieba to use RoFormerTokenizer. " 36 | "See https://pypi.org/project/rjieba/ for installation." 37 | ) 38 | self.jieba = rjieba 39 | 40 | def jieba_split( 41 | self, i: int, normalized_string: NormalizedString 42 | ) -> List[NormalizedString]: 43 | splits = [] 44 | 45 | # this code slice normalized_string is too slow (6s) but test_alignement_methods can pass 46 | for token, start, end in self.jieba.tokenize(str(normalized_string), hmm=False): 47 | if token in self.vocab: 48 | splits.append(normalized_string[start:end]) 49 | else: 50 | token_list = self.normalizers.normalize_str(token).split() 51 | for token in token_list: 52 | if token: 53 | end = start + len(token) 54 | splits.append(normalized_string[start:end]) 55 | start = end 56 | 57 | # this code test_alignement_methods can't pass but fast (300ms) 58 | # for token in self.jieba.cut(str(normalized_string), False): 59 | # if token in self.vocab: 60 | # splits.append(NormalizedString(token)) 61 | # else: 62 | # token_list = self.normalizers.normalize_str(token).split() 63 | # for token in token_list: 64 | # if token: 65 | # splits.append(NormalizedString(token)) 66 | 67 | return splits 68 | 69 | def pre_tokenize(self, pretok: PreTokenizedString): 70 | pretok.split(self.jieba_split) 71 | -------------------------------------------------------------------------------- /src/roformer/convert_roformer_original_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert RoFormer checkpoint.""" 16 | 17 | import argparse 18 | 19 | import torch 20 | from transformers.utils import logging 21 | 22 | from roformer import RoFormerConfig, RoFormerForMaskedLM, RoFormerForCausalLM, load_tf_weights_in_roformer 23 | 24 | logging.set_verbosity_info() 25 | 26 | 27 | def convert_tf_checkpoint_to_pytorch( 28 | tf_checkpoint_path, bert_config_file, pytorch_dump_path, roformer_sim=False 29 | ): 30 | # Initialise PyTorch model 31 | config = RoFormerConfig.from_json_file(bert_config_file) 32 | print(f"Building PyTorch model from configuration: {config}") 33 | 34 | if roformer_sim: 35 | # 如果转换roformer-sim的话,需要使用RoFormerForCausalLM,这个带有pooler的权重 36 | config.is_decoder = True 37 | config.eos_token_id = 102 38 | config.pooler_activation = "linear" 39 | model = RoFormerForCausalLM(config) 40 | else: 41 | model = RoFormerForMaskedLM(config) 42 | 43 | # Load weights from tf checkpoint 44 | load_tf_weights_in_roformer(model, config, tf_checkpoint_path) 45 | 46 | # ignore 不保存roformer.encoder.embed_positions.weight 47 | _keys_to_ignore_on_save = ["roformer.encoder.embed_positions.weight"] 48 | state_dict = model.state_dict() 49 | for ignore_key in _keys_to_ignore_on_save: 50 | if ignore_key in state_dict.keys(): 51 | del state_dict[ignore_key] 52 | 53 | # Save pytorch-model 54 | print(f"Save PyTorch model to {pytorch_dump_path}") 55 | torch.save( 56 | state_dict, pytorch_dump_path, _use_new_zipfile_serialization=False 57 | ) 58 | 59 | 60 | if __name__ == "__main__": 61 | parser = argparse.ArgumentParser() 62 | # Required parameters 63 | parser.add_argument( 64 | "--tf_checkpoint_path", 65 | default=None, 66 | type=str, 67 | required=True, 68 | help="Path to the TensorFlow checkpoint path.", 69 | ) 70 | parser.add_argument( 71 | "--bert_config_file", 72 | default=None, 73 | type=str, 74 | required=True, 75 | help="The config json file corresponding to the pre-trained BERT model. \n" 76 | "This specifies the model architecture.", 77 | ) 78 | parser.add_argument( 79 | "--pytorch_dump_path", 80 | default=None, 81 | type=str, 82 | required=True, 83 | help="Path to the output PyTorch model.", 84 | ) 85 | parser.add_argument( 86 | "--roformer_sim", 87 | action="store_true", 88 | help="Whether or not roformer-sim.", 89 | ) 90 | args = parser.parse_args() 91 | convert_tf_checkpoint_to_pytorch( 92 | args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path, args.roformer_sim 93 | ) 94 | -------------------------------------------------------------------------------- /examples/clue/classification/accuracy.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Accuracy metric.""" 15 | 16 | import datasets 17 | from sklearn.metrics import accuracy_score 18 | 19 | _DESCRIPTION = """ 20 | Accuracy is the proportion of correct predictions among the total number of cases processed. It can be computed with: 21 | Accuracy = (TP + TN) / (TP + TN + FP + FN) 22 | TP: True positive 23 | TN: True negative 24 | FP: False positive 25 | FN: False negative 26 | """ 27 | 28 | _KWARGS_DESCRIPTION = """ 29 | Args: 30 | predictions: Predicted labels, as returned by a model. 31 | references: Ground truth labels. 32 | normalize: If False, return the number of correctly classified samples. 33 | Otherwise, return the fraction of correctly classified samples. 34 | sample_weight: Sample weights. 35 | Returns: 36 | accuracy: Accuracy score. 37 | Examples: 38 | >>> accuracy_metric = datasets.load_metric("accuracy") 39 | >>> results = accuracy_metric.compute(references=[0, 1], predictions=[0, 1]) 40 | >>> print(results) 41 | {'accuracy': 1.0} 42 | """ 43 | 44 | _CITATION = """\ 45 | @article{scikit-learn, 46 | title={Scikit-learn: Machine Learning in {P}ython}, 47 | author={Pedregosa, F. and Varoquaux, G. and Gramfort, A. and Michel, V. 48 | and Thirion, B. and Grisel, O. and Blondel, M. and Prettenhofer, P. 49 | and Weiss, R. and Dubourg, V. and Vanderplas, J. and Passos, A. and 50 | Cournapeau, D. and Brucher, M. and Perrot, M. and Duchesnay, E.}, 51 | journal={Journal of Machine Learning Research}, 52 | volume={12}, 53 | pages={2825--2830}, 54 | year={2011} 55 | } 56 | """ 57 | 58 | 59 | @datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) 60 | class Accuracy(datasets.Metric): 61 | def _info(self): 62 | return datasets.MetricInfo( 63 | description=_DESCRIPTION, 64 | citation=_CITATION, 65 | inputs_description=_KWARGS_DESCRIPTION, 66 | features=datasets.Features( 67 | { 68 | "predictions": datasets.Sequence(datasets.Value("int32")), 69 | "references": datasets.Sequence(datasets.Value("int32")), 70 | } 71 | if self.config_name == "multilabel" 72 | else { 73 | "predictions": datasets.Value("int32"), 74 | "references": datasets.Value("int32"), 75 | } 76 | ), 77 | reference_urls=[ 78 | "https://scikit-learn.org/stable/modules/generated/sklearn.metrics.accuracy_score.html" 79 | ], 80 | ) 81 | 82 | def _compute(self, predictions, references, normalize=True, sample_weight=None): 83 | return { 84 | "accuracy": float( 85 | accuracy_score( 86 | references, 87 | predictions, 88 | normalize=normalize, 89 | sample_weight=sample_weight, 90 | ) 91 | ) 92 | } 93 | -------------------------------------------------------------------------------- /src/roformer/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2021 The HuggingFace Team. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | from typing import TYPE_CHECKING 19 | 20 | from transformers.file_utils import ( 21 | _LazyModule, 22 | is_tf_available, 23 | is_tokenizers_available, 24 | is_torch_available, 25 | ) 26 | 27 | 28 | _import_structure = { 29 | "configuration_roformer": [ 30 | "ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", 31 | "RoFormerConfig", 32 | ], 33 | "tokenization_roformer": ["RoFormerTokenizer"], 34 | } 35 | 36 | if is_tokenizers_available(): 37 | _import_structure["tokenization_roformer_fast"] = ["RoFormerTokenizerFast"] 38 | 39 | if is_torch_available(): 40 | _import_structure["modeling_roformer"] = [ 41 | "ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST", 42 | "RoFormerForCausalLM", 43 | "RoFormerForMaskedLM", 44 | "RoFormerForMultipleChoice", 45 | "RoFormerForQuestionAnswering", 46 | "RoFormerForSequenceClassification", 47 | "RoFormerForTokenClassification", 48 | "RoFormerLayer", 49 | "RoFormerModel", 50 | "RoFormerPreTrainedModel", 51 | "load_tf_weights_in_roformer", 52 | ] 53 | 54 | 55 | if is_tf_available(): 56 | _import_structure["modeling_tf_roformer"] = [ 57 | "TF_ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST", 58 | "TFRoFormerForCausalLM", 59 | "TFRoFormerForMaskedLM", 60 | "TFRoFormerForMultipleChoice", 61 | "TFRoFormerForQuestionAnswering", 62 | "TFRoFormerForSequenceClassification", 63 | "TFRoFormerForTokenClassification", 64 | "TFRoFormerLayer", 65 | "TFRoFormerModel", 66 | "TFRoFormerPreTrainedModel", 67 | ] 68 | 69 | 70 | if TYPE_CHECKING: 71 | from .configuration_roformer import ( 72 | ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, 73 | RoFormerConfig, 74 | ) 75 | from .tokenization_roformer import RoFormerTokenizer 76 | 77 | if is_tokenizers_available(): 78 | from .tokenization_roformer_fast import RoFormerTokenizerFast 79 | 80 | if is_torch_available(): 81 | from .modeling_roformer import ( 82 | ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, 83 | RoFormerForCausalLM, 84 | RoFormerForMaskedLM, 85 | RoFormerForMultipleChoice, 86 | RoFormerForQuestionAnswering, 87 | RoFormerForSequenceClassification, 88 | RoFormerForTokenClassification, 89 | RoFormerLayer, 90 | RoFormerModel, 91 | RoFormerPreTrainedModel, 92 | load_tf_weights_in_roformer, 93 | ) 94 | 95 | if is_tf_available(): 96 | from .modeling_tf_roformer import ( 97 | TF_ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, 98 | TFRoFormerForCausalLM, 99 | TFRoFormerForMaskedLM, 100 | TFRoFormerForMultipleChoice, 101 | TFRoFormerForQuestionAnswering, 102 | TFRoFormerForSequenceClassification, 103 | TFRoFormerForTokenClassification, 104 | TFRoFormerLayer, 105 | TFRoFormerModel, 106 | TFRoFormerPreTrainedModel, 107 | ) 108 | 109 | 110 | else: 111 | import sys 112 | 113 | sys.modules[__name__] = _LazyModule( 114 | __name__, globals()["__file__"], _import_structure 115 | ) 116 | -------------------------------------------------------------------------------- /examples/dummpy/task_text_classification_chnsenti.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os 3 | 4 | from torchblocks.callback import TrainLogger 5 | from torchblocks.metrics import Accuracy 6 | from torchblocks.processor import InputExample, TextClassifierProcessor 7 | from torchblocks.trainer import TextClassifierTrainer 8 | from torchblocks.utils import ( 9 | build_argparse, 10 | dict_to_text, 11 | get_checkpoints, 12 | prepare_device, 13 | seed_everything, 14 | ) 15 | from transformers import WEIGHTS_NAME 16 | 17 | from roformer import ( 18 | RoFormerConfig, 19 | RoFormerForSequenceClassification, 20 | RoFormerTokenizer, 21 | ) 22 | 23 | MODEL_CLASSES = { 24 | "roformer": (RoFormerConfig, RoFormerForSequenceClassification, RoFormerTokenizer) 25 | } 26 | 27 | 28 | class ChnSentiProcessor(TextClassifierProcessor): 29 | def get_labels(self): 30 | """See base class.""" 31 | return ["0", "1"] 32 | 33 | def read_data(self, input_file): 34 | """Reads a json list file.""" 35 | with open(input_file, "r", encoding="utf-8-sig") as f: 36 | reader = csv.reader(f, delimiter="\t", quotechar=None) 37 | lines = [] 38 | for line in reader: 39 | lines.append(line) 40 | return lines 41 | 42 | def create_examples(self, lines, set_type): 43 | """Creates examples for the training and dev sets.""" 44 | examples = [] 45 | for (i, line) in enumerate(lines): 46 | if i == 0: 47 | continue 48 | guid = "%s-%s" % (set_type, i) 49 | text_a = line[1] 50 | text_b = None 51 | label = line[0] 52 | examples.append( 53 | InputExample(guid=guid, texts=[text_a, text_b], label=label) 54 | ) 55 | return examples 56 | 57 | 58 | def main(): 59 | args = build_argparse().parse_args() 60 | if args.model_name is None: 61 | args.model_name = args.model_path.split("/")[-1] 62 | args.output_dir = args.output_dir + "{}".format(args.model_name) 63 | os.makedirs(args.output_dir, exist_ok=True) 64 | 65 | # output dir 66 | prefix = "_".join([args.model_name, args.task_name]) 67 | logger = TrainLogger(log_dir=args.output_dir, prefix=prefix) 68 | 69 | # device 70 | logger.info("initializing device") 71 | args.device, args.n_gpu = prepare_device(args.gpu, args.local_rank) 72 | seed_everything(args.seed) 73 | args.model_type = args.model_type.lower() 74 | config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] 75 | 76 | # data processor 77 | logger.info("initializing data processor") 78 | tokenizer = tokenizer_class.from_pretrained( 79 | args.model_path, do_lower_case=args.do_lower_case 80 | ) 81 | processor = ChnSentiProcessor( 82 | data_dir=args.data_dir, tokenizer=tokenizer, prefix=prefix 83 | ) 84 | label_list = processor.get_labels() 85 | num_labels = len(label_list) 86 | args.num_labels = num_labels 87 | 88 | # model 89 | logger.info("initializing model and config") 90 | config = config_class.from_pretrained( 91 | args.model_path, 92 | num_labels=num_labels, 93 | cache_dir=args.cache_dir if args.cache_dir else None, 94 | ) 95 | model = model_class.from_pretrained(args.model_path, config=config) 96 | model.to(args.device) 97 | 98 | # trainer 99 | logger.info("initializing traniner") 100 | trainer = TextClassifierTrainer( 101 | logger=logger, 102 | args=args, 103 | collate_fn=processor.collate_fn, 104 | input_keys=processor.get_input_keys(), 105 | metrics=[Accuracy()], 106 | ) 107 | # do train 108 | if args.do_train: 109 | train_dataset = processor.create_dataset( 110 | args.train_max_seq_length, "train.tsv", "train" 111 | ) 112 | eval_dataset = processor.create_dataset( 113 | args.eval_max_seq_length, "dev.tsv", "dev" 114 | ) 115 | trainer.train(model, train_dataset=train_dataset, eval_dataset=eval_dataset) 116 | # do eval 117 | if args.do_eval and args.local_rank in [-1, 0]: 118 | results = {} 119 | eval_dataset = processor.create_dataset( 120 | args.eval_max_seq_length, "test.tsv", "test" 121 | ) 122 | checkpoints = [args.output_dir] 123 | if args.eval_all_checkpoints or args.checkpoint_number > 0: 124 | checkpoints = get_checkpoints( 125 | args.output_dir, args.checkpoint_number, WEIGHTS_NAME 126 | ) 127 | logger.info("Evaluate the following checkpoints: %s", checkpoints) 128 | for checkpoint in checkpoints: 129 | global_step = checkpoint.split("/")[-1].split("-")[-1] 130 | model = model_class.from_pretrained(checkpoint, config=config) 131 | model.to(args.device) 132 | trainer.evaluate( 133 | model, eval_dataset, save_preds=True, prefix=str(global_step) 134 | ) 135 | if global_step: 136 | result = { 137 | "{}_{}".format(global_step, k): v 138 | for k, v in trainer.records["result"].items() 139 | } 140 | results.update(result) 141 | output_eval_file = os.path.join(args.output_dir, "eval_results.txt") 142 | dict_to_text(output_eval_file, results) 143 | 144 | 145 | if __name__ == "__main__": 146 | main() 147 | -------------------------------------------------------------------------------- /examples/clue/classification/run_clue_predict_no_trainer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import torch 6 | from accelerate import Accelerator 7 | from datasets import load_dataset 8 | from torch.utils.data import DataLoader 9 | from tqdm.auto import tqdm 10 | from transformers import AutoTokenizer, DataCollatorWithPadding 11 | 12 | from roformer import RoFormerForSequenceClassification 13 | 14 | task_to_keys = { 15 | "iflytek": ("sentence", None), 16 | "tnews": ("sentence", None), 17 | "afqmc": ("sentence1", "sentence2"), 18 | "cmnli": ("sentence1", "sentence2"), 19 | "ocnli": ("sentence1", "sentence2"), 20 | "cluewsc2020": ("text", None), 21 | "csl": ("keyword", "abst"), 22 | } 23 | # 11 24 | task_to_outputfile11 = { 25 | "iflytek": "iflytek_predict.json", 26 | "tnews": "tnews11_predict.json", 27 | "afqmc": "afqmc_predict.json", 28 | "cmnli": "cmnli_predict.json", 29 | "ocnli": "ocnli_50k_predict.json", 30 | "cluewsc2020": "cluewsc11_predict.json", 31 | "csl": "csl_predict.json", 32 | } 33 | # 1.0 34 | task_to_outputfile10 = { 35 | "tnews": "tnews10_predict.json", 36 | "cluewsc2020": "cluewsc10_predict.json", 37 | } 38 | 39 | 40 | def parse_args(): 41 | parser = argparse.ArgumentParser( 42 | description="Predict a transformers model on a text classification task" 43 | ) 44 | parser.add_argument( 45 | "--task_name", 46 | type=str, 47 | default=None, 48 | help="The name of the CLUE task to train on.", 49 | choices=list(task_to_keys.keys()), 50 | ) 51 | parser.add_argument( 52 | "--max_length", 53 | type=int, 54 | default=128, 55 | help=( 56 | "The maximum total input sequence length after tokenization. Sequences longer than this will be truncated," 57 | " sequences shorter will be padded if `--pad_to_max_lengh` is passed." 58 | ), 59 | ) 60 | parser.add_argument( 61 | "--pad_to_max_length", 62 | action="store_true", 63 | help="If passed, pad all samples to `max_length`. Otherwise, dynamic padding is used.", 64 | ) 65 | parser.add_argument( 66 | "--model_name_or_path", 67 | type=str, 68 | help="Path to pretrained model or model identifier from huggingface.co/models.", 69 | required=True, 70 | ) 71 | parser.add_argument( 72 | "--per_device_eval_batch_size", 73 | type=int, 74 | default=32, 75 | help="Batch size (per device) for the evaluation dataloader.", 76 | ) 77 | parser.add_argument( 78 | "--version", 79 | type=int, 80 | default=1.0, 81 | help="Batch size (per device) for the evaluation dataloader.", 82 | ) 83 | args = parser.parse_args() 84 | 85 | # Sanity checks 86 | if args.task_name is None: 87 | raise ValueError("task_name should not be none.") 88 | 89 | return args 90 | 91 | 92 | def predict(): 93 | args = parse_args() 94 | accelerator = Accelerator() 95 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) 96 | model = RoFormerForSequenceClassification.from_pretrained(args.model_name_or_path) 97 | model.eval() 98 | 99 | all_eval_datasets = [] 100 | all_file_names = [] 101 | if args.task_name in task_to_outputfile10.keys(): 102 | all_eval_datasets.append( 103 | load_dataset( 104 | "clue_10.py", args.task_name, cache_dir="./clue_caches_10", split="test" 105 | ) 106 | ) 107 | all_file_names.append(task_to_outputfile10[args.task_name]) 108 | 109 | if args.task_name in task_to_outputfile11.keys(): 110 | all_eval_datasets.append( 111 | load_dataset( 112 | "clue_11.py", args.task_name, cache_dir="./clue_caches", split="test" 113 | ) 114 | ) 115 | all_file_names.append(task_to_outputfile11[args.task_name]) 116 | 117 | for raw_test_dataset, file in zip(all_eval_datasets, all_file_names): 118 | os.makedirs("results", exist_ok=True) 119 | out_file = f"results/{file}" 120 | sentence1_key, sentence2_key = task_to_keys[args.task_name] 121 | int2str = raw_test_dataset.features["label"].int2str 122 | padding = "max_length" if args.pad_to_max_length else False 123 | 124 | def preprocess_test_function(examples): 125 | # Tokenize the texts 126 | if sentence1_key == "keyword": 127 | k1 = [";".join(l) for l in examples[sentence1_key]] 128 | else: 129 | k1 = examples[sentence1_key] 130 | texts = (k1,) if sentence2_key is None else (k1, examples[sentence2_key]) 131 | result = tokenizer( 132 | *texts, 133 | padding=padding, 134 | max_length=args.max_length, 135 | truncation=True, 136 | return_token_type_ids=False, 137 | ) 138 | return result 139 | 140 | with accelerator.main_process_first(): 141 | processed_test_dataset = raw_test_dataset.map( 142 | preprocess_test_function, 143 | batched=True, 144 | remove_columns=raw_test_dataset.column_names, 145 | desc="Running tokenizer on test dataset", 146 | ) 147 | data_collator = DataCollatorWithPadding( 148 | tokenizer, pad_to_multiple_of=(8 if accelerator.use_fp16 else None) 149 | ) 150 | test_dataloader = DataLoader( 151 | processed_test_dataset, 152 | collate_fn=data_collator, 153 | batch_size=args.per_device_eval_batch_size, 154 | ) 155 | model, test_dataloader = accelerator.prepare(model, test_dataloader) 156 | 157 | samples_seen = 0 158 | all_predictions = [] 159 | 160 | with torch.no_grad(): 161 | for step, batch in enumerate(tqdm(test_dataloader)): 162 | outputs = model(**batch) 163 | predictions = outputs.logits.argmax(dim=-1).cpu().numpy() 164 | predictions = accelerator.gather(predictions) 165 | # If we are in a multiprocess environment, the last batch has duplicates 166 | if accelerator.num_processes > 1: 167 | if step == len(test_dataloader): 168 | predictions = predictions[ 169 | : len(test_dataloader.dataset) - samples_seen 170 | ] 171 | else: 172 | samples_seen += predictions.shape[0] 173 | all_predictions.extend(int2str(predictions)) 174 | 175 | with open(out_file, "w") as fw: 176 | for idx, pred in zip(raw_test_dataset["idx"], all_predictions): 177 | l = json.dumps({"id": str(idx), "label": pred}) 178 | fw.write(l + "\n") 179 | fw.close() 180 | 181 | 182 | if __name__ == "__main__": 183 | predict() 184 | -------------------------------------------------------------------------------- /src/roformer/configuration_roformer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ RoFormer model configuration """ 16 | 17 | from transformers.configuration_utils import PretrainedConfig 18 | from transformers.utils import logging 19 | 20 | logger = logging.get_logger(__name__) 21 | 22 | ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = { 23 | "junnyu/roformer_chinese_small": "https://huggingface.co/junnyu/roformer_chinese_small/resolve/main/config.json", 24 | "junnyu/roformer_chinese_base": "https://huggingface.co/junnyu/roformer_chinese_base/resolve/main/config.json", 25 | "junnyu/roformer_chinese_char_small": "https://huggingface.co/junnyu/roformer_chinese_char_small/resolve/main/config.json", 26 | "junnyu/roformer_chinese_char_base": "https://huggingface.co/junnyu/roformer_chinese_char_base/resolve/main/config.json", 27 | "junnyu/roformer_chinese_sim_char_small": "https://huggingface.co/junnyu/roformer_chinese_sim_char_small/resolve/main/config.json", 28 | "junnyu/roformer_chinese_sim_char_base": "https://huggingface.co/junnyu/roformer_chinese_sim_char_base/resolve/main/config.json", 29 | "junnyu/roformer_chinese_sim_char_ft_base": "https://huggingface.co/junnyu/roformer_chinese_sim_char_ft_base/resolve/main/config.json", 30 | "junnyu/roformer_chinese_sim_char_ft_small": "https://huggingface.co/junnyu/roformer_chinese_sim_char_ft_small/resolve/main/config.json", 31 | "junnyu/roformer_small_discriminator": "https://huggingface.co/junnyu/roformer_small_discriminator/resolve/main/config.json", 32 | "junnyu/roformer_small_generator": "https://huggingface.co/junnyu/roformer_small_generator/resolve/main/config.json", 33 | "junnyu/roformer_base_wwm_cluecorpussmall": "https://huggingface.co/junnyu/roformer_base_wwm_cluecorpussmall/resolve/main/config.json", 34 | "junnyu/roformer_v2_chinese_char_small": "https://huggingface.co/junnyu/roformer_v2_chinese_char_small/resolve/main/config.json", 35 | "junnyu/roformer_v2_chinese_char_base": "https://huggingface.co/junnyu/roformer_v2_chinese_char_base/resolve/main/config.json", 36 | "junnyu/roformer_v2_chinese_char_large": "https://huggingface.co/junnyu/roformer_v2_chinese_char_large/resolve/main/config.json", 37 | # See all RoFormer models at https://huggingface.co/models?filter=roformer 38 | } 39 | 40 | 41 | class RoFormerConfig(PretrainedConfig): 42 | r""" 43 | This is the configuration class to store the configuration of a :class:`~transformers.RoFormerModel`. It is used to 44 | instantiate an RoFormer model according to the specified arguments, defining the model architecture. Instantiating 45 | a configuration with the defaults will yield a similar configuration to that of the RoFormer 46 | `junnyu/roformer_chinese_base `__ architecture. 47 | Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model 48 | outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information. 49 | Args: 50 | vocab_size (:obj:`int`, `optional`, defaults to 50000): 51 | Vocabulary size of the RoFormer model. Defines the number of different tokens that can be represented by 52 | the :obj:`inputs_ids` passed when calling :class:`~transformers.RoFormerModel` or 53 | :class:`~transformers.TFRoFormerModel`. 54 | embedding_size (:obj:`int`, `optional`, defaults to None): 55 | Dimensionality of the encoder layers and the pooler layer. Defaults to the :obj:`hidden_size` if not 56 | provided. 57 | hidden_size (:obj:`int`, `optional`, defaults to 768): 58 | Dimension of the encoder layers and the pooler layer. 59 | num_hidden_layers (:obj:`int`, `optional`, defaults to 12): 60 | Number of hidden layers in the Transformer encoder. 61 | num_attention_heads (:obj:`int`, `optional`, defaults to 12): 62 | Number of attention heads for each attention layer in the Transformer encoder. 63 | intermediate_size (:obj:`int`, `optional`, defaults to 3072): 64 | Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. 65 | hidden_act (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu"`): 66 | The non-linear activation function (function or string) in the encoder and pooler. If string, 67 | :obj:`"gelu"`, :obj:`"relu"`, :obj:`"selu"` and :obj:`"gelu_new"` are supported. 68 | hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1): 69 | The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler. 70 | attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1): 71 | The dropout ratio for the attention probabilities. 72 | max_position_embeddings (:obj:`int`, `optional`, defaults to 1536): 73 | The maximum sequence length that this model might ever be used with. Typically set this to something large 74 | just in case (e.g., 512 or 1024 or 1536). 75 | type_vocab_size (:obj:`int`, `optional`, defaults to 2): 76 | The vocabulary size of the :obj:`token_type_ids` passed when calling :class:`~transformers.RoFormerModel` 77 | or :class:`~transformers.TFRoFormerModel`. 78 | initializer_range (:obj:`float`, `optional`, defaults to 0.02): 79 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 80 | layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): 81 | The epsilon used by the layer normalization layers. 82 | use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): 83 | Whether or not the model should return the last key/values attentions (not used by all models). Only 84 | relevant if ``config.is_decoder=True``. 85 | rotary_value (:obj:`bool`, `optional`, defaults to :obj:`False`): 86 | Whether or not apply rotary position embeddings on value layer. 87 | Example:: 88 | >>> from transformers import RoFormerModel, RoFormerConfig 89 | >>> # Initializing a RoFormer junnyu/roformer_chinese_base style configuration 90 | >>> configuration = RoFormerConfig() 91 | >>> # Initializing a model from the junnyu/roformer_chinese_base style configuration 92 | >>> model = RoFormerModel(configuration) 93 | >>> # Accessing the model configuration 94 | >>> configuration = model.config 95 | """ 96 | model_type = "roformer" 97 | 98 | def __init__( 99 | self, 100 | vocab_size=50000, 101 | embedding_size=None, 102 | hidden_size=768, 103 | num_hidden_layers=12, 104 | num_attention_heads=12, 105 | intermediate_size=3072, 106 | hidden_act="gelu", 107 | hidden_dropout_prob=0.1, 108 | attention_probs_dropout_prob=0.1, 109 | max_position_embeddings=1536, 110 | type_vocab_size=2, 111 | initializer_range=0.02, 112 | layer_norm_eps=1e-12, 113 | pad_token_id=0, 114 | rotary_value=False, 115 | use_cache=True, 116 | use_bias=True, 117 | norm_type="layer_norm", 118 | pooler_activation="tanh", 119 | **kwargs 120 | ): 121 | super().__init__(pad_token_id=pad_token_id, **kwargs) 122 | 123 | self.vocab_size = vocab_size 124 | self.embedding_size = hidden_size if embedding_size is None else embedding_size 125 | self.hidden_size = hidden_size 126 | self.num_hidden_layers = num_hidden_layers 127 | self.num_attention_heads = num_attention_heads 128 | self.hidden_act = hidden_act 129 | self.intermediate_size = intermediate_size 130 | self.hidden_dropout_prob = hidden_dropout_prob 131 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 132 | self.max_position_embeddings = max_position_embeddings 133 | self.type_vocab_size = type_vocab_size 134 | self.initializer_range = initializer_range 135 | self.layer_norm_eps = layer_norm_eps 136 | self.rotary_value = rotary_value 137 | self.use_cache = use_cache 138 | self.use_bias = use_bias 139 | self.norm_type = norm_type 140 | self.pooler_activation = pooler_activation 141 | -------------------------------------------------------------------------------- /src/roformer/tokenization_roformer_fast.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for RoFormer.""" 16 | import json 17 | from typing import List, Optional, Tuple 18 | 19 | from tokenizers import normalizers 20 | from tokenizers.pre_tokenizers import BertPreTokenizer, PreTokenizer 21 | 22 | from transformers.tokenization_utils_fast import PreTrainedTokenizerFast 23 | from transformers.utils import logging 24 | from .tokenization_roformer import RoFormerTokenizer 25 | from .tokenization_utils import JiebaPreTokenizer 26 | 27 | 28 | logger = logging.get_logger(__name__) 29 | 30 | VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} 31 | 32 | PRETRAINED_VOCAB_FILES_MAP = { 33 | "vocab_file": { 34 | "junnyu/roformer_chinese_small": "https://huggingface.co/junnyu/roformer_chinese_small/resolve/main/vocab.txt", 35 | "junnyu/roformer_chinese_base": "https://huggingface.co/junnyu/roformer_chinese_base/resolve/main/vocab.txt", 36 | "junnyu/roformer_chinese_char_small": "https://huggingface.co/junnyu/roformer_chinese_char_small/resolve/main/vocab.txt", 37 | "junnyu/roformer_chinese_char_base": "https://huggingface.co/junnyu/roformer_chinese_char_base/resolve/main/vocab.txt", 38 | "junnyu/roformer_chinese_sim_char_small": "https://huggingface.co/junnyu/roformer_chinese_sim_char_small/resolve/main/vocab.txt", 39 | "junnyu/roformer_chinese_sim_char_base": "https://huggingface.co/junnyu/roformer_chinese_sim_char_base/resolve/main/vocab.txt", 40 | "junnyu/roformer_chinese_sim_char_ft_small": "https://huggingface.co/junnyu/roformer_chinese_sim_char_ft_small/resolve/main/vocab.txt", 41 | "junnyu/roformer_chinese_sim_char_ft_base": "https://huggingface.co/junnyu/roformer_chinese_sim_char_ft_base/resolve/main/vocab.txt", 42 | "junnyu/roformer_small_discriminator": "https://huggingface.co/junnyu/roformer_small_discriminator/resolve/main/vocab.txt", 43 | "junnyu/roformer_small_generator": "https://huggingface.co/junnyu/roformer_small_generator/resolve/main/vocab.txt", 44 | "junnyu/roformer_base_wwm_cluecorpussmall": "https://huggingface.co/junnyu/roformer_base_wwm_cluecorpussmall/resolve/main/vocab.txt", 45 | "junnyu/roformer_v2_chinese_char_small": "https://huggingface.co/junnyu/roformer_v2_chinese_char_small/resolve/main/vocab.txt", 46 | "junnyu/roformer_v2_chinese_char_base": "https://huggingface.co/junnyu/roformer_v2_chinese_char_base/resolve/main/vocab.txt", 47 | "junnyu/roformer_v2_chinese_char_large": "https://huggingface.co/junnyu/roformer_v2_chinese_char_large/resolve/main/vocab.txt", 48 | # See all RoFormer models at https://huggingface.co/models?filter=roformer 49 | } 50 | } 51 | 52 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 53 | "junnyu/roformer_chinese_small": 1536, 54 | "junnyu/roformer_chinese_base": 1536, 55 | "junnyu/roformer_chinese_char_small": 512, 56 | "junnyu/roformer_chinese_char_base": 512, 57 | "junnyu/roformer_chinese_sim_char_small": 512, 58 | "junnyu/roformer_chinese_sim_char_base": 512, 59 | "junnyu/roformer_chinese_sim_char_ft_small": 512, 60 | "junnyu/roformer_chinese_sim_char_ft_base": 512, 61 | "junnyu/roformer_small_discriminator": 128, 62 | "junnyu/roformer_small_generator": 128, 63 | "junnyu/roformer_base_wwm_cluecorpussmall": 512, 64 | "junnyu/roformer_v2_chinese_char_small": 512, 65 | "junnyu/roformer_v2_chinese_char_base": 512, 66 | "junnyu/roformer_v2_chinese_char_large": 512, 67 | } 68 | 69 | 70 | PRETRAINED_INIT_CONFIGURATION = { 71 | "junnyu/roformer_chinese_small": {"do_lower_case": True}, 72 | "junnyu/roformer_chinese_base": {"do_lower_case": True}, 73 | "junnyu/roformer_chinese_char_small": {"do_lower_case": True}, 74 | "junnyu/roformer_chinese_char_base": {"do_lower_case": True}, 75 | "junnyu/roformer_chinese_sim_char_small": {"do_lower_case": True}, 76 | "junnyu/roformer_chinese_sim_char_base": {"do_lower_case": True}, 77 | "junnyu/roformer_chinese_sim_char_ft_small": {"do_lower_case": True}, 78 | "junnyu/roformer_chinese_sim_char_ft_base": {"do_lower_case": True}, 79 | "junnyu/roformer_small_discriminator": {"do_lower_case": True}, 80 | "junnyu/roformer_small_generator": {"do_lower_case": True}, 81 | "junnyu/roformer_base_wwm_cluecorpussmall": {"do_lower_case": True}, 82 | "junnyu/roformer_v2_chinese_char_small": {"do_lower_case": True}, 83 | "junnyu/roformer_v2_chinese_char_base": {"do_lower_case": True}, 84 | "junnyu/roformer_v2_chinese_char_large": {"do_lower_case": True}, 85 | } 86 | 87 | 88 | class RoFormerTokenizerFast(PreTrainedTokenizerFast): 89 | r""" 90 | Construct a "fast" RoFormer tokenizer (backed by HuggingFace's `tokenizers` library). 91 | :class:`~transformers.RoFormerTokenizerFast` is almost identical to :class:`~transformers.BertTokenizerFast` and 92 | runs end-to-end tokenization: punctuation splitting and wordpiece. There are some difference between them when 93 | tokenizing Chinese. 94 | This tokenizer inherits from :class:`~transformers.PreTrainedTokenizerFast` which contains most of the main 95 | methods. Users should refer to this superclass for more information regarding those methods. 96 | Example:: 97 | >>> from transformers import RoFormerTokenizerFast 98 | >>> tokenizer = RoFormerTokenizerFast.from_pretrained('junnyu/roformer_chinese_base') 99 | >>> tokenizer.tokenize("今天天气非常好。") 100 | # ['今', '天', '天', '气', '非常', '好', '。'] 101 | """ 102 | 103 | vocab_files_names = VOCAB_FILES_NAMES 104 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 105 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 106 | pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION 107 | slow_tokenizer_class = RoFormerTokenizer 108 | 109 | def __init__( 110 | self, 111 | vocab_file=None, 112 | tokenizer_file=None, 113 | do_lower_case=True, 114 | unk_token="[UNK]", 115 | sep_token="[SEP]", 116 | pad_token="[PAD]", 117 | cls_token="[CLS]", 118 | mask_token="[MASK]", 119 | tokenize_chinese_chars=True, 120 | strip_accents=None, 121 | **kwargs, 122 | ): 123 | super().__init__( 124 | vocab_file, 125 | tokenizer_file=tokenizer_file, 126 | do_lower_case=do_lower_case, 127 | unk_token=unk_token, 128 | sep_token=sep_token, 129 | pad_token=pad_token, 130 | cls_token=cls_token, 131 | mask_token=mask_token, 132 | tokenize_chinese_chars=tokenize_chinese_chars, 133 | strip_accents=strip_accents, 134 | **kwargs, 135 | ) 136 | 137 | pre_tok_state = json.loads(self.backend_tokenizer.normalizer.__getstate__()) 138 | if ( 139 | pre_tok_state.get("lowercase", do_lower_case) != do_lower_case 140 | or pre_tok_state.get("strip_accents", strip_accents) != strip_accents 141 | ): 142 | pre_tok_class = getattr(normalizers, pre_tok_state.pop("type")) 143 | pre_tok_state["lowercase"] = do_lower_case 144 | pre_tok_state["strip_accents"] = strip_accents 145 | self.backend_tokenizer.normalizer = pre_tok_class(**pre_tok_state) 146 | 147 | self.do_lower_case = do_lower_case 148 | 149 | def __getstate__(self): 150 | state = self.__dict__.copy() 151 | state["_tokenizer"].pre_tokenizer = BertPreTokenizer() 152 | return state 153 | 154 | def __setstate__(self, d): 155 | self.__dict__ = d 156 | vocab = self.__dict__["_tokenizer"].get_vocab() 157 | self.__dict__["_tokenizer"].pre_tokenizer = PreTokenizer.custom( 158 | JiebaPreTokenizer(vocab) 159 | ) 160 | 161 | def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): 162 | """ 163 | Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and 164 | adding special tokens. A RoFormer sequence has the following format: 165 | - single sequence: ``[CLS] X [SEP]`` 166 | - pair of sequences: ``[CLS] A [SEP] B [SEP]`` 167 | Args: 168 | token_ids_0 (:obj:`List[int]`): 169 | List of IDs to which the special tokens will be added. 170 | token_ids_1 (:obj:`List[int]`, `optional`): 171 | Optional second list of IDs for sequence pairs. 172 | Returns: 173 | :obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens. 174 | """ 175 | output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id] 176 | 177 | if token_ids_1: 178 | output += token_ids_1 + [self.sep_token_id] 179 | 180 | return output 181 | 182 | def create_token_type_ids_from_sequences( 183 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None 184 | ) -> List[int]: 185 | """ 186 | Create a mask from the two sequences passed to be used in a sequence-pair classification task. A RoFormer 187 | sequence pair mask has the following format: 188 | :: 189 | 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 190 | | first sequence | second sequence | 191 | If :obj:`token_ids_1` is :obj:`None`, this method only returns the first portion of the mask (0s). 192 | Args: 193 | token_ids_0 (:obj:`List[int]`): 194 | List of IDs. 195 | token_ids_1 (:obj:`List[int]`, `optional`): 196 | Optional second list of IDs for sequence pairs. 197 | Returns: 198 | :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given 199 | sequence(s). 200 | """ 201 | sep = [self.sep_token_id] 202 | cls = [self.cls_token_id] 203 | if token_ids_1 is None: 204 | return len(cls + token_ids_0 + sep) * [0] 205 | return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] 206 | 207 | def save_vocabulary( 208 | self, save_directory: str, filename_prefix: Optional[str] = None 209 | ) -> Tuple[str]: 210 | files = self._tokenizer.model.save(save_directory, name=filename_prefix) 211 | return tuple(files) 212 | 213 | def save_pretrained( 214 | self, 215 | save_directory, 216 | legacy_format=None, 217 | filename_prefix=None, 218 | push_to_hub=False, 219 | **kwargs, 220 | ): 221 | self.backend_tokenizer.pre_tokenizer = BertPreTokenizer() 222 | return super().save_pretrained( 223 | save_directory, legacy_format, filename_prefix, push_to_hub, **kwargs 224 | ) 225 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch RoFormer & RoFormer-V2 2 | RoFormer模型和RoFormer-V2模型 3 | 4 | ## 更新 5 | - **2022/05/18** 6 | 7 | 添加paddle版本[RoFormerV2](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/paddlenlp/transformers/roformerv2)在分类任务上的训练结果。 8 | 9 | - **2022/05/11** 10 | 11 | 感谢苏神提醒,添加了一个注释,其中RoFormerV2*表示未经多任务学习的RoFormerV2模型。 12 | 13 | - **2022/05/01** 14 | 15 | 添加`clue分类任务`的代码和dev集结果,代码在`examples/clue`文件夹,缺少啥依赖安装啥,比如需要这个`pip install -U accelerate`。 16 | - **2022/04/30** 17 | 18 | 有个细节需要注意一下,苏神在微调时无论输入是`text`还是`text pair`类型时,`token_type_id`都置为了0。 19 | 20 | 如果想要使用与苏神保持一致,那么可以在`tokenizer`时候设置`return_token_type_ids=False`,这样模型会在内部处理。 21 | 22 | 否则对于`text pair`类型时,会返回与`0,1`两种类型的`token_type_id` 23 | - **2022/04/02** 24 | 25 | (1)修改RoFormerForCausalLM,支持`roformer-sim`并提供相关的例子,请见`examples/test_sim.py`。 26 | 27 | (2)修改`apply_rotary`实现方式,看起来更简单。 28 | ```python 29 | def apply_rotary(x, sinusoidal_pos=None): 30 | if sinusoidal_pos is None: 31 | return x 32 | sin, cos = sinusoidal_pos 33 | # x.shape [batch, seq_len, 2] 34 | x1, x2 = x[..., 0::2], x[..., 1::2] 35 | # [cos_nθ, -sin_nθ] [x1] 36 | # [sin_nθ, cos_nθ] [x2] 37 | # => [x1 * cos_nθ - x2 * sin_nθ, x1 * sin_nθ + x2 * cos_nθ] 38 | # 苏神的rotary,使用了下面的计算方法。 39 | # return torch.stack([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1).flatten(-2, -1) 40 | # 考虑到矩阵乘法torch.einsum("bhmd,bhnd->bhmn", q, k),因此可以直接在最后一个维度拼接(无需奇偶交错) 41 | return torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1) 42 | ``` 43 | - **2022/03/21** 添加`roformer-v2`的权重, 注:必须使用本仓库的代码,不能使用transformers仓库的代码!!! 44 | 45 | 46 | 47 | ## 安装 48 | 49 | ```bash 50 | # v2版本 51 | pip install roformer>=0.4.3 52 | # v1版本(代码已经加入到huggingface仓库,请使用新版本的transformers) 53 | pip install -U transformers 54 | ``` 55 | 56 | 57 | 58 | ## 评测对比 59 | ### CLUE-dev榜单分类任务结果,base+large版本。 60 | 61 | | | iflytek | tnews | afqmc | cmnli | ocnli | wsc | csl | avg | 62 | | :-----: | :-----: | :---: | :---: | :---: | :---: | :---: | :---: | ----- | 63 | | BERT | 60.06 | 56.80 | 72.41 | 79.56 | 73.93 | 78.62 | 83.93 | 72.19 | 64 | | RoBERTa | 60.64 | 58.06 | 74.05 | 81.24 | 76.00 | 87.50 | 84.50 | 74.57 | 65 | | RoFormer | 60.91 | 57.54 | 73.52 | 80.92 | 76.07 | 86.84 | 84.63 | 74.35 | 66 | | RoFormerV2* | 60.87 | 56.54 | 72.75 | 80.34 | 75.36 | 80.92 | 84.67 | 73.06 | 67 | | GAU-α | 61.41 | 57.76 | 74.17 | 81.82 | 75.86 | 79.93 | 85.67 | 73.8 | 68 | | RoFormer-pytorch(本仓库代码) | 60.60 | 57.51 | 74.44 | 80.79 | 75.67 | 86.84 | 84.77 | 74.37 | 69 | | RoFormerV2-pytorch(本仓库代码) | 62.87 | 59.03 | 76.20 | 80.85 | 79.73 | 87.82 | **91.87** | 76.91 | 70 | | GAU-α-pytorch(Adafactor) | 61.18 | 57.52 | 73.42 | 80.91 | 75.69 | 80.59 | 85.5 | 73.54 | 71 | | GAU-α-pytorch(AdamW wd0.01 warmup0.1) | 60.68 | 57.95 | 73.08 | 81.02 | 75.36 | 81.25 | 83.93 | 73.32 | 72 | | RoFormerV2-large-pytorch(本仓库代码) | 61.75 | 59.21 | 76.14 | 82.35 | 81.73 | 91.45 | 91.5 | 77.73 | 73 | | Chinesebert-large-pytorch | 61.25 | 58.67 | 74.70 | 82.65 | 79.63 | 87.83 | 84.97 | 75.67 | 74 | | RoFormerV2-base-paddle | 63.76 | 59.53 | 77.06 | 81.58 | 81.56 | 87.83 | 86.73 | 76.87 | 75 | | RoFormerV2-large-paddle | **64.02** | **60.08** | **77.92** | **82.87** | **83.9** | **92.43** | 86.87 | **78.30** | 76 | 77 | ### CLUE-1.0-test榜单分类任务结果,base+large版本。 78 | 79 | | | iflytek | tnews | afqmc | cmnli | ocnli | wsc | csl | avg | 80 | | :-----: | :-----: | :---: | :---: | :---: | :---: | :---: | :---: | ----- | 81 | | RoFormer-pytorch(本仓库代码) | 59.54 | 57.34 | 74.46 | 80.23 | 73.67 | 80.69 | 84.57 | 72.93 | 82 | | RoFormerV2-pytorch(本仓库代码) | 63.15 | 58.24 | 75.42 | 80.59 | 74.17 | 83.79 | 83.73 | 74.16 | 83 | | GAU-α-pytorch(Adafactor) | 61.38 | 57.08 | 74.05 | 80.37 | 73.53 | 74.83 | **85.6** | 72.41 | 84 | | GAU-α-pytorch(AdamW wd0.01 warmup0.1) | 60.54 | 57.67 | 72.44 | 80.32 | 72.97 | 76.55 | 84.13 | 72.09 | 85 | | RoFormerV2-large-pytorch(本仓库代码) | 61.85 | 59.13 | 76.38 | 80.97 | 76.23 | **85.86** | 84.33 | 74.96 | 86 | | Chinesebert-large-pytorch | 61.54 | 58.57 | 74.8 | 81.94 | **76.93** | 79.66 | 85.1 | 74.08 | 87 | | RoFormerV2-large-paddle | **64.23** | **59.99** | **76.85** | **81.97** | 76.57 | 84.48 | 83.37 | **75.35** | 88 | 89 | ### 注: 90 | - 其中RoFormerV2*表示的是未进行多任务学习的RoFormerV2模型,该模型苏神并未开源,感谢苏神的提醒。 91 | - 其中不带有pytorch后缀结果都是从[GAU-alpha](https://github.com/ZhuiyiTechnology/GAU-alpha)仓库复制过来的。 92 | - 其中带有pytorch后缀的结果都是自己训练得出的。 93 | - 苏神代码中拿了cls标签后直接进行了分类,而本仓库使用了如下的分类头,多了2个dropout,1个dense,1个relu激活。 94 | - paddle版本的代码进行了grid search! 95 | 96 | ```python 97 | class RoFormerClassificationHead(nn.Module): 98 | def __init__(self, config): 99 | super().__init__() 100 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 101 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 102 | self.out_proj = nn.Linear(config.hidden_size, config.num_labels) 103 | 104 | self.config = config 105 | 106 | def forward(self, features, **kwargs): 107 | x = features[:, 0, :] # take token (equiv. to [CLS]) 108 | x = self.dropout(x) 109 | x = self.dense(x) 110 | x = ACT2FN[self.config.hidden_act](x) # 这里是relu 111 | x = self.dropout(x) 112 | x = self.out_proj(x) 113 | return x 114 | ``` 115 | 116 | ### Tips: 117 | 118 | - 实验环境**RTX 3090** 119 | 120 | ### Leadborad截图 121 |

122 | 123 |

124 |

125 | 126 |

127 |

128 | 129 |

130 |

131 | 132 |

133 | 134 | ## Roformer-sim测试例子 135 | 136 | ```python 137 | import torch 138 | import numpy as np 139 | from roformer import RoFormerForCausalLM, RoFormerConfig 140 | from transformers import BertTokenizer 141 | 142 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 143 | # 可选以下几个。 144 | # junnyu/roformer_chinese_sim_char_small, junnyu/roformer_chinese_sim_char_base 145 | # junnyu/roformer_chinese_sim_char_ft_small, roformer_chinese_sim_char_ft_base 146 | pretrained_model = "junnyu/roformer_chinese_sim_char_base" 147 | tokenizer = BertTokenizer.from_pretrained(pretrained_model) 148 | config = RoFormerConfig.from_pretrained(pretrained_model) 149 | config.is_decoder = True 150 | config.eos_token_id = tokenizer.sep_token_id 151 | config.pooler_activation = "linear" 152 | model = RoFormerForCausalLM.from_pretrained(pretrained_model, config=config) 153 | model.to(device) 154 | model.eval() 155 | 156 | def gen_synonyms(text, n=100, k=20): 157 | ''''含义: 产生sent的n个相似句,然后返回最相似的k个。 158 | 做法:用seq2seq生成,并用encoder算相似度并排序。 159 | ''' 160 | # 寻找所有相似的句子 161 | r = [] 162 | inputs1 = tokenizer(text, return_tensors="pt") 163 | for _ in range(n): 164 | inputs1.to(device) 165 | output = tokenizer.batch_decode(model.generate(**inputs1, top_p=0.95, do_sample=True, max_length=128), skip_special_tokens=True)[0].replace(" ","").replace(text, "") # 去除空格,去除原始text文本。 166 | r.append(output) 167 | 168 | # 对相似的句子进行排序 169 | r = [i for i in set(r) if i != text and len(i) > 0] 170 | r = [text] + r 171 | inputs2 = tokenizer(r, padding=True, return_tensors="pt") 172 | with torch.no_grad(): 173 | inputs2.to(device) 174 | outputs = model(**inputs2) 175 | Z = outputs.pooler_output.cpu().numpy() 176 | Z /= (Z**2).sum(axis=1, keepdims=True)**0.5 177 | argsort = np.dot(Z[1:], -Z[0]).argsort() 178 | 179 | return [r[i + 1] for i in argsort[:k]] 180 | 181 | out = gen_synonyms("广州和深圳哪个好?") 182 | print(out) 183 | # ['深圳和广州哪个好?', 184 | # '广州和深圳哪个好', 185 | # '深圳和广州哪个好', 186 | # '深圳和广州哪个比较好。', 187 | # '深圳和广州哪个最好?', 188 | # '深圳和广州哪个比较好', 189 | # '广州和深圳那个比较好', 190 | # '深圳和广州哪个更好?', 191 | # '深圳与广州哪个好', 192 | # '深圳和广州,哪个比较好', 193 | # '广州与深圳比较哪个好', 194 | # '深圳和广州哪里比较好', 195 | # '深圳还是广州比较好?', 196 | # '广州和深圳哪个地方好一些?', 197 | # '广州好还是深圳好?', 198 | # '广州好还是深圳好呢?', 199 | # '广州与深圳哪个地方好点?', 200 | # '深圳好还是广州好', 201 | # '广州好还是深圳好', 202 | # '广州和深圳哪个城市好?'] 203 | ``` 204 | 205 | 206 | 207 | ## 模型权重对照表 208 | 209 | ### 中文模型 roformer-v2 210 | | huggingface.co | bert4keras | 211 | | ---------------------------------- | ------------------------------------------------ | 212 | | [roformer_v2_chinese_char_small](https://huggingface.co/junnyu/roformer_v2_chinese_char_small) | [chinese_roformer-v2-char_L-6_H-384_A-6.zip](https://pan.baidu.com/s/1huUrC9P60Afggo8AfiUcmA) (download code:ttn4) | 213 | | [roformer_v2_chinese_char_base](https://huggingface.co/junnyu/roformer_v2_chinese_char_base) | [chinese_roformer-v2-char_L-12_H-768_A-12.zip](https://pan.baidu.com/s/1qcnN4LVKVe0-mnHlkN3-6Q) (download code:pfoh) | 214 | | [roformer_v2_chinese_char_large](https://huggingface.co/junnyu/roformer_v2_chinese_char_large) | [chinese_roformer-v2-char_L-24_H-1024_A-16.zip](https://pan.baidu.com/s/1QiJWSZrGxn8vek-8myvL6w) (download code:npfv) | 215 | 216 | 217 | ### 中文模型 roformer-v1 218 | | huggingface.co | bert4keras | 219 | | ---------------------------------- | ------------------------------------------------ | 220 | | [roformer_chinese_base](https://huggingface.co/junnyu/roformer_chinese_base) | [chinese_roformer_L-12_H-768_A-12.zip](https://pan.baidu.com/s/1fiss862YsGCwf2HvU_Jm-g) (download code:xy9x) | 221 | | [roformer_chinese_small](https://huggingface.co/junnyu/roformer_chinese_small) | [chinese_roformer_L-6_H-384_A-6.zip](https://pan.baidu.com/s/1iIXgZHHCgrYGXVRRSSCVPg) (download code:gy97) | 222 | | [roformer_chinese_char_base](https://huggingface.co/junnyu/roformer_chinese_char_base) | [chinese_roformer-char_L-12_H-768_A-12.zip](https://pan.baidu.com/s/1Q1pq8F4Fsl6bTipUAkqeDQ) (download code:bt94) | 223 | | [roformer_chinese_char_small](https://huggingface.co/junnyu/roformer_chinese_char_small) | [chinese_roformer-char_L-6_H-384_A-6.zip](https://pan.baidu.com/s/1cc281-M0Rsjlwws5phqzbQ) (download code:a44c) | 224 | | [roformer_chinese_sim_char_base](https://huggingface.co/junnyu/roformer_chinese_sim_char_base) | [chinese_roformer-sim-char_L-12_H-768_A-12.zip](https://pan.baidu.com/s/1f1FB288nv1a6jYjsNCordg) (download code:2cgz) | 225 | | [roformer_chinese_sim_char_small](https://huggingface.co/junnyu/roformer_chinese_sim_char_small) | [chinese_roformer-sim-char_L-6_H-384_A-6.zip](https://pan.baidu.com/s/1r0eJ7shGwQ0RzV9BTFFW4g) (download code:h68q) | 226 | | [roformer_chinese_sim_char_ft_base](https://huggingface.co/junnyu/roformer_chinese_sim_char_ft_base) | [chinese_roformer-sim-char-ft_L-12_H-768_A-12.zip](https://pan.baidu.com/s/1Igh3tSvSu_ahDZmGaOlVoA) (download code:w15n) | 227 | | [roformer_chinese_sim_char_ft_small](https://huggingface.co/junnyu/roformer_chinese_sim_char_ft_small) | [chinese_roformer-sim-char-ft_L-6_H-384_A-6.zip](https://pan.baidu.com/s/1G36x7YQF1b6nzW0OzyJS_Q) (download code:gty5) | 228 | 229 | 230 | 231 | 232 | ### 英文模型(使用electra的训练方法在openwebtext上训练的small模型(rotary value = True)) 233 | | huggingface.co | 234 | | ---------------------------------- | 235 | |[roformer_small_generator](https://huggingface.co/junnyu/roformer_small_generator)| 236 | |[roformer_small_discriminator](https://huggingface.co/junnyu/roformer_small_discriminator)| 237 | 238 | 239 | 240 | ## Roformer-v2 MLM测试 241 | 242 | ```python 243 | import torch 244 | import tensorflow as tf 245 | from transformers import BertTokenizer 246 | from roformer import RoFormerForMaskedLM, TFRoFormerForMaskedLM 247 | 248 | text = "今天[MASK]很好,我[MASK]去公园玩。" 249 | tokenizer = BertTokenizer.from_pretrained("junnyu/roformer_v2_chinese_char_base") 250 | pt_model = RoFormerForMaskedLM.from_pretrained("junnyu/roformer_v2_chinese_char_base") 251 | tf_model = TFRoFormerForMaskedLM.from_pretrained( 252 | "junnyu/roformer_v2_chinese_char_base", from_pt=True 253 | ) 254 | pt_inputs = tokenizer(text, return_tensors="pt") 255 | tf_inputs = tokenizer(text, return_tensors="tf") 256 | # pytorch 257 | with torch.no_grad(): 258 | pt_outputs = pt_model(**pt_inputs).logits[0] 259 | pt_outputs_sentence = "pytorch: " 260 | for i, id in enumerate(tokenizer.encode(text)): 261 | if id == tokenizer.mask_token_id: 262 | tokens = tokenizer.convert_ids_to_tokens(pt_outputs[i].topk(k=5)[1]) 263 | pt_outputs_sentence += "[" + "||".join(tokens) + "]" 264 | else: 265 | pt_outputs_sentence += "".join( 266 | tokenizer.convert_ids_to_tokens([id], skip_special_tokens=True) 267 | ) 268 | print(pt_outputs_sentence) 269 | # tf 270 | tf_outputs = tf_model(**tf_inputs, training=False).logits[0] 271 | tf_outputs_sentence = "tf: " 272 | for i, id in enumerate(tokenizer.encode(text)): 273 | if id == tokenizer.mask_token_id: 274 | tokens = tokenizer.convert_ids_to_tokens(tf.math.top_k(tf_outputs[i], k=5)[1]) 275 | tf_outputs_sentence += "[" + "||".join(tokens) + "]" 276 | else: 277 | tf_outputs_sentence += "".join( 278 | tokenizer.convert_ids_to_tokens([id], skip_special_tokens=True) 279 | ) 280 | print(tf_outputs_sentence) 281 | # small 282 | # pytorch: 今天[的||,||是||很||也]很好,我[要||会||是||想||在]去公园玩。 283 | # tf: 今天[的||,||是||很||也]很好,我[要||会||是||想||在]去公园玩。 284 | # base 285 | # pytorch: 今天[我||天||晴||园||玩]很好,我[想||要||会||就||带]去公园玩。 286 | # tf: 今天[我||天||晴||园||玩]很好,我[想||要||会||就||带]去公园玩。 287 | # large 288 | # pytorch: 今天[天||气||我||空||阳]很好,我[又||想||会||就||爱]去公园玩。 289 | # tf: 今天[天||气||我||空||阳]很好,我[又||想||会||就||爱]去公园玩。 290 | ``` 291 | 292 | 293 | 294 | ## Roformer-v1 MLM测试 295 | 296 | ```python 297 | import torch 298 | import tensorflow as tf 299 | from transformers import RoFormerForMaskedLM, RoFormerTokenizer, TFRoFormerForMaskedLM 300 | 301 | text = "今天[MASK]很好,我[MASK]去公园玩。" 302 | tokenizer = RoFormerTokenizer.from_pretrained("junnyu/roformer_chinese_base") 303 | pt_model = RoFormerForMaskedLM.from_pretrained("junnyu/roformer_chinese_base") 304 | tf_model = TFRoFormerForMaskedLM.from_pretrained( 305 | "junnyu/roformer_chinese_base", from_pt=True 306 | ) 307 | pt_inputs = tokenizer(text, return_tensors="pt") 308 | tf_inputs = tokenizer(text, return_tensors="tf") 309 | # pytorch 310 | with torch.no_grad(): 311 | pt_outputs = pt_model(**pt_inputs).logits[0] 312 | pt_outputs_sentence = "pytorch: " 313 | for i, id in enumerate(tokenizer.encode(text)): 314 | if id == tokenizer.mask_token_id: 315 | tokens = tokenizer.convert_ids_to_tokens(pt_outputs[i].topk(k=5)[1]) 316 | pt_outputs_sentence += "[" + "||".join(tokens) + "]" 317 | else: 318 | pt_outputs_sentence += "".join( 319 | tokenizer.convert_ids_to_tokens([id], skip_special_tokens=True) 320 | ) 321 | print(pt_outputs_sentence) 322 | # tf 323 | tf_outputs = tf_model(**tf_inputs, training=False).logits[0] 324 | tf_outputs_sentence = "tf: " 325 | for i, id in enumerate(tokenizer.encode(text)): 326 | if id == tokenizer.mask_token_id: 327 | tokens = tokenizer.convert_ids_to_tokens(tf.math.top_k(tf_outputs[i], k=5)[1]) 328 | tf_outputs_sentence += "[" + "||".join(tokens) + "]" 329 | else: 330 | tf_outputs_sentence += "".join( 331 | tokenizer.convert_ids_to_tokens([id], skip_special_tokens=True) 332 | ) 333 | print(tf_outputs_sentence) 334 | # pytorch: 今天[天气||天||心情||阳光||空气]很好,我[想||要||打算||准备||喜欢]去公园玩。 335 | # tf: 今天[天气||天||心情||阳光||空气]很好,我[想||要||打算||准备||喜欢]去公园玩。 336 | 337 | ``` 338 | 339 | 340 | 341 | ## 手动权重转换 342 | 343 | ```bash 344 | python convert_roformer_original_tf_checkpoint_to_pytorch.py \ 345 | --tf_checkpoint_path=xxxxxx/chinese_roformer_L-12_H-768_A-12/bert_model.ckpt \ 346 | --bert_config_file=pretrained_models/chinese_roformer_base/config.json \ 347 | --pytorch_dump_path=pretrained_models/chinese_roformer_base/pytorch_model.bin 348 | ``` 349 | 350 | 351 | 352 | ## tf与pytorch精度对齐 353 | 354 | ```python 355 | small版本 356 | bert4keras vs pytorch 357 | mean diff : tensor(5.9108e-07) 358 | max diff : tensor(5.7220e-06) 359 | bert4keras vs tf2.0 360 | mean diff : tensor(4.5976e-07) 361 | max diff : tensor(3.5763e-06) 362 | 363 | base版本 364 | python compare_model.py 365 | bert4keras vs pytorch 366 | mean diff : tensor(4.3340e-07) 367 | max diff : tensor(5.7220e-06) 368 | bert4keras vs tf2.0 369 | mean diff : tensor(3.4319e-07) 370 | max diff : tensor(5.2452e-06) 371 | ``` 372 | 373 | 374 | 375 | ## 参考 376 | 377 | https://github.com/pengming617/bert_classification 378 | 379 | https://github.com/bojone/bert4keras 380 | 381 | https://github.com/ZhuiyiTechnology/roformer 382 | 383 | https://github.com/lonePatient/NeZha_Chinese_PyTorch 384 | 385 | https://github.com/lonePatient/TorchBlocks 386 | 387 | https://github.com/huggingface/transformers 388 | 389 | 390 | 391 | ## Citation 392 | 393 | Bibtex: 394 | 395 | ```tex 396 | 397 | @misc{su2021roformer, 398 | title={RoFormer: Enhanced Transformer with Rotary Position Embedding}, 399 | author={Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu}, 400 | year={2021}, 401 | eprint={2104.09864}, 402 | archivePrefix={arXiv}, 403 | primaryClass={cs.CL} 404 | } 405 | 406 | ``` 407 | 408 | ```tex 409 | @techreport{roformerv2, 410 | title={RoFormerV2: A Faster and Better RoFormer - ZhuiyiAI}, 411 | author={Jianlin Su, Shengfeng Pan, Bo Wen, Yunfeng Liu}, 412 | year={2022}, 413 | url="https://github.com/ZhuiyiTechnology/roformer-v2", 414 | } 415 | ``` 416 | -------------------------------------------------------------------------------- /src/roformer/tokenization_roformer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for RoFormer.""" 16 | 17 | import collections 18 | import os 19 | from typing import List, Optional, Tuple 20 | 21 | from transformers.models.bert.tokenization_bert import ( 22 | BasicTokenizer, 23 | WordpieceTokenizer, 24 | load_vocab, 25 | ) 26 | from transformers.tokenization_utils import PreTrainedTokenizer 27 | from transformers.utils import logging 28 | 29 | logger = logging.get_logger(__name__) 30 | 31 | VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} 32 | 33 | PRETRAINED_VOCAB_FILES_MAP = { 34 | "vocab_file": { 35 | "junnyu/roformer_chinese_small": "https://huggingface.co/junnyu/roformer_chinese_small/resolve/main/vocab.txt", 36 | "junnyu/roformer_chinese_base": "https://huggingface.co/junnyu/roformer_chinese_base/resolve/main/vocab.txt", 37 | "junnyu/roformer_chinese_char_small": "https://huggingface.co/junnyu/roformer_chinese_char_small/resolve/main/vocab.txt", 38 | "junnyu/roformer_chinese_char_base": "https://huggingface.co/junnyu/roformer_chinese_char_base/resolve/main/vocab.txt", 39 | "junnyu/roformer_chinese_sim_char_small": "https://huggingface.co/junnyu/roformer_chinese_sim_char_small/resolve/main/vocab.txt", 40 | "junnyu/roformer_chinese_sim_char_base": "https://huggingface.co/junnyu/roformer_chinese_sim_char_base/resolve/main/vocab.txt", 41 | "junnyu/roformer_chinese_sim_char_ft_small": "https://huggingface.co/junnyu/roformer_chinese_sim_char_ft_small/resolve/main/vocab.txt", 42 | "junnyu/roformer_chinese_sim_char_ft_base": "https://huggingface.co/junnyu/roformer_chinese_sim_char_ft_base/resolve/main/vocab.txt", 43 | "junnyu/roformer_small_discriminator": "https://huggingface.co/junnyu/roformer_small_discriminator/resolve/main/vocab.txt", 44 | "junnyu/roformer_small_generator": "https://huggingface.co/junnyu/roformer_small_generator/resolve/main/vocab.txt", 45 | "junnyu/roformer_base_wwm_cluecorpussmall": "https://huggingface.co/junnyu/roformer_base_wwm_cluecorpussmall/resolve/main/vocab.txt", 46 | "junnyu/roformer_v2_chinese_char_small": "https://huggingface.co/junnyu/roformer_v2_chinese_char_small/resolve/main/vocab.txt", 47 | "junnyu/roformer_v2_chinese_char_base": "https://huggingface.co/junnyu/roformer_v2_chinese_char_base/resolve/main/vocab.txt", 48 | "junnyu/roformer_v2_chinese_char_large": "https://huggingface.co/junnyu/roformer_v2_chinese_char_large/resolve/main/vocab.txt", 49 | # See all RoFormer models at https://huggingface.co/models?filter=roformer 50 | } 51 | } 52 | 53 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 54 | "junnyu/roformer_chinese_small": 1536, 55 | "junnyu/roformer_chinese_base": 1536, 56 | "junnyu/roformer_chinese_char_small": 512, 57 | "junnyu/roformer_chinese_char_base": 512, 58 | "junnyu/roformer_chinese_sim_char_small": 512, 59 | "junnyu/roformer_chinese_sim_char_base": 512, 60 | "junnyu/roformer_chinese_sim_char_ft_small": 512, 61 | "junnyu/roformer_chinese_sim_char_ft_base": 512, 62 | "junnyu/roformer_small_discriminator": 128, 63 | "junnyu/roformer_small_generator": 128, 64 | "junnyu/roformer_base_wwm_cluecorpussmall": 512, 65 | "junnyu/roformer_v2_chinese_char_small": 512, 66 | "junnyu/roformer_v2_chinese_char_base": 512, 67 | "junnyu/roformer_v2_chinese_char_large": 512, 68 | } 69 | 70 | 71 | PRETRAINED_INIT_CONFIGURATION = { 72 | "junnyu/roformer_chinese_small": {"do_lower_case": True}, 73 | "junnyu/roformer_chinese_base": {"do_lower_case": True}, 74 | "junnyu/roformer_chinese_char_small": {"do_lower_case": True}, 75 | "junnyu/roformer_chinese_char_base": {"do_lower_case": True}, 76 | "junnyu/roformer_chinese_sim_char_small": {"do_lower_case": True}, 77 | "junnyu/roformer_chinese_sim_char_base": {"do_lower_case": True}, 78 | "junnyu/roformer_chinese_sim_char_ft_small": {"do_lower_case": True}, 79 | "junnyu/roformer_chinese_sim_char_ft_base": {"do_lower_case": True}, 80 | "junnyu/roformer_small_discriminator": {"do_lower_case": True}, 81 | "junnyu/roformer_small_generator": {"do_lower_case": True}, 82 | "junnyu/roformer_base_wwm_cluecorpussmall": {"do_lower_case": True}, 83 | "junnyu/roformer_v2_chinese_char_small": {"do_lower_case": True}, 84 | "junnyu/roformer_v2_chinese_char_base": {"do_lower_case": True}, 85 | "junnyu/roformer_v2_chinese_char_large": {"do_lower_case": True}, 86 | } 87 | 88 | 89 | class RoFormerTokenizer(PreTrainedTokenizer): 90 | r""" 91 | Construct a RoFormer tokenizer. Based on `Rust Jieba `. 92 | This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main methods. 93 | Users should refer to this superclass for more information regarding those methods. 94 | Args: 95 | vocab_file (:obj:`str`): 96 | File containing the vocabulary. 97 | do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`): 98 | Whether or not to lowercase the input when tokenizing. 99 | do_basic_tokenize (:obj:`bool`, `optional`, defaults to :obj:`True`): 100 | Whether or not to do basic tokenization before WordPiece. 101 | never_split (:obj:`Iterable`, `optional`): 102 | Collection of tokens which will never be split during tokenization. Only has an effect when 103 | :obj:`do_basic_tokenize=True` 104 | unk_token (:obj:`str`, `optional`, defaults to :obj:`"[UNK]"`): 105 | The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this 106 | token instead. 107 | sep_token (:obj:`str`, `optional`, defaults to :obj:`"[SEP]"`): 108 | The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for 109 | sequence classification or for a text and a question for question answering. It is also used as the last 110 | token of a sequence built with special tokens. 111 | pad_token (:obj:`str`, `optional`, defaults to :obj:`"[PAD]"`): 112 | The token used for padding, for example when batching sequences of different lengths. 113 | cls_token (:obj:`str`, `optional`, defaults to :obj:`"[CLS]"`): 114 | The classifier token which is used when doing sequence classification (classification of the whole sequence 115 | instead of per-token classification). It is the first token of the sequence when built with special tokens. 116 | mask_token (:obj:`str`, `optional`, defaults to :obj:`"[MASK]"`): 117 | The token used for masking values. This is the token used when training this model with masked language 118 | modeling. This is the token which the model will try to predict. 119 | tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`): 120 | Whether or not to tokenize Chinese characters. 121 | This should likely be deactivated for Japanese (see this `issue 122 | `__). 123 | strip_accents: (:obj:`bool`, `optional`): 124 | Whether or not to strip all accents. If this option is not specified, then it will be determined by the 125 | value for :obj:`lowercase` (as in the original BERT). 126 | Example:: 127 | >>> from transformers import RoFormerTokenizer 128 | >>> tokenizer = RoFormerTokenizer.from_pretrained('junnyu/roformer_chinese_base') 129 | >>> tokenizer.tokenize("今天天气非常好。") 130 | # ['今', '天', '天', '气', '非常', '好', '。'] 131 | """ 132 | vocab_files_names = VOCAB_FILES_NAMES 133 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 134 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 135 | pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION 136 | 137 | def __init__( 138 | self, 139 | vocab_file, 140 | do_lower_case=True, 141 | do_basic_tokenize=True, 142 | never_split=None, 143 | unk_token="[UNK]", 144 | sep_token="[SEP]", 145 | pad_token="[PAD]", 146 | cls_token="[CLS]", 147 | mask_token="[MASK]", 148 | tokenize_chinese_chars=True, 149 | strip_accents=None, 150 | **kwargs, 151 | ): 152 | super().__init__( 153 | do_lower_case=do_lower_case, 154 | do_basic_tokenize=do_basic_tokenize, 155 | never_split=never_split, 156 | unk_token=unk_token, 157 | sep_token=sep_token, 158 | pad_token=pad_token, 159 | cls_token=cls_token, 160 | mask_token=mask_token, 161 | tokenize_chinese_chars=tokenize_chinese_chars, 162 | strip_accents=strip_accents, 163 | **kwargs, 164 | ) 165 | 166 | if not os.path.isfile(vocab_file): 167 | raise ValueError( 168 | f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained " 169 | "model use `tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`" 170 | ) 171 | self.vocab = load_vocab(vocab_file) 172 | self.ids_to_tokens = collections.OrderedDict( 173 | [(ids, tok) for tok, ids in self.vocab.items()] 174 | ) 175 | self.do_basic_tokenize = do_basic_tokenize 176 | if do_basic_tokenize: 177 | self.basic_tokenizer = BasicTokenizer( 178 | do_lower_case=do_lower_case, 179 | never_split=never_split, 180 | tokenize_chinese_chars=tokenize_chinese_chars, 181 | strip_accents=strip_accents, 182 | ) 183 | self.wordpiece_tokenizer = WordpieceTokenizer( 184 | vocab=self.vocab, unk_token=self.unk_token 185 | ) 186 | try: 187 | import rjieba 188 | except ImportError: 189 | raise ImportError( 190 | "You need to install rjieba to use RoFormerTokenizer. " 191 | "See https://pypi.org/project/rjieba/ for installation." 192 | ) 193 | self.jieba = rjieba 194 | 195 | @property 196 | def do_lower_case(self): 197 | return self.basic_tokenizer.do_lower_case 198 | 199 | @property 200 | def vocab_size(self): 201 | return len(self.vocab) 202 | 203 | def __getstate__(self): 204 | state = self.__dict__.copy() 205 | state["jieba"] = None 206 | return state 207 | 208 | def __setstate__(self, d): 209 | self.__dict__ = d 210 | import rjieba 211 | 212 | self.jieba = rjieba 213 | 214 | def get_vocab(self): 215 | return dict(self.vocab, **self.added_tokens_encoder) 216 | 217 | def _tokenize(self, text, use_jieba=True): 218 | split_tokens = [] 219 | if use_jieba: 220 | for wholword in self.jieba.cut(text, False): 221 | if wholword in self.vocab: 222 | split_tokens.append(wholword) 223 | else: 224 | # use bert tokenizer to _tokenize 225 | char_list = self._tokenize(wholword, use_jieba=False) 226 | split_tokens.extend(char_list) 227 | else: 228 | if self.do_basic_tokenize: 229 | for token in self.basic_tokenizer.tokenize( 230 | text, never_split=self.all_special_tokens 231 | ): 232 | # If the token is part of the never_split set 233 | if token in self.basic_tokenizer.never_split: 234 | split_tokens.append(token) 235 | else: 236 | split_tokens += self.wordpiece_tokenizer.tokenize(token) 237 | else: 238 | split_tokens = self.wordpiece_tokenizer.tokenize(text) 239 | return split_tokens 240 | 241 | def _convert_token_to_id(self, token): 242 | """Converts a token (str) in an id using the vocab.""" 243 | return self.vocab.get(token, self.vocab.get(self.unk_token)) 244 | 245 | def _convert_id_to_token(self, index): 246 | """Converts an index (integer) in a token (str) using the vocab.""" 247 | return self.ids_to_tokens.get(index, self.unk_token) 248 | 249 | def convert_tokens_to_string(self, tokens): 250 | """Converts a sequence of tokens (string) in a single string.""" 251 | out_string = " ".join(tokens).replace(" ##", "").strip() 252 | return out_string 253 | 254 | def build_inputs_with_special_tokens( 255 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None 256 | ) -> List[int]: 257 | """ 258 | Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and 259 | adding special tokens. A RoFormer sequence has the following format: 260 | - single sequence: ``[CLS] X [SEP]`` 261 | - pair of sequences: ``[CLS] A [SEP] B [SEP]`` 262 | Args: 263 | token_ids_0 (:obj:`List[int]`): 264 | List of IDs to which the special tokens will be added. 265 | token_ids_1 (:obj:`List[int]`, `optional`): 266 | Optional second list of IDs for sequence pairs. 267 | Returns: 268 | :obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens. 269 | """ 270 | if token_ids_1 is None: 271 | return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] 272 | cls = [self.cls_token_id] 273 | sep = [self.sep_token_id] 274 | return cls + token_ids_0 + sep + token_ids_1 + sep 275 | 276 | def get_special_tokens_mask( 277 | self, 278 | token_ids_0: List[int], 279 | token_ids_1: Optional[List[int]] = None, 280 | already_has_special_tokens: bool = False, 281 | ) -> List[int]: 282 | """ 283 | Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding 284 | special tokens using the tokenizer ``prepare_for_model`` method. 285 | Args: 286 | token_ids_0 (:obj:`List[int]`): 287 | List of IDs. 288 | token_ids_1 (:obj:`List[int]`, `optional`): 289 | Optional second list of IDs for sequence pairs. 290 | already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`): 291 | Whether or not the token list is already formatted with special tokens for the model. 292 | Returns: 293 | :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. 294 | """ 295 | 296 | if already_has_special_tokens: 297 | return super().get_special_tokens_mask( 298 | token_ids_0=token_ids_0, 299 | token_ids_1=token_ids_1, 300 | already_has_special_tokens=True, 301 | ) 302 | 303 | if token_ids_1 is not None: 304 | return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] 305 | return [1] + ([0] * len(token_ids_0)) + [1] 306 | 307 | def create_token_type_ids_from_sequences( 308 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None 309 | ) -> List[int]: 310 | """ 311 | Create a mask from the two sequences passed to be used in a sequence-pair classification task. A RoFormer 312 | sequence pair mask has the following format: 313 | :: 314 | 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 315 | | first sequence | second sequence | 316 | If :obj:`token_ids_1` is :obj:`None`, this method only returns the first portion of the mask (0s). 317 | Args: 318 | token_ids_0 (:obj:`List[int]`): 319 | List of IDs. 320 | token_ids_1 (:obj:`List[int]`, `optional`): 321 | Optional second list of IDs for sequence pairs. 322 | Returns: 323 | :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given 324 | sequence(s). 325 | """ 326 | sep = [self.sep_token_id] 327 | cls = [self.cls_token_id] 328 | if token_ids_1 is None: 329 | return len(cls + token_ids_0 + sep) * [0] 330 | return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] 331 | 332 | def save_vocabulary( 333 | self, save_directory: str, filename_prefix: Optional[str] = None 334 | ) -> Tuple[str]: 335 | index = 0 336 | if os.path.isdir(save_directory): 337 | vocab_file = os.path.join( 338 | save_directory, 339 | (filename_prefix + "-" if filename_prefix else "") 340 | + VOCAB_FILES_NAMES["vocab_file"], 341 | ) 342 | else: 343 | vocab_file = ( 344 | filename_prefix + "-" if filename_prefix else "" 345 | ) + save_directory 346 | with open(vocab_file, "w", encoding="utf-8") as writer: 347 | for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): 348 | if index != token_index: 349 | logger.warning( 350 | f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive." 351 | " Please check that the vocabulary is not corrupted!" 352 | ) 353 | index = token_index 354 | writer.write(token + "\n") 355 | index += 1 356 | return (vocab_file,) -------------------------------------------------------------------------------- /examples/clue/classification/clue_11.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TensorFlow Datasets Authors and the HuggingFace Datasets Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """A Chinese Language Understanding Evaluation Benchmark (CLUE) benchmark.""" 18 | 19 | 20 | import json 21 | import os 22 | import re 23 | import textwrap 24 | 25 | import datasets 26 | 27 | _CLUE_CITATION = """\ 28 | @misc{xu2020clue, 29 | title={CLUE: A Chinese Language Understanding Evaluation Benchmark}, 30 | author={Liang Xu and Xuanwei Zhang and Lu Li and Hai Hu and Chenjie Cao and Weitang Liu and Junyi Li and Yudong Li and Kai Sun and Yechen Xu and Yiming Cui and Cong Yu and Qianqian Dong and Yin Tian and Dian Yu and Bo Shi and Jun Zeng and Rongzhao Wang and Weijian Xie and Yanting Li and Yina Patterson and Zuoyu Tian and Yiwen Zhang and He Zhou and Shaoweihua Liu and Qipeng Zhao and Cong Yue and Xinrui Zhang and Zhengliang Yang and Zhenzhong Lan}, 31 | year={2020}, 32 | eprint={2004.05986}, 33 | archivePrefix={arXiv}, 34 | primaryClass={cs.CL} 35 | } 36 | """ 37 | 38 | _CLUE_DESCRIPTION = """\ 39 | CLUE, A Chinese Language Understanding Evaluation Benchmark 40 | (https://www.cluebenchmarks.com/) is a collection of resources for training, 41 | evaluating, and analyzing Chinese language understanding systems. 42 | """ 43 | 44 | 45 | class ClueConfig(datasets.BuilderConfig): 46 | """BuilderConfig for CLUE.""" 47 | 48 | def __init__( 49 | self, 50 | data_url, 51 | text_features=None, 52 | label_column=None, 53 | data_dir="", 54 | citation="", 55 | url="", 56 | label_classes=None, 57 | process_label=lambda x: x, 58 | **kwargs, 59 | ): 60 | """BuilderConfig for CLUE. 61 | Args: 62 | text_features: `dict[string, string]`, map from the name of the feature 63 | dict for each text field to the name of the column in the tsv file 64 | label_column: `string`, name of the column in the tsv file corresponding 65 | to the label 66 | data_url: `string`, url to download the zip file from 67 | data_dir: `string`, the path to the folder containing the tsv files in the 68 | downloaded zip 69 | citation: `string`, citation for the data set 70 | url: `string`, url for information about the data set 71 | label_classes: `list[string]`, the list of classes if the label is 72 | categorical. If not provided, then the label will be of type 73 | `datasets.Value('float32')`. 74 | process_label: `Function[string, any]`, function taking in the raw value 75 | of the label and processing it to the form required by the label feature 76 | **kwargs: keyword arguments forwarded to super. 77 | """ 78 | super(ClueConfig, self).__init__( 79 | version=datasets.Version("1.0.0", ""), **kwargs 80 | ) 81 | self.text_features = text_features 82 | self.label_column = label_column 83 | self.label_classes = label_classes 84 | self.data_url = data_url 85 | self.data_dir = data_dir 86 | self.citation = citation 87 | self.url = url 88 | self.process_label = process_label 89 | 90 | 91 | class Clue(datasets.GeneratorBasedBuilder): 92 | """A Chinese Language Understanding Evaluation Benchmark (CLUE) benchmark.""" 93 | 94 | BUILDER_CONFIGS = [ 95 | ClueConfig( 96 | name="afqmc", 97 | description=textwrap.dedent( 98 | """\ 99 | Ant Financial Question Matching Corpus is a dataset for Chinese 100 | question matching (similar to QQP). 101 | """ 102 | ), 103 | text_features={"sentence1": "sentence1", "sentence2": "sentence2"}, 104 | label_classes=["0", "1"], 105 | label_column="label", 106 | data_url="https://storage.googleapis.com/cluebenchmark/tasks/afqmc_public.zip", 107 | url="https://dc.cloud.alipay.com/index#/topic/data?id=8", 108 | ), 109 | ClueConfig( 110 | name="tnews", 111 | description=textwrap.dedent( 112 | """\ 113 | Toutiao Short Text Classification for News is a dataset for Chinese 114 | short news classification. 115 | """ 116 | ), 117 | text_features={"sentence": "sentence"}, 118 | label_classes=[ 119 | "100", 120 | "101", 121 | "102", 122 | "103", 123 | "104", 124 | "106", 125 | "107", 126 | "108", 127 | "109", 128 | "110", 129 | "112", 130 | "113", 131 | "114", 132 | "115", 133 | "116", 134 | ], 135 | label_column="label", 136 | data_url="https://storage.googleapis.com/cluebenchmark/tasks/tnews_public.zip", 137 | url="https://github.com/skdjfla/toutiao-text-classfication-dataset", 138 | ), 139 | ClueConfig( 140 | name="iflytek", 141 | description=textwrap.dedent( 142 | """\ 143 | IFLYTEK Long Text Classification for News is a dataset for Chinese 144 | long text classification. The text is crawled from an app market. 145 | """ 146 | ), 147 | text_features={"sentence": "sentence"}, 148 | label_classes=[str(label) for label in range(119)], 149 | label_column="label", 150 | data_url="https://storage.googleapis.com/cluebenchmark/tasks/iflytek_public.zip", 151 | ), 152 | ClueConfig( 153 | name="cmnli", 154 | description=textwrap.dedent( 155 | """\ 156 | Chinese Multi-Genre NLI is a dataset for Chinese Natural Language 157 | Inference. It consists of XNLI (Chinese subset) and translated MNLI. 158 | """ 159 | ), 160 | text_features={"sentence1": "sentence1", "sentence2": "sentence2"}, 161 | label_classes=["neutral", "entailment", "contradiction"], 162 | label_column="label", 163 | data_url="https://storage.googleapis.com/cluebenchmark/tasks/cmnli_public.zip", 164 | data_dir="cmnli_public", 165 | ), 166 | ClueConfig( 167 | name="cluewsc2020", 168 | description=textwrap.dedent( 169 | """\ 170 | CLUE Winograd Scheme Challenge (CLUEWSC 2020) is a Chinese WSC dataset. 171 | The text is from contemporary literature and annotated by human experts. 172 | The task is to determine which noun the pronoun in the sentence refers to. 173 | The question appears in the form of true and false discrimination. 174 | """ 175 | ), 176 | text_features={"text": "text", "target": "target"}, 177 | label_classes=["false", "true"], 178 | label_column="label", 179 | data_url="https://storage.googleapis.com/cluebenchmark/tasks/cluewsc2020_public.zip", 180 | ), 181 | ClueConfig( 182 | name="csl", 183 | description=textwrap.dedent( 184 | """\ 185 | Chinese Scientific Literature Dataset (CSL) is taken from the abstracts of 186 | Chinese papers and their keywords. The papers are selected from some core 187 | journals of Chinese social sciences and natural sciences. TF-IDF is used to 188 | generate a mixture of fake keywords and real keywords in the paper to construct 189 | abstract-keyword pairs. The task goal is to judge whether the keywords are 190 | all real keywords based on the abstract. 191 | """ 192 | ), 193 | text_features={"abst": "abst", "keyword": "keyword", "corpus_id": "id"}, 194 | label_classes=["0", "1"], 195 | label_column="label", 196 | data_url="https://storage.googleapis.com/cluebenchmark/tasks/csl_public.zip", 197 | url="https://github.com/P01son6415/CSL", 198 | ), 199 | ClueConfig( 200 | name="cmrc2018", 201 | description=textwrap.dedent( 202 | """\ 203 | CMRC2018 is the first Chinese Span-Extraction Machine Reading Comprehension 204 | Dataset. The task requires to set up a system that reads context, 205 | question and extract the answer from the context (the answer is a continuous 206 | span in the context). 207 | """ 208 | ), 209 | data_url="https://storage.googleapis.com/cluebenchmark/tasks/cmrc2018_public.zip", 210 | url="https://hfl-rc.github.io/cmrc2018/", 211 | citation=textwrap.dedent( 212 | """\ 213 | @article{cmrc2018-dataset, 214 | title={A Span-Extraction Dataset for Chinese Machine Reading Comprehension}, 215 | author={Cui, Yiming and Liu, Ting and Xiao, Li and Chen, Zhipeng and Ma, Wentao and Che, Wanxiang and Wang, Shijin and Hu, Guoping}, 216 | journal={arXiv preprint arXiv:1810.07366}, 217 | year={2018} 218 | }""" 219 | ), 220 | ), 221 | ClueConfig( 222 | name="drcd", 223 | description=textwrap.dedent( 224 | """\ 225 | Delta Reading Comprehension Dataset (DRCD) belongs to the general field of traditional 226 | Chinese machine reading comprehension data set. This data set is expected to become a 227 | standard Chinese reading comprehension data set suitable for transfer learning. 228 | """ 229 | ), 230 | data_url="https://storage.googleapis.com/cluebenchmark/tasks/drcd_public.zip", 231 | url="https://github.com/DRCKnowledgeTeam/DRCD", 232 | ), 233 | ClueConfig( 234 | name="chid", 235 | description=textwrap.dedent( 236 | """\ 237 | Chinese IDiom Dataset for Cloze Test (CHID) contains many masked idioms in the text. 238 | The candidates contain similar idioms to the real ones. 239 | """ 240 | ), 241 | text_features={"candidates": "candidates", "content": "content"}, 242 | data_url="https://storage.googleapis.com/cluebenchmark/tasks/chid_public.zip", 243 | url="https://arxiv.org/abs/1906.01265", 244 | citation=textwrap.dedent( 245 | """\ 246 | @article{Zheng_2019, 247 | title={ChID: A Large-scale Chinese IDiom Dataset for Cloze Test}, 248 | url={http://dx.doi.org/10.18653/v1/P19-1075}, 249 | DOI={10.18653/v1/p19-1075}, 250 | journal={Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics}, 251 | publisher={Association for Computational Linguistics}, 252 | author={Zheng, Chujie and Huang, Minlie and Sun, Aixin}, 253 | year={2019} 254 | }""" 255 | ), 256 | ), 257 | ClueConfig( 258 | name="c3", 259 | description=textwrap.dedent( 260 | """\ 261 | Multiple-Choice Chinese Machine Reading Comprehension (C3, or C^3) is a Chinese 262 | multi-choice reading comprehension data set, including mixed type data sets 263 | such as dialogue and long text. Both the training and validation sets are 264 | the concatenation of the dialogue and long-text subsets. 265 | """ 266 | ), 267 | text_features={"candidates": "candidates", "content": "content"}, 268 | data_url="https://storage.googleapis.com/cluebenchmark/tasks/c3_public.zip", 269 | url="https://arxiv.org/abs/1904.09679", 270 | citation=textwrap.dedent( 271 | """\ 272 | @article{sun2020investigating, 273 | author = {Kai Sun and 274 | Dian Yu and 275 | Dong Yu and 276 | Claire Cardie}, 277 | title = {Investigating Prior Knowledge for Challenging Chinese Machine Reading 278 | Comprehension}, 279 | journal = {Trans. Assoc. Comput. Linguistics}, 280 | volume = {8}, 281 | pages = {141--155}, 282 | year = {2020}, 283 | url = {https://transacl.org/ojs/index.php/tacl/article/view/1882} 284 | }""" 285 | ), 286 | ), 287 | ClueConfig( 288 | name="ocnli", 289 | description=textwrap.dedent( 290 | """\ 291 | OCNLI stands for Original Chinese Natural Language Inference. It is a corpus for 292 | Chinese Natural Language Inference, collected following closely the procedures of MNLI, 293 | but with enhanced strategies aiming for more challenging inference pairs. We want to 294 | emphasize we did not use human/machine translation in creating the dataset, and thus 295 | our Chinese texts are original and not translated. 296 | """ 297 | ), 298 | text_features={"sentence1": "sentence1", "sentence2": "sentence2"}, 299 | label_classes=["neutral", "entailment", "contradiction"], 300 | label_column="label", 301 | data_url="https://github.com/CLUEbenchmark/OCNLI/archive/02d55cb3c7dc984682677b8dd81db6a1e4710720.zip", 302 | data_dir="OCNLI-02d55cb3c7dc984682677b8dd81db6a1e4710720/data/ocnli", 303 | url="https://arxiv.org/abs/2010.05444", 304 | citation=textwrap.dedent( 305 | """\ 306 | @inproceedings{ocnli, 307 | title={OCNLI: Original Chinese Natural Language Inference}, 308 | author={Hai Hu and Kyle Richardson and Liang Xu and Lu Li and Sandra Kuebler and Larry Moss}, 309 | booktitle={Findings of EMNLP}, 310 | year={2020}, 311 | url={https://arxiv.org/abs/2010.05444} 312 | }""" 313 | ), 314 | ), 315 | ClueConfig( 316 | name="diagnostics", 317 | description=textwrap.dedent( 318 | """\ 319 | Diagnostic set, used to evaluate the performance of different models on 9 Chinese language 320 | phenomena summarized by linguists. 321 | Use the model trained on CMNLI to directly predict the result on this diagnostic set. 322 | """ 323 | ), 324 | text_features={"sentence1": "premise", "sentence2": "hypothesis"}, 325 | label_classes=["neutral", "entailment", "contradiction"], 326 | label_column="label", 327 | data_url="https://storage.googleapis.com/cluebenchmark/tasks/clue_diagnostics_public.zip", 328 | ), 329 | ] 330 | 331 | def _info(self): 332 | if self.config.name in [ 333 | "afqmc", 334 | "tnews", 335 | "iflytek", 336 | "cmnli", 337 | "diagnostics", 338 | "ocnli", 339 | ]: 340 | features = { 341 | text_feature: datasets.Value("string") 342 | for text_feature in self.config.text_features.keys() 343 | } 344 | if self.config.label_classes: 345 | features["label"] = datasets.features.ClassLabel( 346 | names=self.config.label_classes 347 | ) 348 | else: 349 | features["label"] = datasets.Value("float32") 350 | features["idx"] = datasets.Value("int32") 351 | elif self.config.name == "cluewsc2020": 352 | features = { 353 | "idx": datasets.Value("int32"), 354 | "text": datasets.Value("string"), 355 | "label": datasets.features.ClassLabel(names=["true", "false"]), 356 | "target": { 357 | "span1_text": datasets.Value("string"), 358 | "span2_text": datasets.Value("string"), 359 | "span1_index": datasets.Value("int32"), 360 | "span2_index": datasets.Value("int32"), 361 | }, 362 | } 363 | elif self.config.name == "csl": 364 | features = { 365 | "idx": datasets.Value("int32"), 366 | "corpus_id": datasets.Value("int32"), 367 | "abst": datasets.Value("string"), 368 | "label": datasets.features.ClassLabel(names=self.config.label_classes), 369 | "keyword": datasets.Sequence(datasets.Value("string")), 370 | } 371 | elif self.config.name in ["cmrc2018", "drcd"]: 372 | features = { 373 | "id": datasets.Value("string"), 374 | "context": datasets.Value("string"), 375 | "question": datasets.Value("string"), 376 | "answers": datasets.Sequence( 377 | { 378 | "text": datasets.Value("string"), 379 | "answer_start": datasets.Value("int32"), 380 | } 381 | ), 382 | } 383 | elif self.config.name == "chid": 384 | features = { 385 | "idx": datasets.Value("int32"), 386 | "candidates": datasets.Sequence(datasets.Value("string")), 387 | "content": datasets.Sequence(datasets.Value("string")), 388 | "answers": datasets.features.Sequence( 389 | { 390 | "text": datasets.Value("string"), 391 | "candidate_id": datasets.Value("int32"), 392 | } 393 | ), 394 | } 395 | elif self.config.name == "c3": 396 | features = { 397 | "id": datasets.Value("int32"), 398 | "context": datasets.Sequence(datasets.Value("string")), 399 | "question": datasets.Value("string"), 400 | "choice": datasets.Sequence(datasets.Value("string")), 401 | "answer": datasets.Value("string"), 402 | } 403 | else: 404 | raise NotImplementedError( 405 | "This task is not implemented. If you believe" 406 | " this task was recently added to the CLUE benchmark, " 407 | "please open a GitHub issue and we will add it." 408 | ) 409 | 410 | return datasets.DatasetInfo( 411 | description=_CLUE_DESCRIPTION, 412 | features=datasets.Features(features), 413 | homepage=self.config.url, 414 | citation=self.config.citation + "\n" + _CLUE_CITATION, 415 | ) 416 | 417 | def _split_generators(self, dl_manager): 418 | dl_dir = dl_manager.download_and_extract(self.config.data_url) 419 | data_dir = os.path.join(dl_dir, self.config.data_dir) 420 | 421 | if self.config.name in {"chid", "c3"}: 422 | test_file = "test1.1.json" 423 | elif self.config.name == "diagnostics": 424 | test_file = "diagnostics_test.json" 425 | else: 426 | test_file = "test.json" 427 | 428 | test_split = datasets.SplitGenerator( 429 | name=datasets.Split.TEST, 430 | gen_kwargs={ 431 | "data_file": os.path.join(data_dir, test_file), 432 | "split": "test", 433 | }, 434 | ) 435 | 436 | split_list = [test_split] 437 | 438 | if self.config.name != "diagnostics": 439 | train_split = datasets.SplitGenerator( 440 | name=datasets.Split.TRAIN, 441 | gen_kwargs={ 442 | "data_file": os.path.join( 443 | data_dir or "", 444 | "train.json" if self.config.name != "c3" else "d-train.json", 445 | ), 446 | "split": "train", 447 | }, 448 | ) 449 | val_split = datasets.SplitGenerator( 450 | name=datasets.Split.VALIDATION, 451 | gen_kwargs={ 452 | "data_file": os.path.join( 453 | data_dir or "", 454 | "dev.json" if self.config.name != "c3" else "d-dev.json", 455 | ), 456 | "split": "dev", 457 | }, 458 | ) 459 | split_list += [train_split, val_split] 460 | 461 | if self.config.name == "cmrc2018": 462 | split_list.append( 463 | datasets.SplitGenerator( 464 | name=datasets.Split("trial"), 465 | gen_kwargs={ 466 | "data_file": os.path.join(data_dir or "", "trial.json"), 467 | "split": "trial", 468 | }, 469 | ) 470 | ) 471 | 472 | return split_list 473 | 474 | def _generate_examples(self, data_file, split): 475 | process_label = self.config.process_label 476 | label_classes = self.config.label_classes 477 | 478 | if self.config.name == "chid" and split != "test": 479 | answer_file = os.path.join( 480 | os.path.dirname(data_file), f"{split}_answer.json" 481 | ) 482 | answer_dict = json.load(open(answer_file, encoding="utf8")) 483 | 484 | if self.config.name == "c3": 485 | if split == "test": 486 | files = [data_file] 487 | else: 488 | data_dir = os.path.dirname(data_file) 489 | files = [ 490 | os.path.join(data_dir, f"{typ}-{split}.json") for typ in ["d", "m"] 491 | ] 492 | data = [] 493 | for f in files: 494 | data_subset = json.load(open(f, encoding="utf8")) 495 | data += data_subset 496 | for idx, entry in enumerate(data): 497 | for qidx, question in enumerate(entry[1]): 498 | example = { 499 | "id": idx if split != "test" else int(question["id"]), 500 | "context": entry[0], 501 | "question": question["question"], 502 | "choice": question["choice"], 503 | "answer": question["answer"] if split != "test" else "", 504 | } 505 | yield f"{idx}_{qidx}", example 506 | 507 | else: 508 | with open(data_file, encoding="utf8") as f: 509 | if self.config.name in ["cmrc2018", "drcd"]: 510 | data = json.load(f) 511 | for example in data["data"]: 512 | for paragraph in example["paragraphs"]: 513 | context = paragraph["context"].strip() 514 | for qa in paragraph["qas"]: 515 | question = qa["question"].strip() 516 | id_ = qa["id"] 517 | 518 | answer_starts = [ 519 | answer["answer_start"] for answer in qa["answers"] 520 | ] 521 | answers = [ 522 | answer["text"].strip() for answer in qa["answers"] 523 | ] 524 | 525 | yield id_, { 526 | "context": context, 527 | "question": question, 528 | "id": id_, 529 | "answers": { 530 | "answer_start": answer_starts, 531 | "text": answers, 532 | }, 533 | } 534 | 535 | else: 536 | for n, line in enumerate(f): 537 | row = json.loads(line) 538 | example = { 539 | feat: row[col] 540 | for feat, col in self.config.text_features.items() 541 | } 542 | example["idx"] = ( 543 | n 544 | if self.config.name != "diagnostics" 545 | else int(row["index"]) 546 | ) 547 | 548 | if ( 549 | self.config.name == "chid" 550 | ): # CHID has a separate gold label file 551 | contents = example["content"] 552 | candidates = example["candidates"] 553 | idiom_list = [] 554 | if split != "test": 555 | for content in contents: 556 | idioms = re.findall(r"#idiom\d+#", content) 557 | for idiom in idioms: 558 | idiom_list.append( 559 | { 560 | "candidate_id": answer_dict[idiom], 561 | "text": candidates[answer_dict[idiom]], 562 | } 563 | ) 564 | example["answers"] = idiom_list 565 | 566 | elif self.config.label_column in row: 567 | label = row[self.config.label_column] 568 | # Notice: some labels in CMNLI and OCNLI are invalid. We drop these data. 569 | if self.config.name in ["cmnli", "ocnli"] and label == "-": 570 | continue 571 | # For some tasks, the label is represented as 0 and 1 in the tsv 572 | # files and needs to be cast to integer to work with the feature. 573 | if label_classes and label not in label_classes: 574 | label = int(label) if label else None 575 | example["label"] = process_label(label) 576 | else: 577 | example["label"] = process_label(-1) 578 | 579 | if self.config.name == "cluewsc2020": 580 | text = example["text"] 581 | s1 = example["target"]["span1_index"] 582 | e1 = s1 + len(example["target"]["span1_text"]) 583 | s2 = example["target"]["span2_index"] 584 | e2 = s2 + len(example["target"]["span2_text"]) 585 | if s1 < s2: 586 | text = ( 587 | text[:s1] 588 | + "_" 589 | + text[s1:e1] 590 | + "_" 591 | + text[e1:s2] 592 | + "[" 593 | + text[s2:e2] 594 | + "]" 595 | + text[e2:] 596 | ) 597 | else: 598 | text = ( 599 | text[:s2] 600 | + "[" 601 | + text[s2:e2] 602 | + "]" 603 | + text[e2:s1] 604 | + "_" 605 | + text[s1:e1] 606 | + "_" 607 | + text[e1:] 608 | ) 609 | example["text"] = text 610 | 611 | # Filter out corrupted rows. 612 | for value in example.values(): 613 | if value is None: 614 | break 615 | else: 616 | yield example["idx"], example 617 | -------------------------------------------------------------------------------- /examples/clue/classification/clue_10.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TensorFlow Datasets Authors and the HuggingFace Datasets Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """A Chinese Language Understanding Evaluation Benchmark (CLUE) benchmark.""" 18 | 19 | 20 | import json 21 | import os 22 | import re 23 | import textwrap 24 | 25 | import datasets 26 | 27 | _CLUE_CITATION = """\ 28 | @misc{xu2020clue, 29 | title={CLUE: A Chinese Language Understanding Evaluation Benchmark}, 30 | author={Liang Xu and Xuanwei Zhang and Lu Li and Hai Hu and Chenjie Cao and Weitang Liu and Junyi Li and Yudong Li and Kai Sun and Yechen Xu and Yiming Cui and Cong Yu and Qianqian Dong and Yin Tian and Dian Yu and Bo Shi and Jun Zeng and Rongzhao Wang and Weijian Xie and Yanting Li and Yina Patterson and Zuoyu Tian and Yiwen Zhang and He Zhou and Shaoweihua Liu and Qipeng Zhao and Cong Yue and Xinrui Zhang and Zhengliang Yang and Zhenzhong Lan}, 31 | year={2020}, 32 | eprint={2004.05986}, 33 | archivePrefix={arXiv}, 34 | primaryClass={cs.CL} 35 | } 36 | """ 37 | 38 | _CLUE_DESCRIPTION = """\ 39 | CLUE, A Chinese Language Understanding Evaluation Benchmark 40 | (https://www.cluebenchmarks.com/) is a collection of resources for training, 41 | evaluating, and analyzing Chinese language understanding systems. 42 | """ 43 | 44 | 45 | class ClueConfig(datasets.BuilderConfig): 46 | """BuilderConfig for CLUE.""" 47 | 48 | def __init__( 49 | self, 50 | data_url, 51 | text_features=None, 52 | label_column=None, 53 | data_dir="", 54 | citation="", 55 | url="", 56 | label_classes=None, 57 | process_label=lambda x: x, 58 | **kwargs, 59 | ): 60 | """BuilderConfig for CLUE. 61 | Args: 62 | text_features: `dict[string, string]`, map from the name of the feature 63 | dict for each text field to the name of the column in the tsv file 64 | label_column: `string`, name of the column in the tsv file corresponding 65 | to the label 66 | data_url: `string`, url to download the zip file from 67 | data_dir: `string`, the path to the folder containing the tsv files in the 68 | downloaded zip 69 | citation: `string`, citation for the data set 70 | url: `string`, url for information about the data set 71 | label_classes: `list[string]`, the list of classes if the label is 72 | categorical. If not provided, then the label will be of type 73 | `datasets.Value('float32')`. 74 | process_label: `Function[string, any]`, function taking in the raw value 75 | of the label and processing it to the form required by the label feature 76 | **kwargs: keyword arguments forwarded to super. 77 | """ 78 | super(ClueConfig, self).__init__( 79 | version=datasets.Version("1.0.0", ""), **kwargs 80 | ) 81 | self.text_features = text_features 82 | self.label_column = label_column 83 | self.label_classes = label_classes 84 | self.data_url = data_url 85 | self.data_dir = data_dir 86 | self.citation = citation 87 | self.url = url 88 | self.process_label = process_label 89 | 90 | 91 | class Clue(datasets.GeneratorBasedBuilder): 92 | """A Chinese Language Understanding Evaluation Benchmark (CLUE) benchmark.""" 93 | 94 | BUILDER_CONFIGS = [ 95 | ClueConfig( 96 | name="afqmc", 97 | description=textwrap.dedent( 98 | """\ 99 | Ant Financial Question Matching Corpus is a dataset for Chinese 100 | question matching (similar to QQP). 101 | """ 102 | ), 103 | text_features={"sentence1": "sentence1", "sentence2": "sentence2"}, 104 | label_classes=["0", "1"], 105 | label_column="label", 106 | data_url="https://storage.googleapis.com/cluebenchmark/tasks/afqmc_public.zip", 107 | url="https://dc.cloud.alipay.com/index#/topic/data?id=8", 108 | ), 109 | ClueConfig( 110 | name="tnews", 111 | description=textwrap.dedent( 112 | """\ 113 | Toutiao Short Text Classification for News is a dataset for Chinese 114 | short news classification. 115 | """ 116 | ), 117 | text_features={"sentence": "sentence"}, 118 | label_classes=[ 119 | "100", 120 | "101", 121 | "102", 122 | "103", 123 | "104", 124 | "106", 125 | "107", 126 | "108", 127 | "109", 128 | "110", 129 | "112", 130 | "113", 131 | "114", 132 | "115", 133 | "116", 134 | ], 135 | label_column="label", 136 | data_url="https://storage.googleapis.com/cluebenchmark/tasks/tnews_public.zip", 137 | url="https://github.com/skdjfla/toutiao-text-classfication-dataset", 138 | ), 139 | ClueConfig( 140 | name="iflytek", 141 | description=textwrap.dedent( 142 | """\ 143 | IFLYTEK Long Text Classification for News is a dataset for Chinese 144 | long text classification. The text is crawled from an app market. 145 | """ 146 | ), 147 | text_features={"sentence": "sentence"}, 148 | label_classes=[str(label) for label in range(119)], 149 | label_column="label", 150 | data_url="https://storage.googleapis.com/cluebenchmark/tasks/iflytek_public.zip", 151 | ), 152 | ClueConfig( 153 | name="cmnli", 154 | description=textwrap.dedent( 155 | """\ 156 | Chinese Multi-Genre NLI is a dataset for Chinese Natural Language 157 | Inference. It consists of XNLI (Chinese subset) and translated MNLI. 158 | """ 159 | ), 160 | text_features={"sentence1": "sentence1", "sentence2": "sentence2"}, 161 | label_classes=["neutral", "entailment", "contradiction"], 162 | label_column="label", 163 | data_url="https://storage.googleapis.com/cluebenchmark/tasks/cmnli_public.zip", 164 | data_dir="cmnli_public", 165 | ), 166 | ClueConfig( 167 | name="cluewsc2020", 168 | description=textwrap.dedent( 169 | """\ 170 | CLUE Winograd Scheme Challenge (CLUEWSC 2020) is a Chinese WSC dataset. 171 | The text is from contemporary literature and annotated by human experts. 172 | The task is to determine which noun the pronoun in the sentence refers to. 173 | The question appears in the form of true and false discrimination. 174 | """ 175 | ), 176 | text_features={"text": "text", "target": "target"}, 177 | label_classes=["false", "true"], 178 | label_column="label", 179 | data_url="https://storage.googleapis.com/cluebenchmark/tasks/cluewsc2020_public.zip", 180 | ), 181 | ClueConfig( 182 | name="csl", 183 | description=textwrap.dedent( 184 | """\ 185 | Chinese Scientific Literature Dataset (CSL) is taken from the abstracts of 186 | Chinese papers and their keywords. The papers are selected from some core 187 | journals of Chinese social sciences and natural sciences. TF-IDF is used to 188 | generate a mixture of fake keywords and real keywords in the paper to construct 189 | abstract-keyword pairs. The task goal is to judge whether the keywords are 190 | all real keywords based on the abstract. 191 | """ 192 | ), 193 | text_features={"abst": "abst", "keyword": "keyword", "corpus_id": "id"}, 194 | label_classes=["0", "1"], 195 | label_column="label", 196 | data_url="https://storage.googleapis.com/cluebenchmark/tasks/csl_public.zip", 197 | url="https://github.com/P01son6415/CSL", 198 | ), 199 | ClueConfig( 200 | name="cmrc2018", 201 | description=textwrap.dedent( 202 | """\ 203 | CMRC2018 is the first Chinese Span-Extraction Machine Reading Comprehension 204 | Dataset. The task requires to set up a system that reads context, 205 | question and extract the answer from the context (the answer is a continuous 206 | span in the context). 207 | """ 208 | ), 209 | data_url="https://storage.googleapis.com/cluebenchmark/tasks/cmrc2018_public.zip", 210 | url="https://hfl-rc.github.io/cmrc2018/", 211 | citation=textwrap.dedent( 212 | """\ 213 | @article{cmrc2018-dataset, 214 | title={A Span-Extraction Dataset for Chinese Machine Reading Comprehension}, 215 | author={Cui, Yiming and Liu, Ting and Xiao, Li and Chen, Zhipeng and Ma, Wentao and Che, Wanxiang and Wang, Shijin and Hu, Guoping}, 216 | journal={arXiv preprint arXiv:1810.07366}, 217 | year={2018} 218 | }""" 219 | ), 220 | ), 221 | ClueConfig( 222 | name="drcd", 223 | description=textwrap.dedent( 224 | """\ 225 | Delta Reading Comprehension Dataset (DRCD) belongs to the general field of traditional 226 | Chinese machine reading comprehension data set. This data set is expected to become a 227 | standard Chinese reading comprehension data set suitable for transfer learning. 228 | """ 229 | ), 230 | data_url="https://storage.googleapis.com/cluebenchmark/tasks/drcd_public.zip", 231 | url="https://github.com/DRCKnowledgeTeam/DRCD", 232 | ), 233 | ClueConfig( 234 | name="chid", 235 | description=textwrap.dedent( 236 | """\ 237 | Chinese IDiom Dataset for Cloze Test (CHID) contains many masked idioms in the text. 238 | The candidates contain similar idioms to the real ones. 239 | """ 240 | ), 241 | text_features={"candidates": "candidates", "content": "content"}, 242 | data_url="https://storage.googleapis.com/cluebenchmark/tasks/chid_public.zip", 243 | url="https://arxiv.org/abs/1906.01265", 244 | citation=textwrap.dedent( 245 | """\ 246 | @article{Zheng_2019, 247 | title={ChID: A Large-scale Chinese IDiom Dataset for Cloze Test}, 248 | url={http://dx.doi.org/10.18653/v1/P19-1075}, 249 | DOI={10.18653/v1/p19-1075}, 250 | journal={Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics}, 251 | publisher={Association for Computational Linguistics}, 252 | author={Zheng, Chujie and Huang, Minlie and Sun, Aixin}, 253 | year={2019} 254 | }""" 255 | ), 256 | ), 257 | ClueConfig( 258 | name="c3", 259 | description=textwrap.dedent( 260 | """\ 261 | Multiple-Choice Chinese Machine Reading Comprehension (C3, or C^3) is a Chinese 262 | multi-choice reading comprehension data set, including mixed type data sets 263 | such as dialogue and long text. Both the training and validation sets are 264 | the concatenation of the dialogue and long-text subsets. 265 | """ 266 | ), 267 | text_features={"candidates": "candidates", "content": "content"}, 268 | data_url="https://storage.googleapis.com/cluebenchmark/tasks/c3_public.zip", 269 | url="https://arxiv.org/abs/1904.09679", 270 | citation=textwrap.dedent( 271 | """\ 272 | @article{sun2020investigating, 273 | author = {Kai Sun and 274 | Dian Yu and 275 | Dong Yu and 276 | Claire Cardie}, 277 | title = {Investigating Prior Knowledge for Challenging Chinese Machine Reading 278 | Comprehension}, 279 | journal = {Trans. Assoc. Comput. Linguistics}, 280 | volume = {8}, 281 | pages = {141--155}, 282 | year = {2020}, 283 | url = {https://transacl.org/ojs/index.php/tacl/article/view/1882} 284 | }""" 285 | ), 286 | ), 287 | ClueConfig( 288 | name="ocnli", 289 | description=textwrap.dedent( 290 | """\ 291 | OCNLI stands for Original Chinese Natural Language Inference. It is a corpus for 292 | Chinese Natural Language Inference, collected following closely the procedures of MNLI, 293 | but with enhanced strategies aiming for more challenging inference pairs. We want to 294 | emphasize we did not use human/machine translation in creating the dataset, and thus 295 | our Chinese texts are original and not translated. 296 | """ 297 | ), 298 | text_features={"sentence1": "sentence1", "sentence2": "sentence2"}, 299 | label_classes=["neutral", "entailment", "contradiction"], 300 | label_column="label", 301 | data_url="https://github.com/CLUEbenchmark/OCNLI/archive/02d55cb3c7dc984682677b8dd81db6a1e4710720.zip", 302 | data_dir="OCNLI-02d55cb3c7dc984682677b8dd81db6a1e4710720/data/ocnli", 303 | url="https://arxiv.org/abs/2010.05444", 304 | citation=textwrap.dedent( 305 | """\ 306 | @inproceedings{ocnli, 307 | title={OCNLI: Original Chinese Natural Language Inference}, 308 | author={Hai Hu and Kyle Richardson and Liang Xu and Lu Li and Sandra Kuebler and Larry Moss}, 309 | booktitle={Findings of EMNLP}, 310 | year={2020}, 311 | url={https://arxiv.org/abs/2010.05444} 312 | }""" 313 | ), 314 | ), 315 | ClueConfig( 316 | name="diagnostics", 317 | description=textwrap.dedent( 318 | """\ 319 | Diagnostic set, used to evaluate the performance of different models on 9 Chinese language 320 | phenomena summarized by linguists. 321 | Use the model trained on CMNLI to directly predict the result on this diagnostic set. 322 | """ 323 | ), 324 | text_features={"sentence1": "premise", "sentence2": "hypothesis"}, 325 | label_classes=["neutral", "entailment", "contradiction"], 326 | label_column="label", 327 | data_url="https://storage.googleapis.com/cluebenchmark/tasks/clue_diagnostics_public.zip", 328 | ), 329 | ] 330 | 331 | def _info(self): 332 | if self.config.name in [ 333 | "afqmc", 334 | "tnews", 335 | "iflytek", 336 | "cmnli", 337 | "diagnostics", 338 | "ocnli", 339 | ]: 340 | features = { 341 | text_feature: datasets.Value("string") 342 | for text_feature in self.config.text_features.keys() 343 | } 344 | if self.config.label_classes: 345 | features["label"] = datasets.features.ClassLabel( 346 | names=self.config.label_classes 347 | ) 348 | else: 349 | features["label"] = datasets.Value("float32") 350 | features["idx"] = datasets.Value("int32") 351 | elif self.config.name == "cluewsc2020": 352 | features = { 353 | "idx": datasets.Value("int32"), 354 | "text": datasets.Value("string"), 355 | "label": datasets.features.ClassLabel(names=["true", "false"]), 356 | "target": { 357 | "span1_text": datasets.Value("string"), 358 | "span2_text": datasets.Value("string"), 359 | "span1_index": datasets.Value("int32"), 360 | "span2_index": datasets.Value("int32"), 361 | }, 362 | } 363 | elif self.config.name == "csl": 364 | features = { 365 | "idx": datasets.Value("int32"), 366 | "corpus_id": datasets.Value("int32"), 367 | "abst": datasets.Value("string"), 368 | "label": datasets.features.ClassLabel(names=self.config.label_classes), 369 | "keyword": datasets.Sequence(datasets.Value("string")), 370 | } 371 | elif self.config.name in ["cmrc2018", "drcd"]: 372 | features = { 373 | "id": datasets.Value("string"), 374 | "context": datasets.Value("string"), 375 | "question": datasets.Value("string"), 376 | "answers": datasets.Sequence( 377 | { 378 | "text": datasets.Value("string"), 379 | "answer_start": datasets.Value("int32"), 380 | } 381 | ), 382 | } 383 | elif self.config.name == "chid": 384 | features = { 385 | "idx": datasets.Value("int32"), 386 | "candidates": datasets.Sequence(datasets.Value("string")), 387 | "content": datasets.Sequence(datasets.Value("string")), 388 | "answers": datasets.features.Sequence( 389 | { 390 | "text": datasets.Value("string"), 391 | "candidate_id": datasets.Value("int32"), 392 | } 393 | ), 394 | } 395 | elif self.config.name == "c3": 396 | features = { 397 | "id": datasets.Value("int32"), 398 | "context": datasets.Sequence(datasets.Value("string")), 399 | "question": datasets.Value("string"), 400 | "choice": datasets.Sequence(datasets.Value("string")), 401 | "answer": datasets.Value("string"), 402 | } 403 | else: 404 | raise NotImplementedError( 405 | "This task is not implemented. If you believe" 406 | " this task was recently added to the CLUE benchmark, " 407 | "please open a GitHub issue and we will add it." 408 | ) 409 | 410 | return datasets.DatasetInfo( 411 | description=_CLUE_DESCRIPTION, 412 | features=datasets.Features(features), 413 | homepage=self.config.url, 414 | citation=self.config.citation + "\n" + _CLUE_CITATION, 415 | ) 416 | 417 | def _split_generators(self, dl_manager): 418 | dl_dir = dl_manager.download_and_extract(self.config.data_url) 419 | data_dir = os.path.join(dl_dir, self.config.data_dir) 420 | 421 | if self.config.name in {"chid", "c3"}: 422 | test_file = "test1.1.json" 423 | elif self.config.name == "diagnostics": 424 | test_file = "diagnostics_test.json" 425 | else: 426 | if self.config.name in {"tnews", "cluewsc2020"}: 427 | test_file = "test1.0.json" 428 | else: 429 | test_file = "test.json" 430 | 431 | test_split = datasets.SplitGenerator( 432 | name=datasets.Split.TEST, 433 | gen_kwargs={ 434 | "data_file": os.path.join(data_dir, test_file), 435 | "split": "test", 436 | }, 437 | ) 438 | 439 | split_list = [test_split] 440 | 441 | if self.config.name != "diagnostics": 442 | train_split = datasets.SplitGenerator( 443 | name=datasets.Split.TRAIN, 444 | gen_kwargs={ 445 | "data_file": os.path.join( 446 | data_dir or "", 447 | "train.json" if self.config.name != "c3" else "d-train.json", 448 | ), 449 | "split": "train", 450 | }, 451 | ) 452 | val_split = datasets.SplitGenerator( 453 | name=datasets.Split.VALIDATION, 454 | gen_kwargs={ 455 | "data_file": os.path.join( 456 | data_dir or "", 457 | "dev.json" if self.config.name != "c3" else "d-dev.json", 458 | ), 459 | "split": "dev", 460 | }, 461 | ) 462 | split_list += [train_split, val_split] 463 | 464 | if self.config.name == "cmrc2018": 465 | split_list.append( 466 | datasets.SplitGenerator( 467 | name=datasets.Split("trial"), 468 | gen_kwargs={ 469 | "data_file": os.path.join(data_dir or "", "trial.json"), 470 | "split": "trial", 471 | }, 472 | ) 473 | ) 474 | 475 | return split_list 476 | 477 | def _generate_examples(self, data_file, split): 478 | process_label = self.config.process_label 479 | label_classes = self.config.label_classes 480 | 481 | if self.config.name == "chid" and split != "test": 482 | answer_file = os.path.join( 483 | os.path.dirname(data_file), f"{split}_answer.json" 484 | ) 485 | answer_dict = json.load(open(answer_file, encoding="utf8")) 486 | 487 | if self.config.name == "c3": 488 | if split == "test": 489 | files = [data_file] 490 | else: 491 | data_dir = os.path.dirname(data_file) 492 | files = [ 493 | os.path.join(data_dir, f"{typ}-{split}.json") for typ in ["d", "m"] 494 | ] 495 | data = [] 496 | for f in files: 497 | data_subset = json.load(open(f, encoding="utf8")) 498 | data += data_subset 499 | for idx, entry in enumerate(data): 500 | for qidx, question in enumerate(entry[1]): 501 | example = { 502 | "id": idx if split != "test" else int(question["id"]), 503 | "context": entry[0], 504 | "question": question["question"], 505 | "choice": question["choice"], 506 | "answer": question["answer"] if split != "test" else "", 507 | } 508 | yield f"{idx}_{qidx}", example 509 | 510 | else: 511 | with open(data_file, encoding="utf8") as f: 512 | if self.config.name in ["cmrc2018", "drcd"]: 513 | data = json.load(f) 514 | for example in data["data"]: 515 | for paragraph in example["paragraphs"]: 516 | context = paragraph["context"].strip() 517 | for qa in paragraph["qas"]: 518 | question = qa["question"].strip() 519 | id_ = qa["id"] 520 | 521 | answer_starts = [ 522 | answer["answer_start"] for answer in qa["answers"] 523 | ] 524 | answers = [ 525 | answer["text"].strip() for answer in qa["answers"] 526 | ] 527 | 528 | yield id_, { 529 | "context": context, 530 | "question": question, 531 | "id": id_, 532 | "answers": { 533 | "answer_start": answer_starts, 534 | "text": answers, 535 | }, 536 | } 537 | 538 | else: 539 | for n, line in enumerate(f): 540 | row = json.loads(line) 541 | example = { 542 | feat: row[col] 543 | for feat, col in self.config.text_features.items() 544 | } 545 | example["idx"] = ( 546 | n 547 | if self.config.name != "diagnostics" 548 | else int(row["index"]) 549 | ) 550 | 551 | if ( 552 | self.config.name == "chid" 553 | ): # CHID has a separate gold label file 554 | contents = example["content"] 555 | candidates = example["candidates"] 556 | idiom_list = [] 557 | if split != "test": 558 | for content in contents: 559 | idioms = re.findall(r"#idiom\d+#", content) 560 | for idiom in idioms: 561 | idiom_list.append( 562 | { 563 | "candidate_id": answer_dict[idiom], 564 | "text": candidates[answer_dict[idiom]], 565 | } 566 | ) 567 | example["answers"] = idiom_list 568 | 569 | elif self.config.label_column in row: 570 | label = row[self.config.label_column] 571 | # Notice: some labels in CMNLI and OCNLI are invalid. We drop these data. 572 | if self.config.name in ["cmnli", "ocnli"] and label == "-": 573 | continue 574 | # For some tasks, the label is represented as 0 and 1 in the tsv 575 | # files and needs to be cast to integer to work with the feature. 576 | if label_classes and label not in label_classes: 577 | label = int(label) if label else None 578 | example["label"] = process_label(label) 579 | else: 580 | example["label"] = process_label(-1) 581 | 582 | if self.config.name == "cluewsc2020": 583 | text = example["text"] 584 | s1 = example["target"]["span1_index"] 585 | e1 = s1 + len(example["target"]["span1_text"]) 586 | s2 = example["target"]["span2_index"] 587 | e2 = s2 + len(example["target"]["span2_text"]) 588 | if s1 < s2: 589 | text = ( 590 | text[:s1] 591 | + "_" 592 | + text[s1:e1] 593 | + "_" 594 | + text[e1:s2] 595 | + "[" 596 | + text[s2:e2] 597 | + "]" 598 | + text[e2:] 599 | ) 600 | else: 601 | text = ( 602 | text[:s2] 603 | + "[" 604 | + text[s2:e2] 605 | + "]" 606 | + text[e2:s1] 607 | + "_" 608 | + text[s1:e1] 609 | + "_" 610 | + text[e1:] 611 | ) 612 | example["text"] = text 613 | 614 | # Filter out corrupted rows. 615 | for value in example.values(): 616 | if value is None: 617 | break 618 | else: 619 | yield example["idx"], example 620 | -------------------------------------------------------------------------------- /examples/clue/classification/run_clue_no_trainer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ Finetuning a 🤗 Transformers model for sequence classification on CLUE.""" 16 | import argparse 17 | import json 18 | import logging 19 | import math 20 | import os 21 | import random 22 | from pathlib import Path 23 | 24 | import datasets 25 | import torch 26 | import transformers 27 | from accelerate import Accelerator 28 | from accelerate.utils import set_seed 29 | from datasets import load_dataset, load_metric 30 | from huggingface_hub import Repository 31 | from torch.utils.data import DataLoader 32 | from tqdm.auto import tqdm 33 | from transformers import ( 34 | Adafactor, 35 | AdamW, 36 | AutoConfig, 37 | AutoModelForSequenceClassification, 38 | AutoTokenizer, 39 | DataCollatorWithPadding, 40 | PretrainedConfig, 41 | SchedulerType, 42 | default_data_collator, 43 | get_scheduler, 44 | ) 45 | from transformers.utils import get_full_repo_name 46 | from transformers.utils.versions import require_version 47 | 48 | from roformer import RoFormerConfig, RoFormerForSequenceClassification 49 | 50 | logger = logging.getLogger(__name__) 51 | 52 | require_version( 53 | "datasets>=1.8.0", 54 | "To fix: pip install -r examples/pytorch/text-classification/requirements.txt", 55 | ) 56 | 57 | task_to_keys = { 58 | "iflytek": ("sentence", None), 59 | "tnews": ("sentence", None), 60 | "afqmc": ("sentence1", "sentence2"), 61 | "cmnli": ("sentence1", "sentence2"), 62 | "ocnli": ("sentence1", "sentence2"), 63 | "cluewsc2020": ("text", None), 64 | "csl": ("keyword", "abst"), 65 | } 66 | 67 | 68 | def parse_args(): 69 | parser = argparse.ArgumentParser( 70 | description="Finetune a transformers model on a text classification task" 71 | ) 72 | parser.add_argument( 73 | "--task_name", 74 | type=str, 75 | default=None, 76 | help="The name of the CLUE task to train on.", 77 | choices=list(task_to_keys.keys()), 78 | ) 79 | parser.add_argument( 80 | "--train_file", 81 | type=str, 82 | default=None, 83 | help="A csv or a json file containing the training data.", 84 | ) 85 | parser.add_argument( 86 | "--validation_file", 87 | type=str, 88 | default=None, 89 | help="A csv or a json file containing the validation data.", 90 | ) 91 | parser.add_argument( 92 | "--max_length", 93 | type=int, 94 | default=128, 95 | help=( 96 | "The maximum total input sequence length after tokenization. Sequences longer than this will be truncated," 97 | " sequences shorter will be padded if `--pad_to_max_lengh` is passed." 98 | ), 99 | ) 100 | parser.add_argument( 101 | "--pad_to_max_length", 102 | action="store_true", 103 | help="If passed, pad all samples to `max_length`. Otherwise, dynamic padding is used.", 104 | ) 105 | parser.add_argument( 106 | "--model_name_or_path", 107 | type=str, 108 | help="Path to pretrained model or model identifier from huggingface.co/models.", 109 | required=True, 110 | ) 111 | parser.add_argument( 112 | "--use_slow_tokenizer", 113 | action="store_true", 114 | help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).", 115 | ) 116 | parser.add_argument( 117 | "--per_device_train_batch_size", 118 | type=int, 119 | default=8, 120 | help="Batch size (per device) for the training dataloader.", 121 | ) 122 | parser.add_argument( 123 | "--per_device_eval_batch_size", 124 | type=int, 125 | default=8, 126 | help="Batch size (per device) for the evaluation dataloader.", 127 | ) 128 | parser.add_argument( 129 | "--learning_rate", 130 | type=float, 131 | default=5e-5, 132 | help="Initial learning rate (after the potential warmup period) to use.", 133 | ) 134 | parser.add_argument( 135 | "--weight_decay", type=float, default=0.0, help="Weight decay to use." 136 | ) 137 | parser.add_argument( 138 | "--num_train_epochs", 139 | type=int, 140 | default=3, 141 | help="Total number of training epochs to perform.", 142 | ) 143 | parser.add_argument( 144 | "--max_train_steps", 145 | type=int, 146 | default=None, 147 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 148 | ) 149 | parser.add_argument( 150 | "--gradient_accumulation_steps", 151 | type=int, 152 | default=1, 153 | help="Number of updates steps to accumulate before performing a backward/update pass.", 154 | ) 155 | parser.add_argument( 156 | "--lr_scheduler_type", 157 | type=SchedulerType, 158 | default="linear", 159 | help="The scheduler type to use.", 160 | choices=[ 161 | "linear", 162 | "cosine", 163 | "cosine_with_restarts", 164 | "polynomial", 165 | "constant", 166 | "constant_with_warmup", 167 | ], 168 | ) 169 | parser.add_argument( 170 | "--num_warmup_steps_or_radios", 171 | type=eval, 172 | default=0.1, 173 | help="Number of steps for the warmup in the lr scheduler.", 174 | ) 175 | parser.add_argument( 176 | "--output_dir", type=str, default=None, help="Where to store the final model." 177 | ) 178 | parser.add_argument( 179 | "--seed", type=int, default=None, help="A seed for reproducible training." 180 | ) 181 | parser.add_argument( 182 | "--push_to_hub", 183 | action="store_true", 184 | help="Whether or not to push the model to the Hub.", 185 | ) 186 | parser.add_argument( 187 | "--hub_model_id", 188 | type=str, 189 | help="The name of the repository to keep in sync with the local `output_dir`.", 190 | ) 191 | parser.add_argument( 192 | "--hub_token", type=str, help="The token to use to push to the Model Hub." 193 | ) 194 | parser.add_argument( 195 | "--checkpointing_steps", 196 | type=str, 197 | default=None, 198 | help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.", 199 | ) 200 | parser.add_argument( 201 | "--resume_from_checkpoint", 202 | type=str, 203 | default=None, 204 | help="If the training should continue from a checkpoint folder.", 205 | ) 206 | parser.add_argument( 207 | "--with_tracking", 208 | action="store_true", 209 | help="Whether to load in all available experiment trackers from the environment and use them for logging.", 210 | ) 211 | parser.add_argument( 212 | "--max_grad_norm", default=None, type=float, help="Max gradient norm." 213 | ) 214 | parser.add_argument( 215 | "--logging_steps", 216 | type=int, 217 | default=100, 218 | help="logging_steps.", 219 | ) 220 | parser.add_argument( 221 | "--adam_epsilon", default=1e-8, type=float, help="Epsilon for AdamW optimizer." 222 | ) 223 | args = parser.parse_args() 224 | 225 | # Sanity checks 226 | if ( 227 | args.task_name is None 228 | and args.train_file is None 229 | and args.validation_file is None 230 | ): 231 | raise ValueError("Need either a task name or a training/validation file.") 232 | else: 233 | if args.train_file is not None: 234 | extension = args.train_file.split(".")[-1] 235 | assert extension in [ 236 | "csv", 237 | "json", 238 | ], "`train_file` should be a csv or a json file." 239 | if args.validation_file is not None: 240 | extension = args.validation_file.split(".")[-1] 241 | assert extension in [ 242 | "csv", 243 | "json", 244 | ], "`validation_file` should be a csv or a json file." 245 | 246 | if args.push_to_hub: 247 | assert ( 248 | args.output_dir is not None 249 | ), "Need an `output_dir` to create a repo when `--push_to_hub` is passed." 250 | 251 | return args 252 | 253 | 254 | def main(): 255 | args = parse_args() 256 | 257 | # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. 258 | # If we're using tracking, we also need to initialize it here and it will pick up all supported trackers in the environment 259 | accelerator = ( 260 | Accelerator(log_with="tensorboard", logging_dir=args.output_dir) 261 | if args.with_tracking 262 | else Accelerator() 263 | ) 264 | # Make one log on every process with the configuration for debugging. 265 | logging.basicConfig( 266 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 267 | datefmt="%m/%d/%Y %H:%M:%S", 268 | level=logging.INFO, 269 | ) 270 | 271 | # Setup logging, we only want one process per machine to log things on the screen. 272 | # accelerator.is_local_main_process is only True for one process per machine. 273 | logger.setLevel( 274 | logging.INFO if accelerator.is_local_main_process else logging.ERROR 275 | ) 276 | if accelerator.is_local_main_process: 277 | if args.output_dir is not None: 278 | os.makedirs(args.output_dir, exist_ok=True) 279 | logger.addHandler( 280 | logging.FileHandler( 281 | os.path.join(args.output_dir, "training.log"), "w", encoding="utf-8" 282 | ) 283 | ) 284 | datasets.utils.logging.set_verbosity_warning() 285 | transformers.utils.logging.set_verbosity_info() 286 | else: 287 | datasets.utils.logging.set_verbosity_error() 288 | transformers.utils.logging.set_verbosity_error() 289 | 290 | logger.info(accelerator.state) 291 | # If passed along, set the training seed now. 292 | if args.seed is not None: 293 | set_seed(args.seed) 294 | 295 | # Handle the repository creation 296 | if accelerator.is_main_process: 297 | if args.push_to_hub: 298 | if args.hub_model_id is None: 299 | repo_name = get_full_repo_name( 300 | Path(args.output_dir).name, token=args.hub_token 301 | ) 302 | else: 303 | repo_name = args.hub_model_id 304 | repo = Repository(args.output_dir, clone_from=repo_name) 305 | 306 | with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: 307 | if "step_*" not in gitignore: 308 | gitignore.write("step_*\n") 309 | if "epoch_*" not in gitignore: 310 | gitignore.write("epoch_*\n") 311 | elif args.output_dir is not None: 312 | os.makedirs(args.output_dir, exist_ok=True) 313 | accelerator.wait_for_everyone() 314 | 315 | # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) 316 | # or specify a CLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub). 317 | 318 | # For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the 319 | # sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named 320 | # label if at least two columns are provided. 321 | 322 | # If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this 323 | # single column. You can easily tweak this behavior (see below) 324 | 325 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently 326 | # download the dataset. 327 | if args.task_name is not None: 328 | # Downloading and loading a dataset from the hub. 329 | raw_datasets = load_dataset( 330 | "clue_11.py", args.task_name, cache_dir="./clue_caches" 331 | ) 332 | else: 333 | # Loading the dataset from local csv or json file. 334 | data_files = {} 335 | if args.train_file is not None: 336 | data_files["train"] = args.train_file 337 | if args.validation_file is not None: 338 | data_files["validation"] = args.validation_file 339 | extension = ( 340 | args.train_file if args.train_file is not None else args.valid_file 341 | ).split(".")[-1] 342 | raw_datasets = load_dataset(extension, data_files=data_files) 343 | # See more about loading any type of standard or custom dataset at 344 | # https://huggingface.co/docs/datasets/loading_datasets.html. 345 | 346 | # Labels 347 | if args.task_name is not None: 348 | label_list = raw_datasets["train"].features["label"].names 349 | num_labels = len(label_list) 350 | 351 | # Load pretrained model and tokenizer 352 | # 353 | # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently 354 | # download model & vocab. 355 | config = RoFormerConfig.from_pretrained( 356 | args.model_name_or_path, 357 | num_labels=num_labels, 358 | finetuning_task=args.task_name, 359 | summary_type="first", 360 | ) 361 | tokenizer = AutoTokenizer.from_pretrained( 362 | args.model_name_or_path, use_fast=not args.use_slow_tokenizer 363 | ) 364 | model = RoFormerForSequenceClassification.from_pretrained( 365 | args.model_name_or_path, 366 | from_tf=bool(".ckpt" in args.model_name_or_path), 367 | config=config, 368 | ) 369 | # Preprocessing the datasets 370 | if args.task_name is not None: 371 | sentence1_key, sentence2_key = task_to_keys[args.task_name] 372 | else: 373 | # Again, we try to have some nice defaults but don't hesitate to tweak to your use case. 374 | non_label_column_names = [ 375 | name for name in raw_datasets["train"].column_names if name != "label" 376 | ] 377 | if ( 378 | "sentence1" in non_label_column_names 379 | and "sentence2" in non_label_column_names 380 | ): 381 | sentence1_key, sentence2_key = "sentence1", "sentence2" 382 | else: 383 | if len(non_label_column_names) >= 2: 384 | sentence1_key, sentence2_key = non_label_column_names[:2] 385 | else: 386 | sentence1_key, sentence2_key = non_label_column_names[0], None 387 | 388 | # Some models have set the order of the labels to use, so let's make sure we do use it. 389 | label_to_id = None 390 | if ( 391 | model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id 392 | and args.task_name is not None 393 | ): 394 | # Some have all caps in their config, some don't. 395 | label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()} 396 | if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)): 397 | logger.info( 398 | f"The configuration of the model provided the following label correspondence: {label_name_to_id}. " 399 | "Using it!" 400 | ) 401 | label_to_id = { 402 | i: label_name_to_id[label_list[i]] for i in range(num_labels) 403 | } 404 | else: 405 | logger.warning( 406 | "Your model seems to have been trained with labels, but they don't match the dataset: ", 407 | f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}." 408 | "\nIgnoring the model labels as a result.", 409 | ) 410 | elif args.task_name is None: 411 | label_to_id = {v: i for i, v in enumerate(label_list)} 412 | 413 | if label_to_id is not None: 414 | model.config.label2id = label_to_id 415 | model.config.id2label = {id: label for label, id in config.label2id.items()} 416 | elif args.task_name is not None: 417 | model.config.label2id = {l: i for i, l in enumerate(label_list)} 418 | model.config.id2label = {id: label for label, id in config.label2id.items()} 419 | 420 | padding = "max_length" if args.pad_to_max_length else False 421 | 422 | def preprocess_function(examples): 423 | # Tokenize the texts 424 | if sentence1_key == "keyword": 425 | k1 = [";".join(l) for l in examples[sentence1_key]] 426 | else: 427 | k1 = examples[sentence1_key] 428 | texts = (k1,) if sentence2_key is None else (k1, examples[sentence2_key]) 429 | result = tokenizer( 430 | *texts, 431 | padding=padding, 432 | max_length=args.max_length, 433 | truncation=True, 434 | return_token_type_ids=False, 435 | ) 436 | 437 | if "label" in examples: 438 | if label_to_id is not None: 439 | # Map labels to IDs (not necessary for CLUE tasks) 440 | result["labels"] = [label_to_id[l] for l in examples["label"]] 441 | else: 442 | # In all cases, rename the column to labels because the model will expect that. 443 | result["labels"] = examples["label"] 444 | return result 445 | 446 | with accelerator.main_process_first(): 447 | processed_datasets = raw_datasets.map( 448 | preprocess_function, 449 | batched=True, 450 | remove_columns=raw_datasets["train"].column_names, 451 | desc="Running tokenizer on dataset", 452 | ) 453 | 454 | train_dataset = processed_datasets["train"] 455 | eval_dataset = processed_datasets["validation"] 456 | 457 | # Log a few random samples from the training set: 458 | for index in random.sample(range(len(train_dataset)), 3): 459 | logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") 460 | 461 | # DataLoaders creation: 462 | if args.pad_to_max_length: 463 | # If padding was already done ot max length, we use the default data collator that will just convert everything 464 | # to tensors. 465 | data_collator = default_data_collator 466 | else: 467 | # Otherwise, `DataCollatorWithPadding` will apply dynamic padding for us (by padding to the maximum length of 468 | # the samples passed). When using mixed precision, we add `pad_to_multiple_of=8` to pad all tensors to multiple 469 | # of 8s, which will enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta). 470 | data_collator = DataCollatorWithPadding( 471 | tokenizer, pad_to_multiple_of=(8 if accelerator.use_fp16 else None) 472 | ) 473 | 474 | train_dataloader = DataLoader( 475 | train_dataset, 476 | shuffle=True, 477 | collate_fn=data_collator, 478 | batch_size=args.per_device_train_batch_size, 479 | ) 480 | eval_dataloader = DataLoader( 481 | eval_dataset, 482 | collate_fn=data_collator, 483 | batch_size=args.per_device_eval_batch_size, 484 | ) 485 | 486 | # Optimizer 487 | # Split weights in two groups, one with weight decay and the other not. 488 | no_decay = ["bias", "LayerNorm.weight", "norm"] 489 | optimizer_grouped_parameters = [ 490 | { 491 | "params": [ 492 | p 493 | for n, p in model.named_parameters() 494 | if not any(nd in n for nd in no_decay) 495 | ], 496 | "weight_decay": args.weight_decay, 497 | }, 498 | { 499 | "params": [ 500 | p 501 | for n, p in model.named_parameters() 502 | if any(nd in n for nd in no_decay) 503 | ], 504 | "weight_decay": 0.0, 505 | }, 506 | ] 507 | # optimizer = Adafactor(optimizer_grouped_parameters, 508 | # lr=args.learning_rate, beta1=0.9, relative_step=False) 509 | optimizer = AdamW( 510 | optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon 511 | ) 512 | 513 | # Scheduler and math around the number of training steps. 514 | num_update_steps_per_epoch = math.ceil( 515 | len(train_dataloader) / args.gradient_accumulation_steps 516 | ) 517 | if args.max_train_steps is None: 518 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 519 | else: 520 | args.num_train_epochs = math.ceil( 521 | args.max_train_steps / num_update_steps_per_epoch 522 | ) 523 | 524 | # compute the number of warmup steps 525 | args.num_warmup_steps = ( 526 | math.ceil(args.max_train_steps * args.num_warmup_steps_or_radios) 527 | if isinstance(args.num_warmup_steps_or_radios, float) 528 | else args.num_warmup_steps_or_radios 529 | ) 530 | lr_scheduler = get_scheduler( 531 | name=args.lr_scheduler_type, 532 | optimizer=optimizer, 533 | num_warmup_steps=args.num_warmup_steps, 534 | num_training_steps=args.max_train_steps, 535 | ) 536 | 537 | # Prepare everything with our `accelerator`. 538 | ( 539 | model, 540 | optimizer, 541 | train_dataloader, 542 | eval_dataloader, 543 | lr_scheduler, 544 | ) = accelerator.prepare( 545 | model, optimizer, train_dataloader, eval_dataloader, lr_scheduler 546 | ) 547 | 548 | # Figure out how many steps we should save the Accelerator states 549 | if hasattr(args.checkpointing_steps, "isdigit"): 550 | checkpointing_steps = args.checkpointing_steps 551 | if args.checkpointing_steps.isdigit(): 552 | checkpointing_steps = int(args.checkpointing_steps) 553 | else: 554 | checkpointing_steps = None 555 | 556 | # We need to initialize the trackers we use, and also store our configuration 557 | if args.with_tracking: 558 | experiment_config = vars(args) 559 | # TensorBoard cannot log Enums, need the raw value 560 | experiment_config["lr_scheduler_type"] = experiment_config[ 561 | "lr_scheduler_type" 562 | ].value 563 | accelerator.init_trackers("clue_no_trainer", experiment_config) 564 | 565 | # Get the metric function 566 | metric = load_metric("accuracy.py") 567 | all_metric = [] 568 | max_metric = 0.0 569 | # Train! 570 | total_batch_size = ( 571 | args.per_device_train_batch_size 572 | * accelerator.num_processes 573 | * args.gradient_accumulation_steps 574 | ) 575 | 576 | logger.info("***** Running training *****") 577 | logger.info(f" Num examples = {len(train_dataset)}") 578 | logger.info(f" Num Epochs = {args.num_train_epochs}") 579 | logger.info( 580 | f" Instantaneous batch size per device = {args.per_device_train_batch_size}" 581 | ) 582 | logger.info( 583 | f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" 584 | ) 585 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 586 | logger.info(f" Total optimization steps = {args.max_train_steps}") 587 | # Only show the progress bar once on each machine. 588 | progress_bar = tqdm( 589 | range(args.max_train_steps), disable=not accelerator.is_local_main_process 590 | ) 591 | completed_steps = 0 592 | starting_epoch = 0 593 | # Potentially load in the weights and states from a previous save 594 | if args.resume_from_checkpoint: 595 | if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": 596 | accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") 597 | accelerator.load_state(args.resume_from_checkpoint) 598 | path = os.path.basename(args.resume_from_checkpoint) 599 | else: 600 | # Get the most recent checkpoint 601 | dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] 602 | dirs.sort(key=os.path.getctime) 603 | # Sorts folders by date modified, most recent checkpoint is the last 604 | path = dirs[-1] 605 | # Extract `epoch_{i}` or `step_{i}` 606 | training_difference = os.path.splitext(path)[0] 607 | 608 | if "epoch" in training_difference: 609 | starting_epoch = int(training_difference.replace("epoch_", "")) + 1 610 | resume_step = None 611 | else: 612 | resume_step = int(training_difference.replace("step_", "")) 613 | starting_epoch = resume_step // len(train_dataloader) 614 | resume_step -= starting_epoch * len(train_dataloader) 615 | 616 | for epoch in range(starting_epoch, args.num_train_epochs): 617 | model.train() 618 | if args.with_tracking: 619 | total_loss = 0 620 | for step, batch in enumerate(train_dataloader): 621 | # We need to skip steps until we reach the resumed step 622 | if args.resume_from_checkpoint and epoch == starting_epoch: 623 | if resume_step is not None and step < resume_step: 624 | completed_steps += 1 625 | continue 626 | outputs = model(**batch) 627 | loss = outputs.loss 628 | # We keep track of the loss at each epoch 629 | if args.with_tracking: 630 | total_loss += loss.detach().float() 631 | loss = loss / args.gradient_accumulation_steps 632 | accelerator.backward(loss) 633 | if ( 634 | step % args.gradient_accumulation_steps == 0 635 | or step == len(train_dataloader) - 1 636 | ): 637 | if args.max_grad_norm is not None: 638 | accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm) 639 | optimizer.step() 640 | lr_scheduler.step() 641 | optimizer.zero_grad() 642 | progress_bar.update(1) 643 | completed_steps += 1 644 | 645 | # add logging_steps 646 | if args.logging_steps > 0 and completed_steps % args.logging_steps == 0: 647 | logger.info( 648 | "completed_steps {} - loss: {:.8f}".format( 649 | completed_steps, 650 | loss.item(), 651 | ) 652 | ) 653 | accelerator.log({"loss": loss.item()}, step=completed_steps) 654 | 655 | if isinstance(checkpointing_steps, int): 656 | if completed_steps % checkpointing_steps == 0: 657 | output_dir = f"step_{completed_steps }" 658 | if args.output_dir is not None: 659 | output_dir = os.path.join(args.output_dir, output_dir) 660 | accelerator.save_state(output_dir) 661 | 662 | if completed_steps >= args.max_train_steps: 663 | break 664 | 665 | model.eval() 666 | samples_seen = 0 667 | with torch.no_grad(): 668 | for step, batch in enumerate(eval_dataloader): 669 | outputs = model(**batch) 670 | predictions = outputs.logits.argmax(dim=-1) 671 | predictions, references = accelerator.gather( 672 | (predictions, batch["labels"]) 673 | ) 674 | # If we are in a multiprocess environment, the last batch has duplicates 675 | if accelerator.num_processes > 1: 676 | if step == len(eval_dataloader): 677 | predictions = predictions[ 678 | : len(eval_dataloader.dataset) - samples_seen 679 | ] 680 | references = references[ 681 | : len(eval_dataloader.dataset) - samples_seen 682 | ] 683 | else: 684 | samples_seen += references.shape[0] 685 | metric.add_batch( 686 | predictions=predictions, 687 | references=references, 688 | ) 689 | 690 | eval_metric = metric.compute() 691 | all_metric.append(eval_metric["accuracy"]) 692 | 693 | logger.info(f"epoch {epoch}: {eval_metric}") 694 | 695 | if args.with_tracking: 696 | accelerator.log( 697 | { 698 | "accuracy": eval_metric, 699 | "train_loss": total_loss, 700 | "epoch": epoch, 701 | "step": completed_steps, 702 | }, 703 | ) 704 | 705 | if args.push_to_hub and epoch < args.num_train_epochs - 1: 706 | accelerator.wait_for_everyone() 707 | unwrapped_model = accelerator.unwrap_model(model) 708 | unwrapped_model.save_pretrained( 709 | args.output_dir, 710 | is_main_process=accelerator.is_main_process, 711 | save_function=accelerator.save, 712 | ) 713 | if accelerator.is_main_process: 714 | tokenizer.save_pretrained(args.output_dir) 715 | repo.push_to_hub( 716 | commit_message=f"Training in progress epoch {epoch}", 717 | blocking=False, 718 | auto_lfs_prune=True, 719 | ) 720 | 721 | if args.checkpointing_steps == "epoch": 722 | output_dir = f"epoch_{epoch}" 723 | if args.output_dir is not None: 724 | output_dir = os.path.join(args.output_dir, output_dir) 725 | accelerator.save_state(output_dir) 726 | 727 | if eval_metric["accuracy"] >= max_metric: 728 | max_metric = eval_metric["accuracy"] 729 | output_dir = "epoch_best" 730 | if args.output_dir is not None: 731 | output_dir = os.path.join(args.output_dir, output_dir) 732 | accelerator.wait_for_everyone() 733 | accelerator.unwrap_model(model).save_pretrained( 734 | output_dir, 735 | is_main_process=accelerator.is_main_process, 736 | save_function=accelerator.save, 737 | ) 738 | if accelerator.is_main_process: 739 | tokenizer.save_pretrained(output_dir) 740 | 741 | # if args.output_dir is not None: 742 | # accelerator.wait_for_everyone() 743 | # unwrapped_model = accelerator.unwrap_model(model) 744 | # unwrapped_model.save_pretrained( 745 | # args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save 746 | # ) 747 | # if accelerator.is_main_process: 748 | # tokenizer.save_pretrained(args.output_dir) 749 | # if args.push_to_hub: 750 | # repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True) 751 | 752 | if args.output_dir is not None: 753 | with open( 754 | os.path.join(args.output_dir, "all_results.json"), "w", encoding="utf-8" 755 | ) as f: 756 | json.dump( 757 | {"eval_max_accuracy": max(all_metric), "all_metric": all_metric}, f 758 | ) 759 | 760 | 761 | if __name__ == "__main__": 762 | main() 763 | --------------------------------------------------------------------------------