├── docs ├── images │ ├── pipeline.png │ └── performance.png └── index.html ├── tokenization ├── tokenize_wiki_bert.bash ├── tokenize_wiki103_bert.bash ├── tokenize_wiki103_roberta.bash ├── tokenize_wiki_roberta.bash ├── to_hdf5.py └── tokenize_dataset.py ├── data └── wiki │ ├── tools │ ├── segment_th.py │ ├── tokenize.sh │ └── remove_accent.py │ ├── install-tools.sh │ ├── get_data_cased_untokenized.bash │ └── get_data_cased.bash ├── configs ├── bert-12L-768H.json ├── bert-12L-512H.json ├── bert-6L-512H.json ├── roberta-12L-768H.json ├── roberta-6L-512H.json ├── bert-3L-768H.json ├── bert-6L-768H.json ├── roberta-3L-768H.json ├── roberta-6L-768H.json ├── roberta-12L-512H.json ├── bert_wiki.txt └── roberta_wiki.txt ├── LICENSE ├── .gitignore ├── model.py ├── README.md ├── data.py ├── param.py ├── flop_computation.py ├── run_grow_distributed.py ├── run_lm_distributed.py ├── ligo.py └── bert_model.py /docs/images/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VITA-Group/LiGO/HEAD/docs/images/pipeline.png -------------------------------------------------------------------------------- /docs/images/performance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VITA-Group/LiGO/HEAD/docs/images/performance.png -------------------------------------------------------------------------------- /tokenization/tokenize_wiki_bert.bash: -------------------------------------------------------------------------------- 1 | DATA_DIR=data/wiki-cased 2 | TOKENIZER=bert-base-uncased 3 | python tokenization/tokenize_dataset.py $DATA_DIR en.valid.raw $TOKENIZER 4 | python tokenization/tokenize_dataset.py $DATA_DIR en.test.raw $TOKENIZER 5 | python tokenization/tokenize_dataset.py $DATA_DIR en.train.raw $TOKENIZER 6 | -------------------------------------------------------------------------------- /tokenization/tokenize_wiki103_bert.bash: -------------------------------------------------------------------------------- 1 | DATA_DIR=data/wiki103-cased 2 | TOKENIZER=bert-base-uncased 3 | python tokenization/tokenize_dataset.py $DATA_DIR wiki.valid.raw $TOKENIZER 4 | python tokenization/tokenize_dataset.py $DATA_DIR wiki.test.raw $TOKENIZER 5 | python tokenization/tokenize_dataset.py $DATA_DIR wiki.train.raw $TOKENIZER 6 | -------------------------------------------------------------------------------- /tokenization/tokenize_wiki103_roberta.bash: -------------------------------------------------------------------------------- 1 | DATA_DIR=data/wiki103-cased 2 | TOKENIZER=roberta-base 3 | python tokenization/tokenize_dataset.py $DATA_DIR wiki.valid.raw $TOKENIZER 4 | python tokenization/tokenize_dataset.py $DATA_DIR wiki.test.raw $TOKENIZER 5 | python tokenization/tokenize_dataset.py $DATA_DIR wiki.train.raw $TOKENIZER 6 | -------------------------------------------------------------------------------- /tokenization/tokenize_wiki_roberta.bash: -------------------------------------------------------------------------------- 1 | DATA_DIR=data/wiki-cased-untokenized/ 2 | TOKENIZER=roberta-base 3 | python tokenization/tokenize_dataset.py $DATA_DIR en.valid.raw $TOKENIZER 4 | python tokenization/tokenize_dataset.py $DATA_DIR en.test.raw $TOKENIZER 5 | python tokenization/tokenize_dataset.py $DATA_DIR en.train.raw $TOKENIZER 6 | -------------------------------------------------------------------------------- /data/wiki/tools/segment_th.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import sys 9 | from pythainlp.tokenize import word_tokenize 10 | 11 | for line in sys.stdin.readlines(): 12 | line = line.rstrip('\n') 13 | print(' '.join(word_tokenize(line))) 14 | -------------------------------------------------------------------------------- /configs/bert-12L-768H.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertForMaskedLM" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 768, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 3072, 11 | "max_position_embeddings": 512, 12 | "num_attention_heads": 12, 13 | "num_hidden_layers": 12, 14 | "type_vocab_size": 2, 15 | "vocab_size": 30522 16 | } 17 | -------------------------------------------------------------------------------- /configs/bert-12L-512H.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertForMaskedLM" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 512, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 2048, 11 | "max_position_embeddings": 512, 12 | "num_attention_heads": 8, 13 | "num_hidden_layers": 12, 14 | "type_vocab_size": 2, 15 | "vocab_size": 30522 16 | } 17 | -------------------------------------------------------------------------------- /configs/bert-6L-512H.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertForMaskedLM" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 512, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 2048, 11 | "max_position_embeddings": 512, 12 | "num_attention_heads": 8, 13 | "num_hidden_layers": 6, 14 | "type_vocab_size": 2, 15 | "vocab_size": 30522, 16 | "tie_word_embeddings": false 17 | } 18 | -------------------------------------------------------------------------------- /configs/roberta-12L-768H.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "RobertaForMaskedLM" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 768, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 3072, 11 | "max_position_embeddings": 512, 12 | "num_attention_heads": 12, 13 | "num_hidden_layers": 12, 14 | "type_vocab_size": 2, 15 | "vocab_size": 50265 16 | } 17 | -------------------------------------------------------------------------------- /configs/roberta-6L-512H.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "RobertaForMaskedLM" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 512, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 2048, 11 | "max_position_embeddings": 512, 12 | "num_attention_heads": 8, 13 | "num_hidden_layers": 6, 14 | "type_vocab_size": 2, 15 | "vocab_size": 50265, 16 | "tie_word_embeddings": false 17 | } 18 | -------------------------------------------------------------------------------- /configs/bert-3L-768H.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertForMaskedLM" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 768, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 3072, 11 | "max_position_embeddings": 512, 12 | "num_attention_heads": 12, 13 | "num_hidden_layers": 3, 14 | "type_vocab_size": 2, 15 | "vocab_size": 30522, 16 | "tie_word_embeddings": false 17 | } 18 | -------------------------------------------------------------------------------- /configs/bert-6L-768H.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertForMaskedLM" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 768, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 3072, 11 | "max_position_embeddings": 512, 12 | "num_attention_heads": 12, 13 | "num_hidden_layers": 6, 14 | "type_vocab_size": 2, 15 | "vocab_size": 30522, 16 | "tie_word_embeddings": false 17 | } 18 | -------------------------------------------------------------------------------- /configs/roberta-3L-768H.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "RobertaForMaskedLM" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 768, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 3072, 11 | "max_position_embeddings": 512, 12 | "num_attention_heads": 12, 13 | "num_hidden_layers": 3, 14 | "type_vocab_size": 2, 15 | "vocab_size": 50265, 16 | "tie_word_embeddings": false 17 | } 18 | -------------------------------------------------------------------------------- /configs/roberta-6L-768H.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "RobertaForMaskedLM" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 768, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 3072, 11 | "max_position_embeddings": 512, 12 | "num_attention_heads": 12, 13 | "num_hidden_layers": 6, 14 | "type_vocab_size": 2, 15 | "vocab_size": 50265, 16 | "tie_word_embeddings": false 17 | } 18 | -------------------------------------------------------------------------------- /configs/roberta-12L-512H.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "RobertaForMaskedLM" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 512, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 2048, 11 | "max_position_embeddings": 512, 12 | "num_attention_heads": 8, 13 | "num_hidden_layers": 12, 14 | "type_vocab_size": 2, 15 | "vocab_size": 50265, 16 | "tie_word_embeddings": false 17 | } 18 | -------------------------------------------------------------------------------- /configs/bert_wiki.txt: -------------------------------------------------------------------------------- 1 | tokenizer_name = bert-base-uncased 2 | model_type = bert 3 | block_size = 126 4 | 5 | do_train = True 6 | train_data_file = data/wiki-cased/en.train.raw 7 | do_eval = True 8 | eval_data_file = data/wiki-cased/en.valid.raw 9 | col_data = True 10 | split_sent = True 11 | shuffle = True 12 | mlm = True 13 | 14 | per_gpu_train_batch_size = 64 15 | per_gpu_eval_batch_size = 64 16 | gradient_accumulation_steps = 1 17 | max_steps = 400000 18 | learning_rate = 2e-4 19 | weight_decay = 0.1 20 | warmup_steps = 10000 21 | 22 | logging_steps = 1000 23 | ckpt_steps = 1000 24 | 25 | should_continue = True 26 | -------------------------------------------------------------------------------- /configs/roberta_wiki.txt: -------------------------------------------------------------------------------- 1 | tokenizer_name = roberta-base 2 | model_type = roberta 3 | block_size = 126 4 | 5 | do_train = True 6 | train_data_file = data/wiki-cased-untokenized/en.train.raw 7 | do_eval = True 8 | eval_data_file = data/wiki-cased-untokenized/en.valid.raw 9 | col_data = True 10 | split_sent = True 11 | shuffle = True 12 | mlm = True 13 | 14 | per_gpu_train_batch_size = 64 15 | per_gpu_eval_batch_size = 64 16 | gradient_accumulation_steps = 1 17 | max_steps = 400000 18 | learning_rate = 2e-4 19 | weight_decay = 0.1 20 | warmup_steps = 10000 21 | 22 | logging_steps = 1000 23 | ckpt_steps = 1000 24 | 25 | should_continue = True 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 VITA 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /data/wiki/tools/tokenize.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | # Tokenize text data in various languages 9 | # Usage: e.g. cat wiki.ar | tokenize.sh ar 10 | 11 | set -e 12 | 13 | N_THREADS=8 14 | 15 | lg=$1 16 | TOOLS_PATH=$2 17 | 18 | # moses 19 | MOSES=$TOOLS_PATH/mosesdecoder 20 | REPLACE_UNICODE_PUNCT=$MOSES/scripts/tokenizer/replace-unicode-punctuation.perl 21 | NORM_PUNC=$MOSES/scripts/tokenizer/normalize-punctuation.perl 22 | REM_NON_PRINT_CHAR=$MOSES/scripts/tokenizer/remove-non-printing-char.perl 23 | TOKENIZER=$MOSES/scripts/tokenizer/tokenizer.perl 24 | 25 | # Chinese 26 | if [ "$lg" = "zh" ]; then 27 | $TOOLS_PATH/stanford-segmenter-*/segment.sh pku /dev/stdin UTF-8 0 | $REPLACE_UNICODE_PUNCT | $NORM_PUNC -l $lg | $REM_NON_PRINT_CHAR 28 | # Thai 29 | elif [ "$lg" = "th" ]; then 30 | cat - | $REPLACE_UNICODE_PUNCT | $NORM_PUNC -l $lg | $REM_NON_PRINT_CHAR | python $TOOLS_PATH/segment_th.py 31 | # Japanese 32 | elif [ "$lg" = "ja" ]; then 33 | cat - | $REPLACE_UNICODE_PUNCT | $NORM_PUNC -l $lg | $REM_NON_PRINT_CHAR | kytea -notags 34 | # other languages 35 | else 36 | cat - | $REPLACE_UNICODE_PUNCT | $NORM_PUNC -l $lg | $REM_NON_PRINT_CHAR | $TOKENIZER -no-escape -threads $N_THREADS -l $lg 37 | fi 38 | -------------------------------------------------------------------------------- /data/wiki/tools/remove_accent.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import sys 9 | import unicodedata 10 | import six 11 | 12 | 13 | def convert_to_unicode(text): 14 | """ 15 | Converts `text` to Unicode (if it's not already), assuming UTF-8 input. 16 | """ 17 | # six_ensure_text is copied from https://github.com/benjaminp/six 18 | def six_ensure_text(s, encoding='utf-8', errors='strict'): 19 | if isinstance(s, six.binary_type): 20 | return s.decode(encoding, errors) 21 | elif isinstance(s, six.text_type): 22 | return s 23 | else: 24 | raise TypeError("not expecting type '%s'" % type(s)) 25 | 26 | return six_ensure_text(text, encoding="utf-8", errors="ignore") 27 | 28 | 29 | def run_strip_accents(text): 30 | """ 31 | Strips accents from a piece of text. 32 | """ 33 | text = unicodedata.normalize("NFD", text) 34 | output = [] 35 | for char in text: 36 | cat = unicodedata.category(char) 37 | if cat == "Mn": 38 | continue 39 | output.append(char) 40 | return "".join(output) 41 | 42 | 43 | for line in sys.stdin: 44 | line = convert_to_unicode(line.rstrip()) 45 | line = run_strip_accents(line) 46 | print(u'%s' % line) 47 | -------------------------------------------------------------------------------- /data/wiki/install-tools.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | set -e 9 | 10 | # data path 11 | TOOLS_PATH=$1 12 | 13 | # tools 14 | MOSES_DIR=mosesdecoder 15 | FASTBPE_DIR=fastBPE 16 | FASTBPE=fast 17 | WMT16_SCRIPTS=wmt16-scripts 18 | 19 | # tools path 20 | mkdir -p $TOOLS_PATH 21 | 22 | # Copy the scripts to TOOLS_PATH 23 | cp -r data/wiki/tools/* $TOOLS_PATH 24 | 25 | 26 | # 27 | # Download and install tools 28 | # 29 | 30 | old=$(pwd) 31 | cd $TOOLS_PATH 32 | 33 | 34 | # Download Moses 35 | if [ ! -d "$MOSES_DIR" ]; then 36 | echo "Cloning Moses from GitHub repository..." 37 | git clone https://github.com/moses-smt/mosesdecoder.git 38 | fi 39 | 40 | # Download fastBPE 41 | if [ ! -d "$FASTBPE_DIR" ]; then 42 | echo "Cloning fastBPE from GitHub repository..." 43 | git clone https://github.com/glample/fastBPE 44 | fi 45 | 46 | # Compile fastBPE 47 | if [ ! -f "$FASTBPE_DIR/$FASTBPE" ]; then 48 | echo "Compiling fastBPE..." 49 | cd fastBPE 50 | g++ -std=c++11 -pthread -O3 fastBPE/main.cc -IfastBPE -o fast 51 | cd .. 52 | fi 53 | 54 | # Download Sennrich's tools 55 | if [ ! -d "$WMT16_SCRIPTS" ]; then 56 | echo "Cloning WMT16 preprocessing scripts..." 57 | git clone https://github.com/rsennrich/wmt16-scripts.git 58 | fi 59 | 60 | # Download WikiExtractor 61 | if [ ! -d wikiextractor ]; then 62 | echo "Cloning WikiExtractor from GitHub repository..." 63 | git clone https://github.com/attardi/wikiextractor.git 64 | cd wikiextractor 65 | git checkout e4abb4cbd019b0257824ee47c23dd163919b731b 66 | cd .. 67 | fi 68 | 69 | cd $old 70 | 71 | # # Chinese segmenter 72 | # if ! ls $TOOLS_PATH/stanford-segmenter-* 1> /dev/null 2>&1; then 73 | # echo "Stanford segmenter not found at $TOOLS_PATH/stanford-segmenter-*" 74 | # echo "Please install Stanford segmenter in $TOOLS_PATH" 75 | # exit 1 76 | # fi 77 | # 78 | # # Thai tokenizer 79 | # if ! python -c 'import pkgutil; exit(not pkgutil.find_loader("pythainlp"))'; then 80 | # echo "pythainlp package not found in python" 81 | # echo "Please install pythainlp (pip install pythainlp)" 82 | # exit 1 83 | # fi 84 | # 85 | -------------------------------------------------------------------------------- /data/wiki/get_data_cased_untokenized.bash: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # Copy frrom https://github.com/facebookresearch/XLM 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # 8 | 9 | # 10 | # Usage: ./get-data-wiki.sh $lg (en) 11 | # 12 | 13 | set -e 14 | 15 | lg=$1 # input language 16 | 17 | # data path 18 | WIKI_PATH=data/wiki-cased-untokenized 19 | MAIN_PATH=$WIKI_PATH 20 | 21 | # tools paths 22 | TOOLS_PATH=$MAIN_PATH/tools 23 | TOKENIZE=$TOOLS_PATH/tokenize.sh 24 | REMOVE_ACCENT=$TOOLS_PATH/remove_accent.py 25 | 26 | # Wiki data 27 | WIKI_DUMP_NAME=${lg}wiki-latest-pages-articles.xml.bz2 28 | WIKI_DUMP_LINK=https://dumps.wikimedia.org/${lg}wiki/latest/$WIKI_DUMP_NAME 29 | 30 | # install tools 31 | data/wiki/install-tools.sh $TOOLS_PATH 32 | 33 | # create Wiki paths 34 | mkdir -p $WIKI_PATH/bz2 35 | mkdir -p $WIKI_PATH/txt 36 | 37 | # download Wikipedia dump 38 | if [ ! -f $WIKI_PATH/bz2/enwiki-latest-pages-articles.xml.bz2 ]; then 39 | echo "Downloading $lg Wikipedia dump from $WIKI_DUMP_LINK ..." 40 | wget -c $WIKI_DUMP_LINK -P $WIKI_PATH/bz2/ 41 | echo "Downloaded $WIKI_DUMP_NAME in $WIKI_PATH/bz2/$WIKI_DUMP_NAME" 42 | fi 43 | 44 | # extract and tokenize Wiki data 45 | #cd $MAIN_PATH 46 | echo "*** Cleaning and tokenizing $lg Wikipedia dump ... ***" 47 | if [ ! -f $WIKI_PATH/txt/$lg.all.raw ]; then 48 | python $TOOLS_PATH/wikiextractor/WikiExtractor.py $WIKI_PATH/bz2/$WIKI_DUMP_NAME --processes 24 -q -o - \ 49 | | sed "/^\s*\$/d" \ 50 | | grep -v "^\$" \ 52 | | python $REMOVE_ACCENT \ 53 | > $WIKI_PATH/txt/$lg.all.raw 54 | fi 55 | echo "*** Not Tokenized ( but + accent-removal) $lg Wikipedia dump to $WIKI_PATH/txt/train.${lg} ***" 56 | 57 | # split into train / valid / test 58 | echo "*** Split into train / valid / test ***" 59 | split_data() { 60 | NLINES=`wc -l $1 | awk -F " " '{print $1}'`; 61 | NTRAIN=$((NLINES - 10000)); 62 | NVAL=$((NTRAIN + 5000)); 63 | cat $1 | head -$NTRAIN > $2; 64 | cat $1 | head -$NVAL | tail -5000 > $3; 65 | cat $1 | tail -5000 > $4; 66 | } 67 | split_data $WIKI_PATH/txt/$lg.all.raw $WIKI_PATH/txt/$lg.train.raw $WIKI_PATH/txt/$lg.valid.raw $WIKI_PATH/txt/$lg.test.raw 68 | 69 | # File structure 70 | mv $WIKI_PATH/txt/* $WIKI_PATH/ 71 | rm -rf $WIKI_PATH/bz2 72 | rm -rf $WIKI_PATH/txt 73 | -------------------------------------------------------------------------------- /data/wiki/get_data_cased.bash: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # Copy frrom https://github.com/facebookresearch/XLM 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # 8 | 9 | # 10 | # Usage: ./get-data-wiki.sh $lg (en) 11 | # 12 | 13 | set -e 14 | 15 | lg=$1 # input language 16 | 17 | # data path 18 | WIKI_PATH=data/wiki-cased 19 | MAIN_PATH=$WIKI_PATH 20 | 21 | # tools paths 22 | TOOLS_PATH=$MAIN_PATH/tools 23 | TOKENIZE=$TOOLS_PATH/tokenize.sh 24 | REMOVE_ACCENT=$TOOLS_PATH/remove_accent.py 25 | 26 | # Wiki data 27 | WIKI_DUMP_NAME=${lg}wiki-latest-pages-articles.xml.bz2 28 | WIKI_DUMP_LINK=https://dumps.wikimedia.org/${lg}wiki/latest/$WIKI_DUMP_NAME 29 | 30 | # install tools 31 | data/wiki/install-tools.sh $TOOLS_PATH 32 | 33 | # create Wiki paths 34 | mkdir -p $WIKI_PATH/bz2 35 | mkdir -p $WIKI_PATH/txt 36 | 37 | # download Wikipedia dump 38 | echo "Downloading $lg Wikipedia dump from $WIKI_DUMP_LINK ..." 39 | wget -c $WIKI_DUMP_LINK -P $WIKI_PATH/bz2/ 40 | echo "Downloaded $WIKI_DUMP_NAME in $WIKI_PATH/bz2/$WIKI_DUMP_NAME" 41 | 42 | # extract and tokenize Wiki data 43 | echo "*** Cleaning and tokenizing $lg Wikipedia dump ... ***" 44 | #python -m $TOOLS_PATH/wikiextractor/wikiextractor/WikiExtractor $WIKI_PATH/bz2/$WIKI_DUMP_NAME --processes 24 -q -o - \ 45 | if [ ! -f $WIKI_PATH/txt/$lg.all.raw ]; then 46 | python $TOOLS_PATH/wikiextractor/WikiExtractor.py $WIKI_PATH/bz2/$WIKI_DUMP_NAME --processes 24 -q -o - \ 47 | | sed "/^\s*\$/d" \ 48 | | grep -v "^\$" \ 50 | | $TOKENIZE $lg $TOOLS_PATH \ 51 | | python $REMOVE_ACCENT \ 52 | > $WIKI_PATH/txt/$lg.all.raw 53 | fi 54 | echo "*** Tokenized ( + accent-removal) $lg Wikipedia dump to $WIKI_PATH/txt/train.${lg} ***" 55 | 56 | # split into train / valid / test 57 | echo "*** Split into train / valid / test ***" 58 | split_data() { 59 | NLINES=`wc -l $1 | awk -F " " '{print $1}'`; 60 | NTRAIN=$((NLINES - 10000)); 61 | NVAL=$((NTRAIN + 5000)); 62 | cat $1 | head -$NTRAIN > $2; 63 | cat $1 | head -$NVAL | tail -5000 > $3; 64 | cat $1 | tail -5000 > $4; 65 | } 66 | split_data $WIKI_PATH/txt/$lg.all.raw $WIKI_PATH/txt/$lg.train.raw $WIKI_PATH/txt/$lg.valid.raw $WIKI_PATH/txt/$lg.test.raw 67 | 68 | # File structure 69 | mv $WIKI_PATH/txt/* $WIKI_PATH/ 70 | rm -rf $WIKI_PATH/bz2 71 | rm -rf $WIKI_PATH/txt 72 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /tokenization/to_hdf5.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import numpy as np 3 | import tqdm 4 | 5 | from transformers import AutoTokenizer 6 | 7 | 8 | def validate_hdf5(fname, tokenizer_name): 9 | print("--------------------------------------------") 10 | print("Start to valid the hdf5 file", fname + '.' + tokenizer_name + '.hdf5') 11 | 12 | with open(fname) as f: 13 | lines = [] 14 | for line in f: 15 | if 'wiki' in fname: 16 | # Wiki103: remove document title 17 | if line.startswith(' = '): 18 | continue 19 | # Full Wiki: Remove the too short lines. 20 | if len(line.strip().split(' ')) < 5: 21 | continue 22 | 23 | if len(line.strip()) == 0: 24 | # Always drop empty line 25 | continue 26 | lines.append(line) 27 | 28 | # Use the slow tokenizer to validate the results of the fast tokenizer. 29 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) 30 | 31 | h5_file = h5py.File(fname + '.' + tokenizer_name + '.hdf5', 'r') 32 | tokens = h5_file['tokens'] 33 | 34 | print("Start to check the first 10 lines:") 35 | ids = [] 36 | for line in lines[:10]: 37 | ids.extend(tokenizer.convert_tokens_to_ids(tokenizer.tokenize(line))) 38 | ids = np.array(ids) 39 | first_tokens = np.array(tokens[:len(ids)]) 40 | if np.array_equal(ids, first_tokens): 41 | print("PASS") 42 | else: 43 | print(' '.join(tokenizer.convert_ids_to_tokens(ids))) 44 | print() 45 | print(' '.join(tokenizer.convert_ids_to_tokens(first_tokens))) 46 | assert False, "FAIL" 47 | 48 | print("Start to check the last 10 lines:") 49 | ids = [] 50 | for line in lines[-10:]: 51 | ids.extend(tokenizer.convert_tokens_to_ids(tokenizer.tokenize(line))) 52 | ids = np.array(ids) 53 | last_tokens = np.array(tokens[-len(ids):]) 54 | if np.array_equal(ids, last_tokens): 55 | print("PASS") 56 | else: 57 | print(' '.join(tokenizer.convert_ids_to_tokens(ids))) 58 | print(' '.join(tokenizer.convert_ids_to_tokens(last_tokens))) 59 | assert False, "FAIL" 60 | print("--------------------------------------------") 61 | 62 | 63 | def to_hdf5(fname, tokenizer_name, validate=True): 64 | print("Process %s" % fname) 65 | 66 | h5_file = h5py.File(fname + '.' + tokenizer_name + '.hdf5', 'w') 67 | dset = h5_file.create_dataset("tokens", 68 | (0,), 69 | maxshape=(None,), 70 | dtype='int32') 71 | 72 | dump_interval = 1000000 73 | dump_iter = 0 74 | with open('%s.%s' % (fname, tokenizer_name)) as f: 75 | lines = 0 76 | tokens = [] 77 | for line in tqdm.tqdm(f): 78 | for token in map(int, line.split(' ')): 79 | tokens.append(token) 80 | if len(tokens) >= dump_interval: 81 | dset.resize((dump_iter + len(tokens),)) 82 | dset[dump_iter: dump_iter + len(tokens)] = tokens 83 | dump_iter += len(tokens) 84 | tokens = [] 85 | lines += 1 86 | 87 | dset.resize((dump_iter + len(tokens),)) 88 | dset[dump_iter: dump_iter + len(tokens)] = tokens 89 | dump_iter += len(tokens) 90 | 91 | assert len(dset) == dump_iter 92 | h5_file.close() 93 | 94 | if validate: 95 | validate_hdf5(fname, tokenizer_name) 96 | 97 | print() 98 | 99 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.nn import CrossEntropyLoss, MSELoss, SmoothL1Loss 6 | from torch import nn 7 | from transformers import ( 8 | AutoConfig, 9 | BertConfig 10 | ) 11 | 12 | # from transformers.modeling_bert import BertOnlyMLMHead 13 | # from transformers.models.bert.modeling_bert import BertOnlyMLMHead 14 | from transformers.models.roberta.modeling_roberta import RobertaForMaskedLM, RobertaLMHead 15 | from bert_model import BertForMaskedLM, BertOnlyMLMHead 16 | 17 | BertLayerNorm = torch.nn.LayerNorm 18 | 19 | 20 | # The GLUE function is copied from huggingface transformers: 21 | # https://github.com/huggingface/transformers/blob/c6acd246ec90857b70f449dcbcb1543f150821fc/src/transformers/activations.py 22 | def _gelu_python(x): 23 | """ Original Implementation of the gelu activation function in Google Bert repo when initially created. 24 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 25 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 26 | Also see https://arxiv.org/abs/1606.08415 27 | """ 28 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 29 | 30 | 31 | if torch.__version__ < "1.4.0": 32 | gelu = _gelu_python 33 | else: 34 | gelu = F.gelu 35 | 36 | 37 | class SimpleBertForMaskedLM(BertForMaskedLM): 38 | 39 | def __init__(self, config, args=None): 40 | super().__init__(config, args=args) 41 | 42 | def forward( 43 | self, 44 | input_ids=None, 45 | attention_mask=None, 46 | token_type_ids=None, 47 | position_ids=None, 48 | head_mask=None, 49 | inputs_embeds=None, 50 | masked_lm_labels=None, 51 | encoder_hidden_states=None, 52 | encoder_attention_mask=None, 53 | lm_labels=None, 54 | **kwargs 55 | ): 56 | outputs = self.bert( 57 | input_ids, 58 | attention_mask=attention_mask, 59 | token_type_ids=token_type_ids, 60 | position_ids=position_ids, 61 | head_mask=head_mask, 62 | inputs_embeds=inputs_embeds, 63 | encoder_hidden_states=encoder_hidden_states, 64 | encoder_attention_mask=encoder_attention_mask, 65 | ) 66 | sequence_output = outputs[0] 67 | 68 | prediction_scores = self.cls(sequence_output) 69 | loss_fct = CrossEntropyLoss() 70 | token_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) 71 | 72 | return {'loss': token_loss, 'lm_loss': token_loss} 73 | 74 | class SimpleRobertaForMaskedLM(RobertaForMaskedLM): 75 | 76 | def __init__(self, config, args=None): 77 | super().__init__(config) 78 | 79 | def forward( 80 | self, 81 | input_ids=None, 82 | attention_mask=None, 83 | token_type_ids=None, 84 | position_ids=None, 85 | head_mask=None, 86 | inputs_embeds=None, 87 | masked_lm_labels=None, 88 | encoder_hidden_states=None, 89 | encoder_attention_mask=None, 90 | lm_labels=None, 91 | **kwargs 92 | ): 93 | outputs = self.roberta( 94 | input_ids, 95 | attention_mask=attention_mask, 96 | token_type_ids=token_type_ids, 97 | position_ids=position_ids, 98 | head_mask=head_mask, 99 | inputs_embeds=inputs_embeds, 100 | encoder_hidden_states=encoder_hidden_states, 101 | encoder_attention_mask=encoder_attention_mask, 102 | ) 103 | sequence_output = outputs[0] 104 | 105 | prediction_scores = self.lm_head(sequence_output) 106 | loss_fct = CrossEntropyLoss() 107 | token_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) 108 | 109 | return {'loss': token_loss, 'lm_loss': token_loss} 110 | -------------------------------------------------------------------------------- /tokenization/tokenize_dataset.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyleft 2020 project COL. 3 | 4 | import argparse 5 | from pathlib import Path 6 | 7 | from transformers import AutoTokenizer 8 | import time 9 | 10 | from to_hdf5 import to_hdf5 11 | 12 | def tokenize_dataset(data_dir, fname, tokenizer_name, lines_are_sents=False): 13 | data_path = Path(data_dir) 14 | 15 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=True) 16 | 17 | f = open(data_path / fname) 18 | g = open((data_path / ('%s.%s' % (fname, tokenizer_name))), 'w') 19 | 20 | # Statistics 21 | dcmt_cnt = 0 22 | token_cnt = 0 23 | line_cnt = 0 24 | line_starts = [] 25 | 26 | # Logging and dumping hyper-parameters 27 | cache = '' 28 | log_interval = log_iter = 1000000 29 | dump_interval = dump_iter = 100000 30 | start_time = time.time() 31 | 32 | for i, line in enumerate(f): 33 | # Identify the start of documents, ignore it. 34 | if 'wiki103' in data_dir: 35 | if line.startswith(' = '): 36 | dcmt_cnt += 1 37 | continue 38 | elif 'wiki' in data_dir: 39 | if len(line.strip().split(' ')) == 1: 40 | dcmt_cnt += 1 41 | continue 42 | 43 | if 'wiki' in data_dir: 44 | # Remove too short lines. Book corpus does not need this. 45 | if len(line.strip().split(' ')) < 5: 46 | continue 47 | 48 | # Drop empty line (1) 49 | if len(line.strip()) == 0: 50 | continue 51 | 52 | tokenized_line = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(line)) 53 | # tokenized_line = tokenizer.encode(line, add_special_tokens=False) 54 | if len(tokenized_line) == 0: # Drop empty line (2) 55 | continue 56 | 57 | line_cnt += 1 58 | line_starts.append(token_cnt) 59 | if i < 5: 60 | print() 61 | print('Line:', line) 62 | print('Tokens:', ' '.join(tokenizer.convert_ids_to_tokens(tokenized_line))) 63 | token_cnt += len(tokenized_line) 64 | cache += ' '.join(map(str, tokenized_line)) + '\n' 65 | 66 | if (token_cnt + 1) > dump_iter: 67 | g.write(cache) 68 | cache = '' 69 | dump_iter += dump_interval 70 | 71 | if (token_cnt + 1) > log_iter: 72 | used_time = time.time() - start_time 73 | print("Process %d tokens in %d seconds, %0.4f tokens per second." % ( 74 | token_cnt, used_time, token_cnt / used_time)) 75 | log_iter += log_interval 76 | 77 | # Deal with the last remaining tokens. 78 | line_starts.append(token_cnt) 79 | g.write(cache) 80 | 81 | # Dump Line starts 82 | identifier = 'sent' if lines_are_sents else 'line' 83 | with open(data_path / ('%s.%s.%s' % (fname, tokenizer_name, identifier)), 'w') as f: 84 | for line_start in line_starts: 85 | f.write(str(line_start) + "\n") 86 | 87 | f.close() 88 | g.close() 89 | print(f"Documents: {dcmt_cnt}, Lines: {line_cnt}, Words: {token_cnt} in dataset {fname}") 90 | 91 | to_hdf5(str(data_path / fname), tokenizer_name) 92 | 93 | 94 | if __name__ == "__main__": 95 | parser = argparse.ArgumentParser() 96 | 97 | # Required parameters 98 | parser.add_argument( 99 | "datadir", default=None, type=str, help="The input training data file (a text file)." 100 | ) 101 | parser.add_argument( 102 | "fname", default=None, type=str, help="The input training data file (a text file)." 103 | ) 104 | parser.add_argument( 105 | "tokenizer_name", default=None, type=str, help="The input training data file (a text file)." 106 | ) 107 | parser.add_argument( 108 | "--lines-are-sents", action='store_true', 109 | help="Add this if the line are already segmented to sentences, instead of paragraphs." 110 | ) 111 | 112 | param = parser.parse_args() 113 | 114 | tokenize_dataset( 115 | param.datadir, 116 | param.fname, 117 | param.tokenizer_name, 118 | param.lines_are_sents, 119 | ) 120 | 121 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning to Grow Pretrained Models for Efficient Transformer Training 2 | 3 | [![License: MIT](https://img.shields.io/badge/License-MIT-green.svg)](https://opensource.org/licenses/MIT) 4 | 5 | The official implementation of ICLR 2023 paper [Learning to Grow Pretrained Models for Efficient Transformer Training](https://arxiv.org/abs/2303.00980). 6 | 7 | [Peihao Wang](https://peihaowang.github.io/)1, 8 | [Rameswar Panda](https://rpand002.github.io/)2, 9 | [Lucas Torroba Hennigen](https://ltorroba.github.io/)4, 10 | [Philip Greengard](http://www.columbia.edu/~pg2118/)3, 11 | [Leonid Karlinsky](https://scholar.google.com/citations?user=WbO7tjYAAAAJ&hl=en)2, 12 | [Rogerio Feris](http://rogerioferis.com/)2, 13 | [David Cox](https://mitibmwatsonailab.mit.edu/people/david-cox/)2, 14 | [Zhangyang (Atlas) Wang](https://vita-group.github.io/)1, 15 | [Yoon Kim](https://people.csail.mit.edu/yoonkim/)4 16 | 17 | 1University of Texas at Austin, 2MIT-IBM Watson Lab, 3Columbia University, 4MIT 18 | 19 | ![](./docs/images/pipeline.png) 20 | 21 | ## Getting Started 22 | 23 | ### Dependency 24 | 25 | To run our code, the following libraries which are required: 26 | 27 | ``` 28 | torch 29 | torchvision 30 | transformers==4.21.0 31 | tensorboardX 32 | 33 | # For GLUE evaluation 34 | sklearn 35 | 36 | # Fiass supports fast indexing. 37 | # The code has a torch-implemented GPU indexing, so do not worry if you could not install faiss. 38 | faiss-gpu>=1.6.3 39 | 40 | # Spacy is used in sentence segmentation where the sentences are the input the cross-modality matching model. 41 | spacy 42 | 43 | # A higher h5py version to support h5py.VirtualLayout 44 | h5py>=2.10.0 45 | ``` 46 | 47 | ### Data Preparation 48 | 49 | We re-use the data preparation pipeline provided by [Vokenization](https://github.com/airsplay/vokenization#vokenization-vokenization). 50 | 51 | **1. Download and Pre-Process Pure-Language Data** 52 | 53 | We provide scripts to get the English-Wikipeida dataset. 54 | 55 | The script to download and process wiki data are modified from [XLM](https://github.com/facebookresearch/XLM). Note that the data processing pipelines for BERT and RoBERTa are different as they use different tokenizers. 56 | 57 | To get data for training BERT, use the following command: 58 | ``` 59 | bash data/wiki/get_data_cased.bash en 60 | ``` 61 | 62 | For RoBERTa, it requires an untokenized version of English Wikipedia, so please use the following command: 63 | ``` 64 | bash data/wiki/get_data_cased_untokenized.bash en 65 | ``` 66 | 67 | **2. Tokenize Language Data** 68 | 69 | We next tokenize the language corpus. It would locally save three files: `.`, `..hdf5`, and `..line`. 70 | Use the following commands accordingly to tokenize the dataset: 71 | 72 | ``` 73 | # For BERT 74 | bash tokenization/tokenize_wiki_bert.bash 75 | 76 | # For RoBERTa 77 | bash tokenization/tokenize_wiki_roberta.bash 78 | ``` 79 | 80 | ## Usage 81 | 82 | Please use the following commands to pretrain BERT or RoBERTa: 83 | 84 | ### Training BERT from Scratch 85 | 86 | (6L, 512H) BERT 87 | 88 | ``` 89 | python run_lm_distributed.py --config configs/bert_wiki.txt --config_name configs/bert-6L-512H.json --output_dir --max_steps 400000 --warmup_steps 10000 --should_continue 90 | ``` 91 | 92 | (12L, 768H) BERT 93 | 94 | ``` 95 | python run_lm_distributed.py --config configs/bert_wiki.txt --config_name configs/bert-12L-768H.json --output_dir --max_steps 400000 --warmup_steps 10000 --should_continue 96 | ``` 97 | 98 | ### Training BERT with LiGO 99 | 100 | First train a LiGO operator using the following command: 101 | 102 | ``` 103 | python run_grow_distributed.py --config configs/bert_wiki.txt --config_name configs/bert-12L-768H.json --output_dir --tune_width --tune_depth --source_model_path --fuse_init_scheme stackbert_noisy rand --max_steps 100 --logging_steps 100 --ckpt_steps 100 --should_continue 104 | ``` 105 | 106 | Then use pre-trained LiGO operator to grow the model: 107 | 108 | ``` 109 | python run_lm_distributed.py --config configs/bert_wiki.txt --config_name configs/bert-12L-768H.json --output_dir --grow_scheme ligo --source_model_path --pretrained_ligo_path --fuse_init_scheme stackbert_noisy rand --learning_rate 2e-4 --warmup_steps 0 --should_continue 110 | ``` 111 | 112 | ### Training RoBERTa from Scratch 113 | 114 | (6L, 512H) RoBERTa 115 | 116 | ``` 117 | python run_lm_distributed.py --config configs/roberta_wiki.txt --config_name configs/roberta-6L-512H.json --per_gpu_train_batch_size 64 --gradient_accumulation_steps 4 --learning_rate 2e-4 --output_dir --should_continue 118 | ``` 119 | 120 | (12L, 512H) RoBERTa 121 | 122 | ``` 123 | python run_lm_distributed.py --config configs/roberta_wiki.txt --config_name configs/roberta-12L-768H.json --per_gpu_train_batch_size 64 --gradient_accumulation_steps 4 --learning_rate 2e-4 --output_dir --should_continue 124 | ``` 125 | 126 | Note that the argument `--gradient_accumulation_steps 4` is necessary to gaurantee the batch size of RoBERTa is 4 times of BERT. One can use 4 times number of GPUs to achieve the same batch size. 127 | 128 | ### Training RoBERTa with LiGO 129 | 130 | ``` 131 | # Train LiGO 132 | python run_grow_distributed.py --config configs/roberta_wiki.txt --config_name configs/roberta-12L-768H.json --per_gpu_train_batch_size 64 --gradient_accumulation_steps 4 --learning_rate 2e-4 --output_dir --tune_width --tune_depth --source_model_path --fuse_init_scheme stackbert_noisy rand --max_steps 100 --logging_steps 100 --ckpt_steps 100 --should_continue 133 | 134 | # Apply pre-trained LiGO operator to grow the model 135 | python vlm/run_lm_distributed.py --config configs/roberta_wiki.txt --config_name configs/roberta-12L-768H.json --per_gpu_train_batch_size 64 --gradient_accumulation_steps 4 --output_dir --grow_scheme ligo --source_model_path --pretrained_ligo_path --fuse_init_scheme stackbert_noisy rand --learning_rate 2e-4 --warmup_steps 10000 --should_continue 136 | ``` 137 | 138 | ## Citation 139 | 140 | This repository is based on the project [Vokenization](https://github.com/airsplay/vokenization#vokenization-vokenization). 141 | If you find this work or our work helpful for your own research, please cite our paper. 142 | 143 | ``` 144 | @inproceedings{wang2023learning, 145 | title={Learning to grow pretrained models for efficient transformer training}, 146 | author={Wang, Peihao and Panda, Rameswar and Hennigen, Lucas Torroba and Greengard, Philip and Karlinsky, Leonid and Feris, Rogerio and Cox, David Daniel and Wang, Zhangyang and Kim, Yoon}, 147 | booktitle={International Conference on Learning Representations}, 148 | year={2023}, 149 | url={https://openreview.net/forum?id=cDYRS5iZ16f}, 150 | } 151 | ``` 152 | 153 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | import random 4 | 5 | import h5py 6 | import torch 7 | from torch.utils.data import DataLoader, Dataset 8 | import tqdm 9 | 10 | 11 | class CoLDataset(Dataset): 12 | IGNORE_ID = -100 13 | sent_strategy = 'first' 14 | 15 | def __init__(self, file_path, tokenizer_name, tokenizer, block_size=512, 16 | split_sent=False, voken_dir=None, suffix=None, verbose=False, 17 | voken_ablation=None): 18 | 19 | # Open token's hdf5 20 | token_path = file_path + '.' + tokenizer_name + '.hdf5' 21 | assert os.path.isfile(token_path) 22 | if verbose: 23 | print("-------- Load Data -------") 24 | print("Load tokens from", token_path) 25 | self.token_hdf5 = h5py.File(token_path, 'r') 26 | self.tokenizer = tokenizer 27 | self.tokens = self.token_hdf5['tokens'] 28 | self.verbose = verbose 29 | self.voken_ablation = voken_ablation 30 | self._iter_cnt = 0 31 | 32 | # Open voken's hdf5 and load voken ids 33 | if voken_dir is not None: 34 | assert suffix is not None, 'Please provide suffix of the voken, e.g., vg_nococo.5000.' 35 | self.sent_level = 'sent' in voken_dir 36 | dset_fname = os.path.split(file_path)[-1] 37 | voken_path = os.path.join(voken_dir, f"{dset_fname}.{suffix}.hdf5") 38 | voken_ids_path = os.path.join(voken_dir, f"{dset_fname}.{suffix}.ids") 39 | if verbose: 40 | print("Load vokens from", voken_path) 41 | self.voken_hdf5 = h5py.File(voken_path, 'r') 42 | self.vokens = self.voken_hdf5['vokens'] 43 | assert len(self.vokens) == len(self.tokens) 44 | self._voken_ids = list( 45 | map(lambda x: x.strip(), 46 | open(voken_ids_path).readlines()) 47 | ) 48 | if verbose: 49 | print("\t with voken size", self.voken_size) 50 | print("\t top 5 voken ids are:", self._voken_ids[:5]) 51 | else: 52 | self.vokens = None 53 | 54 | # Split for every block_size tokens 55 | # The last block without full length will be dropped. 56 | num_tokens = len(self.tokens) 57 | self.starts = list(range(0, num_tokens, block_size)) 58 | self.batches = list(zip(self.starts[:-1], self.starts[1:])) 59 | 60 | manual_filtered =False 61 | if "en.train.raw" in file_path and tokenizer_name == "bert-base-uncased": 62 | self.batches = manual_filter(self.batches) 63 | if verbose: 64 | print("Data: Mannually filter the range for counties.") 65 | manual_filtered = True 66 | 67 | # batch_info 68 | if verbose: 69 | print("Split sent with block size", block_size) 70 | print(f"Total batches: {len(self.batches)}") 71 | print(f"Total tokens: {len(self.tokens)}") 72 | if voken_dir is not None: 73 | print(f"Total vokens: {len(self.vokens)}") 74 | if voken_ablation is not None: 75 | print("The model will process voken ablation strategy:", voken_ablation) 76 | print() 77 | 78 | block_check(self.batches, block_size, fixed_size=True, manual_filtered=manual_filtered) 79 | if self.voken_ablation == 'token': 80 | self._voken_ids = list(range(30522)) 81 | 82 | @property 83 | def voken_size(self): 84 | return len(self._voken_ids) 85 | 86 | @property 87 | def voken_ids(self): 88 | return copy.copy(self._voken_ids) 89 | 90 | def assert_equal_vokens(self, dataset): 91 | assert self.voken_size == dataset.voken_size 92 | for vid, vid1 in zip(self.voken_ids, dataset.voken_ids): 93 | assert vid == vid1 94 | 95 | def __len__(self): 96 | return len(self.batches) - 1 97 | 98 | def __getitem__(self, item): 99 | token_start, token_end = self.batches[item] 100 | if self._iter_cnt < 5 and self.verbose: 101 | print(f"Data Loader: data iteration {self._iter_cnt}, with range {token_start} to {token_end}.") 102 | self._iter_cnt += 1 103 | tokens = list(self.tokens[token_start: token_end]) 104 | token_tensor = torch.tensor( 105 | self.tokenizer.build_inputs_with_special_tokens(tokens), 106 | dtype=torch.long) 107 | if self.vokens is not None: 108 | vokens = list(self.vokens[token_start: token_end]) 109 | 110 | vokens = self.maybe_do_sent_level(vokens) 111 | vokens = self.maybe_do_ablation_study(vokens, tokens) 112 | 113 | voken_tensor = torch.tensor( 114 | [self.IGNORE_ID] + vokens + [self.IGNORE_ID], 115 | dtype=torch.long 116 | ) 117 | 118 | return token_tensor, voken_tensor 119 | else: 120 | return token_tensor 121 | 122 | def maybe_do_sent_level(self, vokens): 123 | if not self.sent_level: 124 | return vokens 125 | else: 126 | if self.sent_strategy == 'all': 127 | vokens = [ 128 | (-voken-1 if voken < 0 else voken) 129 | for voken in vokens 130 | ] 131 | elif self.sent_strategy == 'first': 132 | vokens = [ 133 | (self.IGNORE_ID if voken < 0 else voken) 134 | for voken in vokens 135 | ] 136 | return vokens 137 | 138 | def maybe_do_ablation_study(self, vokens, tokens): 139 | if self.voken_ablation is None: 140 | return vokens 141 | else: 142 | if self._iter_cnt < 5 and self.verbose: 143 | print("Before voken ablation: ", vokens) 144 | if self.voken_ablation == 'random': 145 | vokens = [random.randint(0, self.voken_size - 1) 146 | for _ in range(len(vokens))] 147 | elif self.voken_ablation == 'shuffle': 148 | random.shuffle(vokens) 149 | elif self.voken_ablation == 'reverse': 150 | vokens = vokens[::-1] 151 | elif self.voken_ablation == 'token': 152 | vokens = tokens 153 | if self._iter_cnt < 5 and self.verbose: 154 | print("After voken ablation: ", vokens) 155 | return vokens 156 | 157 | def get_item_info(self, item): 158 | token_start = self.batches[item] 159 | token_end = self.batches[item + 1] 160 | return token_start, token_end 161 | 162 | def __del__(self): 163 | self.token_hdf5.close() 164 | if self.vokens is not None: 165 | self.voken_hdf5.close() 166 | 167 | 168 | FORBIDDEN_RANGE = ( 169 | 119314944, # Start of iter 3700 170 | 187053048 # End of iter 5800 171 | ) 172 | 173 | 174 | def intersect(x, y): 175 | x1, x2 = x 176 | y1, y2 = y 177 | if x2 <= y1 or x2 >= y2: 178 | # Case 1: [ x )[ y ) 179 | # Case 2: [ y )[ x ) 180 | return False 181 | return True 182 | 183 | 184 | def manual_filter(batches): 185 | batches = list(filter( 186 | lambda x: not intersect(x, FORBIDDEN_RANGE), 187 | batches 188 | )) 189 | return batches 190 | 191 | 192 | def block_check(batches, block_size, fixed_size=False, manual_filtered=False): 193 | """ 194 | Check whether the batches satisfy following requirements. 195 | 1. Monotonic 196 | 2. Mutually exclusive 197 | 3. Range < block_size 198 | """ 199 | last_end = 0 200 | for start_token, end_token in batches: 201 | assert last_end <= start_token 202 | if fixed_size: 203 | assert (end_token - start_token) == block_size, 'len([%d, %d)) != %d' % (start_token, end_token, block_size) 204 | else: 205 | assert (end_token - start_token) <= block_size, 'len([%d, %d)) > %d' % (start_token, end_token, block_size) 206 | if manual_filtered: 207 | assert not intersect((start_token, end_token), FORBIDDEN_RANGE) 208 | last_end = end_token 209 | 210 | 211 | def get_voken_feats(dataset: CoLDataset, feat_dir: str): 212 | """ 213 | Load pre-extracted visual features regarding img_ids of vokens. 214 | """ 215 | set2id2feat = {} 216 | voken_feats = [] 217 | for voken_id in dataset.voken_ids: 218 | voken_img_set, voken_img_id = voken_id.split('/') 219 | if voken_img_set not in set2id2feat: 220 | img_ids = list(map( 221 | lambda x: x.rstrip(), 222 | open(os.path.join(feat_dir, f"{voken_img_set}.ids")) 223 | )) 224 | img_feats = h5py.File( 225 | os.path.join(feat_dir, f"{voken_img_set}.hdf5"), 'r' 226 | )['keys'][:] 227 | id2feat = {} 228 | assert len(img_ids) == len(img_feats) 229 | for img_id, img_feat in zip(img_ids, img_feats): 230 | id2feat[img_id] = img_feat 231 | set2id2feat[voken_img_set] = id2feat 232 | voken_feats.append(set2id2feat[voken_img_set][voken_img_id]) 233 | return voken_feats 234 | 235 | 236 | 237 | -------------------------------------------------------------------------------- /param.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import configargparse 3 | 4 | 5 | def process_args(): 6 | # parser = argparse.ArgumentParser() 7 | parser = configargparse.ArgumentParser() 8 | 9 | # Path to load default configs 10 | parser.add_argument('--config', is_config_file=True, help='Config file path') 11 | 12 | # Datasets 13 | parser.add_argument( 14 | "--train_data_file", default=None, type=str, 15 | help="The input training data file (a text file).") 16 | parser.add_argument( 17 | "--eval_data_file", default=None, type=str, 18 | help="An optional input evaluation data file to evaluate the perplexity on (a text file).") 19 | parser.add_argument("--do_train", action="store_true", help="Whether to run training.") 20 | parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.") 21 | 22 | # Data loader 23 | parser.add_argument("--col_data", action="store_true", help="Using the specific dataset object in data.py") 24 | parser.add_argument("--split_sent", action="store_true", help="Overwrite the cached training and evaluation sets") 25 | parser.add_argument("--shuffle", action="store_true", help="Shuffle the training dataset") 26 | parser.add_argument( 27 | "--block_size", default=-1, type=int, 28 | help="Optional input sequence length after tokenization." 29 | "The training dataset will be truncated in block of this size for training." 30 | "Default to the model max input length for single sentence inputs (take into account special tokens).", 31 | ) 32 | 33 | # Logging and Saving 34 | parser.add_argument("--logging_steps", type=int, default=500, help="Log every X updates steps.") 35 | parser.add_argument("--ckpt_steps", type=int, default=1000, help="Checkpoint every X updates steps.") 36 | parser.add_argument( 37 | "--output_dir", type=str, required=True, 38 | help="The output directory where the model predictions and checkpoints will be written.",) 39 | parser.add_argument( 40 | "--overwrite_output_dir", action="store_true", 41 | help="Overwrite the content of the output directory") 42 | 43 | # Model types 44 | parser.add_argument( 45 | "--model_type", type=str, help="The model architecture to be trained or fine-tuned.",) 46 | parser.add_argument( 47 | "--should_continue", action="store_true", help="Whether to continue from latest checkpoint in output_dir") 48 | parser.add_argument( 49 | "--model_name_or_path", default=None, type=str, 50 | help="The model checkpoint for weights initialization. Leave None if you want to train a model from scratch.",) 51 | parser.add_argument( 52 | "--config_name", default=None, type=str, required=True, 53 | help="Optional pretrained config name or path if not the same as model_name_or_path. If both are None, initialize a new config.",) 54 | parser.add_argument( 55 | "--tokenizer_name", default=None, type=str, 56 | help="Optional pretrained tokenizer name or path if not the same as model_name_or_path. If both are None, initialize a new tokenizer.",) 57 | parser.add_argument( 58 | "--cache_dir", default=None, type=str, 59 | help="Optional directory to store the pre-trained models downloaded from s3 (instead of the default one)",) 60 | parser.add_argument( 61 | "--overwrite_cache", action="store_true", 62 | help="Overwrite the cached training and evaluation sets") 63 | 64 | # MLM tasks 65 | parser.add_argument( 66 | "--mlm", action="store_true", help="Train with masked-language modeling loss instead of language modeling.") 67 | parser.add_argument( 68 | "--mlm_probability", type=float, default=0.15, help="Ratio of tokens to mask for masked language modeling loss") 69 | parser.add_argument( 70 | "--mlm_ratio", type=float, default=1., help="The ratio of mlm loss in the total loss.") 71 | 72 | # Model growth params 73 | parser.add_argument( 74 | "--source_model_path", nargs='*', default=None, type=str, 75 | help="Path to load the source ckpt.") 76 | parser.add_argument( 77 | "--grow_scheme", default='none', type=str, 78 | choices=['none', 'ligo'], 79 | help="Method to grow the model: [none, ligo]") 80 | parser.add_argument("--tune_depth", action='store_true', default=False) 81 | parser.add_argument("--tune_width", action='store_true', default=False) 82 | parser.add_argument( 83 | "--fuse_init_scheme", nargs='*', default=['rand'], type=str, 84 | choices=['rand', 'rand_softmax', 'stackbert', 'stackbert_noisy', 'sel', 'sel_noisy'] 85 | help="Initialization of LiGO operator." 86 | ) 87 | parser.add_argument( 88 | "--fuse_init_noise", nargs='*', type=float, default=[0.03], 89 | help="Noise scale to randomly initialize LiGO operator." 90 | ) 91 | parser.add_argument("--fuse_tie_param", action='store_true', default=True, help="Turn on parameter tying for LiGO.") 92 | parser.add_argument("--no_fuse_tie_param", action='store_false', dest='fuse_tie_param', default=False, 93 | help="Turn off parameter tying for LiGO.") 94 | parser.set_defaults(fuse_tie_param=True) 95 | parser.add_argument("--tune_small_model", action='store_true', default=False, 96 | help="Extra feature: Enabling tuning of small model parameters when training LiGO.") 97 | parser.add_argument("--tune_residual", action='store_true', default=False, 98 | help="Extra feature: Adding bias terms in LiGO operator.") 99 | parser.add_argument("--tune_residual_noise", type=float, default=0.01, 100 | help="Extra feature: Noise scale to initialize bias terms in LiGO operator.") 101 | parser.add_argument("--learning_rate_res", default=None, type=float, help="Extra feature: The initial learning rate for learning bias in LiGO.") 102 | parser.add_argument("--weight_decay_res", default=None, type=float, help="Extra feature: Weight decay for learning bias in LiGO.") 103 | 104 | parser.add_argument( 105 | "--pretrained_ligo_path", default=None, type=str, 106 | help="Path to load the checkpoint of LiGO parameter.") 107 | 108 | # Batch Size and Training Steps 109 | parser.add_argument("--seed", type=int, default=95, help="random seed for initialization") 110 | parser.add_argument("--per_gpu_train_batch_size", default=4, type=int, help="Batch size per GPU/CPU for training.") 111 | parser.add_argument("--per_gpu_eval_batch_size", default=4, type=int, help="Batch size per GPU/CPU for evaluation.") 112 | parser.add_argument("--gradient_accumulation_steps", type=int, default=1, 113 | help="Number of updates steps to accumulate before performing a backward/update pass.",) 114 | parser.add_argument("--max_steps", default=100000, type=int, 115 | help="Total number of training steps to perform. Override num_train_epochs.",) 116 | 117 | # Layer drop 118 | parser.add_argument("--layer_drop", action='store_true', default=False, help="Turn on layer drop.") 119 | parser.add_argument("--layer_drop_rate", type=float, default=False, help="The drop propability to drop a layer.") 120 | parser.add_argument("--layer_drop_lin_decay", action='store_true', default=False, 121 | help="Turn on linear decay of survival prob. The --layer_drop_rate will specify the rate for the last layer") 122 | 123 | # Token drop 124 | parser.add_argument("--token_drop", action='store_true', default=False, help="Turn on token drop.") 125 | parser.add_argument("--token_drop_rate", type=float, default=False, help="The drop propability to drop a layer.") 126 | parser.add_argument("--token_drop_start", type=int, default=2, help="The layer index to separate tokens") 127 | parser.add_argument("--token_drop_end", type=int, default=-1, help="The layer index to merge tokens") 128 | 129 | # Optimizer 130 | parser.add_argument("--lamb", action="store_true", help='Use the LAMB optimizer in apex') 131 | parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.") 132 | parser.add_argument("--warmup_ratio", default=0., type=float, help="Linear warmup over warmup_steps.") 133 | parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") 134 | parser.add_argument("--weight_decay", default=0.01, type=float, help="Weight decay if we apply some.") 135 | parser.add_argument("--adam_epsilon", default=1e-6, type=float, help="Epsilon for Adam optimizer.") 136 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 137 | 138 | # Scheduler 139 | parser.add_argument("--scheduler_type", default='linear', type=str, help="Type of lr scheduler.", choices=['linear', 'cosine', 'poly']) 140 | parser.add_argument("--scheduler_cosine_cycles", default=0.5, type=float, help="Number of cycles for cosine lr scheduler.") 141 | parser.add_argument("--scheduler_poly_power", default=1.0, type=float, help="Power of polynomial lr scheduler.") 142 | 143 | # Distributed Training 144 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 145 | parser.add_argument("--nodes", type=int, default=1) 146 | parser.add_argument("--nr", type=int, default=0) 147 | 148 | # Half Precision 149 | parser.add_argument( 150 | "--fp16", action="store_true", 151 | help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",) 152 | parser.add_argument( 153 | "--fp16_opt_level", type=str, default="O1", 154 | help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 155 | "See details at https://nvidia.github.io/apex/amp.html",) 156 | 157 | return parser 158 | -------------------------------------------------------------------------------- /flop_computation.py: -------------------------------------------------------------------------------- 1 | """Computes the flops needed for training/running transformer networks.""" 2 | 3 | import collections 4 | 5 | # We checked this code with TensorFlow"s FLOPs counting, although we had to 6 | # correct for this issue: https://github.com/tensorflow/tensorflow/issues/22071 7 | # Assumptions going into the FLOPs counting 8 | # - An "operation" is a mathematical operation, not a machine instruction. So 9 | # an "exp" takes one opp like and add, even though in practice an exp 10 | # might be slower. This is not too bad an assumption because 11 | # matrix-multiplies dominate the compute for most models, so minor details 12 | # about activation functions don"t matter too much. Similarly, we count 13 | # matrix-multiplies as 2*m*n flops instead of m*n, as one might if 14 | # if considering fused multiply-add ops. 15 | # - Backward pass takes the same number of FLOPs as forward pass. No exactly 16 | # right (e.g., for softmax cross entropy loss the backward pass is faster). 17 | # Importantly, it really is the same for matrix-multiplies, which is most of 18 | # the compute anyway. 19 | # - We assume "dense" embedding lookups (i.e., multiplication by a one-hot 20 | # vector). On some hardware accelerators, these dense operations are 21 | # actually faster than sparse lookups. 22 | # Please open a github issue if you spot a problem with this code! 23 | 24 | # I am not sure if the below constants are 100% right, but they are only applied 25 | # to O(hidden_size) activations, which is generally a lot less compute than the 26 | # matrix-multiplies, which are O(hidden_size^2), so they don't affect the total 27 | # number of FLOPs much. 28 | 29 | # random number, >=, multiply activations by dropout mask, multiply activations 30 | # by correction (1 / (1 - dropout_rate)) 31 | DROPOUT_FLOPS = 4 32 | 33 | # compute mean activation (sum), computate variance of activation 34 | # (square and sum), bias (add), scale (multiply) 35 | LAYER_NORM_FLOPS = 5 36 | 37 | # GELU: 0.5 * x * (1 + tanh(sqrt(2 / np.pi) * (x + 0.044715 * pow(x, 3)))) 38 | ACTIVATION_FLOPS = 8 39 | 40 | # max/substract (for stability), exp, sum, divide 41 | SOFTMAX_FLOPS = 5 42 | 43 | 44 | class TransformerHparams(object): 45 | """Computes the train/inference FLOPs for transformers.""" 46 | 47 | def __init__(self, h, l, s=512, v=30522, e=None, i=None, heads=None, 48 | head_size=None, output_frac=0.15625, sparse_embed_lookup=False, 49 | decoder=False): 50 | self.h = h # hidden size 51 | self.l = l # number of layers 52 | self.s = s # sequence length 53 | self.v = v # vocab size 54 | self.e = h if e is None else e # embedding size 55 | self.i = h * 4 if i is None else i # intermediate size 56 | self.kqv = h if head_size is None else head_size * heads # attn proj sizes 57 | self.heads = max(h // 64, 1) if heads is None else heads # attention heads 58 | self.output_frac = output_frac # percent of tokens using an output softmax 59 | self.sparse_embed_lookup = sparse_embed_lookup # sparse embedding lookups 60 | self.decoder = decoder # decoder has extra attn to encoder states 61 | 62 | def get_block_flops(self, seq_len=None): 63 | seq_len = self.s if seq_len is None else seq_len 64 | """Get the forward-pass FLOPs for a single transformer block.""" 65 | attn_mul = 2 if self.decoder else 1 66 | block_flops = dict( 67 | kqv=3 * 2 * self.h * self.kqv * attn_mul, 68 | kqv_bias=3 * self.kqv * attn_mul, 69 | attention_scores=2 * self.kqv * seq_len * attn_mul, 70 | attn_softmax=SOFTMAX_FLOPS * seq_len * self.heads * attn_mul, 71 | attention_dropout=DROPOUT_FLOPS * seq_len * self.heads * attn_mul, 72 | attention_scale=seq_len * self.heads * attn_mul, 73 | attention_weighted_avg_values=2 * self.h * seq_len * attn_mul, 74 | attn_output=2 * self.h * self.h * attn_mul, 75 | attn_output_bias=self.h * attn_mul, 76 | attn_output_dropout=DROPOUT_FLOPS * self.h * attn_mul, 77 | attn_output_residual=self.h * attn_mul, 78 | attn_output_layer_norm=LAYER_NORM_FLOPS * attn_mul, 79 | intermediate=2 * self.h * self.i, 80 | intermediate_act=ACTIVATION_FLOPS * self.i, 81 | intermediate_bias=self.i, 82 | output=2 * self.h * self.i, 83 | output_bias=self.h, 84 | output_dropout=DROPOUT_FLOPS * self.h, 85 | output_residual=self.h, 86 | output_layer_norm=LAYER_NORM_FLOPS * self.h, 87 | ) 88 | return sum(block_flops.values()) * seq_len 89 | 90 | def get_embedding_flops(self, output=False, seq_len=None): 91 | seq_len = self.s if seq_len is None else seq_len 92 | """Get the forward-pass FLOPs the transformer inputs or output softmax.""" 93 | embedding_flops = {} 94 | if output or (not self.sparse_embed_lookup): 95 | embedding_flops["main_multiply"] = 2 * self.e * self.v 96 | # input embedding post-processing 97 | if not output: 98 | embedding_flops.update(dict( 99 | tok_type_and_position=2 * self.e * (seq_len + 2), 100 | add_tok_type_and_position=2 * self.e, 101 | emb_layer_norm=LAYER_NORM_FLOPS * self.e, 102 | emb_dropout=DROPOUT_FLOPS * self.e 103 | )) 104 | # projection layer if e != h 105 | if self.e != self.h or output: 106 | embedding_flops.update(dict( 107 | hidden_kernel=2 * self.h * self.e, 108 | hidden_bias=self.e if output else self.h 109 | )) 110 | # extra hidden layer and output softmax 111 | if output: 112 | embedding_flops.update(dict( 113 | hidden_activation=ACTIVATION_FLOPS * self.e, 114 | hidden_layernorm=LAYER_NORM_FLOPS * self.e, 115 | output_softmax=SOFTMAX_FLOPS * self.v, 116 | output_target_word=2 * self.v 117 | )) 118 | return self.output_frac * sum(embedding_flops.values()) * seq_len 119 | return sum(embedding_flops.values()) * seq_len 120 | 121 | def get_binary_classification_flops(self, seq_len=None): 122 | seq_len = self.s if seq_len is None else seq_len 123 | classification_flops = dict( 124 | hidden=2 * self.h * self.h, 125 | hidden_bias=self.h, 126 | hidden_act=ACTIVATION_FLOPS * self.h, 127 | logits=2 * self.h 128 | ) 129 | return sum(classification_flops.values()) * seq_len 130 | 131 | def get_train_flops(self, batch_size, train_steps, seq_len=None, discriminator=False): 132 | """Get the FLOPs for pre-training the transformer.""" 133 | # 2* for forward/backward pass 134 | return 2 * batch_size * train_steps * ( 135 | (self.l * self.get_block_flops(seq_len=seq_len)) + 136 | self.get_embedding_flops(output=False, seq_len=seq_len) + 137 | (self.get_binary_classification_flops(seq_len=seq_len) if discriminator else 138 | self.get_embedding_flops(output=True, seq_len=seq_len)) 139 | ) 140 | 141 | def get_infer_flops(self): 142 | """Get the FLOPs for running inference with the transformer on a 143 | classification task.""" 144 | return ((self.l * self.get_block_flops()) + 145 | self.get_embedding_flops(output=False) + 146 | self.get_binary_classification_flops()) 147 | 148 | class FusedTransformerHparams(TransformerHparams): 149 | 150 | def __init__(self, layer_small, layer_hidden, h, l, s=512, v=30522, e=None, i=None, heads=None, 151 | head_size=None, output_frac=0.15625, sparse_embed_lookup=False, 152 | decoder=False): 153 | 154 | super().__init__(h=h, l=l, s=s, v=v, e=e, i=i, heads=heads, 155 | head_size=head_size, output_frac=output_frac, sparse_embed_lookup=sparse_embed_lookup, 156 | decoder=decoder) 157 | self.ls = layer_small 158 | self.hs = layer_hidden 159 | 160 | def get_block_flops(self, seq_len=None): 161 | seq_len = self.s if seq_len is None else seq_len 162 | """Get the forward-pass FLOPs for a single transformer block.""" 163 | attn_mul = 2 if self.decoder else 1 164 | block_flops = dict( 165 | kqv=3 * 2 * self.h * self.kqv * attn_mul, 166 | kqv_bias=3 * self.kqv * attn_mul, 167 | attention_scores=2 * self.kqv * seq_len * attn_mul, 168 | attn_softmax=SOFTMAX_FLOPS * seq_len * self.heads * attn_mul, 169 | attention_dropout=DROPOUT_FLOPS * seq_len * self.heads * attn_mul, 170 | attention_scale=seq_len * self.heads * attn_mul, 171 | attention_weighted_avg_values=2 * self.h * seq_len * attn_mul, 172 | attn_output=2 * self.h * self.h * attn_mul, 173 | attn_output_bias=self.h * attn_mul, 174 | attn_output_dropout=DROPOUT_FLOPS * self.h * attn_mul, 175 | attn_output_residual=self.h * attn_mul, 176 | attn_output_layer_norm=LAYER_NORM_FLOPS * attn_mul, 177 | intermediate=2 * self.h * self.i, 178 | intermediate_act=ACTIVATION_FLOPS * self.i, 179 | intermediate_bias=self.i, 180 | output=2 * self.h * self.i, 181 | output_bias=self.h, 182 | output_dropout=DROPOUT_FLOPS * self.h, 183 | output_residual=self.h, 184 | output_layer_norm=LAYER_NORM_FLOPS * self.h 185 | ) 186 | fuse_flops = dict( 187 | kqv=self.h * self.kqv * self.ls * 2 + (self.h-self.hs) * self.hs * self.hs * 2 * 2, 188 | kqv_bias=self.kqv * self.ls + (self.h-self.hs) * self.hs, 189 | attn_output=self.h * self.h * self.ls * 2 + (self.h-self.hs) * self.hs * self.hs * 2 * 2, 190 | attn_output_bias=self.h * self.ls + (self.h-self.hs) * self.hs, 191 | attn_output_layer_norm=self.h * self.ls * 2 + (self.h-self.hs) * self.hs * 2, 192 | intermediate=self.h * self.i * self.ls * 2 + (self.h-self.hs)* self.hs * self.i * 2 * 2, 193 | intermediate_bias=self.i * self.ls + (self.h-self.hs) * self.i, 194 | output=self.h * self.i * self.ls * 2 + (self.h-self.hs) * self.i * self.h * 2, 195 | output_bias=self.h * self.ls * 2 + (self.h-self.hs) * self.h, 196 | output_layer_norm=self.h * self.ls * 2 + (self.h-self.hs) * self.h 197 | ) 198 | return sum(block_flops.values()) * seq_len + sum(fuse_flops.values()) 199 | 200 | def get_flops_computer(config): 201 | return TransformerHparams( 202 | config.hidden_size, # hidden size 203 | config.num_hidden_layers, # layers 204 | s=config.max_position_embeddings, # len of sequence 205 | v=config.vocab_size, # vocab size 206 | i=config.intermediate_size, # ff intermediate hidden size 207 | heads=config.num_attention_heads, # heads/head size 208 | ) 209 | 210 | def get_train_flops_by_config(config, batch_size, num_steps, seq_len=128): 211 | return TransformerHparams( 212 | config.hidden_size, # hidden size 213 | config.num_hidden_layers, # layers 214 | s=seq_len, # len of sequence 215 | v=config.vocab_size, # vocab size 216 | i=config.intermediate_size, # ff intermediate hidden size 217 | heads=config.num_attention_heads, # heads/head size 218 | ).get_train_flops(batch_size, num_steps) # 1M steps with batch size 2048 219 | 220 | def get_infer_flops_by_config(config, seq_len=128): 221 | return TransformerHparams( 222 | config.hidden_size, # hidden size 223 | config.num_hidden_layers, # layers 224 | s=seq_len, # len of sequence 225 | v=config.vocab_size, # vocab size 226 | i=config.intermediate_size, # ff intermediate hidden size 227 | heads=config.num_attention_heads, # heads/head size 228 | ).get_infer_flops() # 1M steps with batch size 2048 229 | -------------------------------------------------------------------------------- /docs/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 140 | 141 | 142 | 143 | 144 | Learning to Grow Pretrained Models for Efficient Transformer Training 145 | 146 | 147 | 148 | 149 | 150 |
151 |
Learning to Grow Pretrained Models for
Efficient Transformer Training
152 | 153 | 154 | 155 | 157 | 159 | 161 | 163 | 164 |
156 |
Peihao Wang1
158 |
Rameswar Panda2
160 |
Lucas Torroba Hennigen4
162 |
Philip Greengard3
165 | 166 | 167 | 168 | 170 | 172 | 174 | 176 | 178 | 179 |
169 |
Leonid Karlinsky2
171 |
Rogerio Feris2
173 |
David Cox2
175 |
Atlas Wang1
177 |
Yoon Kim4
180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 |
1 University of Texas at Austin
2 MIT-IBM Watson AI Lab
3 Columbia University
4 MIT
189 | 190 | 191 | 193 | 194 |
192 |
ICLR 2023
195 | 196 | 199 |
197 |

198 |
200 | 201 |

Abstract

202 | Scaling transformers has led to significant breakthroughs in many domains, leading to a paradigm in which larger versions of existing models are trained and released on a periodic basis. New instances of such models are typically trained completely from scratch, despite the fact that they are often just scaled-up versions of their smaller counterparts. How can we use the implicit knowledge in the parameters of smaller, extant models to enable faster training of newer, larger models? This paper describes an approach for accelerating transformer training by {learning to grow} pretrained transformers, where we learn to linearly map the parameters of the smaller model to initialize the larger model. For tractable learning, we factorize the linear transformation as a composition of (linear) width- and depth-growth operators, and further employ a Kronecker factorization of these growth operators to encode architectural knowledge. Extensive experiments across both language and vision transformers demonstrate that our learned Linear Growth Operator (LiGO) can save up to 50% computational cost of training from scratch, while also consistently outperforming strong baselines that also reuse smaller pretrained models to initialize larger models. 203 |

204 | 205 |

Qualitative Results

206 | 207 | 210 |
208 |

209 |
211 |
Qualitative examples showing that LiGO can accelerate BERT training time by ~40% with ~45% FLOPs saving.
212 |
213 |
214 | 215 | 216 |

Paper & Code

217 | 218 | 219 | 220 | 221 | 222 | 225 | 226 | 236 | 237 | 238 |
223 | 224 | 227 | Peihao Wang, Rameswar Panda, Lucas Torroba Hennigen, Philip Greengard, Leonid Karlinsky, Rogerio Feris, David Cox, Atlas Wang, Yoon Kim
228 | 229 | Learning to Grow Pretrained Models for Efficient Transformer Training
International Conference on Learning Representations (ICLR), 2023
230 | [PDF] 231 | [Code] 232 | 233 | 234 |
235 |
239 | 240 |
241 |
242 | 243 |
244 | 245 |

246 | 249 | 250 | 251 | -------------------------------------------------------------------------------- /run_grow_distributed.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ 17 | Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, BERT, RoBERTa). 18 | GPT and GPT-2 are fine-tuned using a causal language modeling (CLM) loss while BERT and RoBERTa are fine-tuned 19 | using a masked language modeling (MLM) loss. 20 | """ 21 | 22 | 23 | import argparse 24 | import glob 25 | import json 26 | import logging 27 | import os 28 | import pickle 29 | import random 30 | import re 31 | import shutil 32 | import sys 33 | from typing import Dict, List, Tuple 34 | from datetime import datetime 35 | import time 36 | 37 | import numpy as np 38 | import torch 39 | from torch.nn.utils.rnn import pad_sequence 40 | import torch.multiprocessing as mp 41 | from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler 42 | from torch.utils.data.distributed import DistributedSampler 43 | from tqdm import tqdm, trange 44 | from transformers import ( 45 | WEIGHTS_NAME, 46 | AdamW, 47 | PreTrainedModel, 48 | PreTrainedTokenizer, 49 | BertConfig, 50 | BertForMaskedLM, 51 | BertTokenizer, 52 | RobertaConfig, 53 | RobertaForMaskedLM, 54 | RobertaTokenizer, 55 | get_linear_schedule_with_warmup, 56 | ) 57 | 58 | sys.path.append( 59 | os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 60 | ) 61 | from param import process_args 62 | from model import SimpleBertForMaskedLM, SimpleRobertaForMaskedLM 63 | 64 | try: 65 | from torch.utils.tensorboard import SummaryWriter 66 | except ImportError: 67 | from tensorboardX import SummaryWriter 68 | 69 | from ligo import create_ligo_from_model 70 | 71 | from run_lm_distributed import TextDataset, LineByLineTextDataset, load_and_cache_examples, \ 72 | set_seed, mask_tokens, is_port_in_use 73 | 74 | logger = logging.getLogger(__name__) 75 | 76 | 77 | MODEL_CLASSES = { 78 | "bert": (BertConfig, SimpleBertForMaskedLM, BertTokenizer), 79 | "roberta": (RobertaConfig, SimpleRobertaForMaskedLM, RobertaTokenizer) 80 | } 81 | 82 | def train(args, train_dataset, model: PreTrainedModel, tokenizer: PreTrainedTokenizer) -> Tuple[int, float]: 83 | set_seed(args) # Added here for reproducibility 84 | 85 | """ Train the model """ 86 | if args.gpu == 0: 87 | current_time = datetime.now().strftime('%b%d_%H-%M-%S') 88 | tb_writer = SummaryWriter(args.output_dir + '/runs/' + current_time) 89 | 90 | args.train_batch_size = args.per_gpu_train_batch_size 91 | 92 | def collate(examples: List[torch.Tensor]): 93 | if tokenizer._pad_token is None: 94 | return pad_sequence(examples, batch_first=True) 95 | return pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id) 96 | 97 | if args.shuffle: 98 | logger.info(f"Shuffle the dataset in training," 99 | f"GPU: {args.gpu}," 100 | f"Rank: {args.rank}," 101 | f"Total: {args.world_size}") 102 | train_sampler = DistributedSampler( 103 | train_dataset, 104 | num_replicas=args.world_size, 105 | rank=args.rank, 106 | shuffle=args.shuffle, 107 | ) 108 | train_dataloader = DataLoader( 109 | train_dataset, sampler=train_sampler, shuffle=False, num_workers=0, 110 | batch_size=args.train_batch_size, collate_fn=collate, pin_memory=True 111 | ) 112 | 113 | 114 | t_total = args.max_steps 115 | 116 | # Prepare optimizer and schedule (linear warmup and decay) 117 | no_decay = [".bias", "LayerNorm.weight"] 118 | residual_weights = [".residual_weight", ".residual_bias"] 119 | optimizer_grouped_parameters = [ 120 | { 121 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay + residual_weights)], 122 | "weight_decay": args.weight_decay, 123 | }, 124 | {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}, 125 | { 126 | "params": [p for n, p in model.named_parameters() if any(nd in n for nd in residual_weights)], 127 | "lr": args.learning_rate_res if args.learning_rate_res is not None else args.learning_rate, 128 | "weight_decay": args.weight_decay_res if args.weight_decay_res is not None else args.weight_decay, 129 | }, 130 | ] 131 | optimizer = AdamW(optimizer_grouped_parameters, 132 | # betas=(0.9, 0.98), 133 | lr=args.learning_rate, 134 | eps=args.adam_epsilon) 135 | if args.warmup_ratio > 0.: 136 | assert args.warmup_steps == 0 137 | args.warmup_steps = int(t_total * args.warmup_ratio) 138 | if args.gpu == 0: 139 | print("Optimized with lr %f, steps %d, warmup steps %d, and use beta, epsilon %0.8f." % ( 140 | args.learning_rate, t_total, args.warmup_steps, optimizer.defaults['eps'] 141 | ), optimizer.defaults['betas']) 142 | scheduler = get_linear_schedule_with_warmup( 143 | optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total 144 | ) 145 | 146 | # Check if saved optimizer or scheduler states exist 147 | if ( 148 | args.model_name_or_path 149 | and os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) 150 | and os.path.isfile(os.path.join(args.model_name_or_path, "scheduler.pt")) 151 | ): 152 | # Load in optimizer and scheduler states 153 | optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt"))) 154 | scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt"))) 155 | 156 | if args.fp16: 157 | try: 158 | from apex import amp 159 | except ImportError: 160 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 161 | model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level, 162 | verbosity=0) 163 | from apex.parallel import DistributedDataParallel as DDP 164 | model = DDP(model) 165 | else: 166 | model = torch.nn.parallel.DistributedDataParallel( 167 | model, device_ids=[args.gpu], find_unused_parameters=True 168 | ) 169 | 170 | # Train! 171 | logger.info("***** Running training *****") 172 | logger.info(" Num examples = %d", len(train_dataset)) 173 | # logger.info(" Num Epochs = %d", args.num_train_epochs) 174 | logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) 175 | logger.info( 176 | " Total train batch size (w. distributed & accumulation) = %d", 177 | args.train_batch_size 178 | * args.gradient_accumulation_steps 179 | * args.world_size 180 | ) 181 | logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) 182 | logger.info(" Total optimization steps = %d", t_total) 183 | 184 | global_step = 0 185 | epochs_trained = 0 186 | # Check if continuing training from a checkpoint 187 | if args.model_name_or_path and os.path.exists(args.model_name_or_path): 188 | try: 189 | # set global_step to gobal_step of last saved checkpoint from model path 190 | checkpoint_name = os.path.basename(args.model_name_or_path) 191 | global_step = int(checkpoint_name.split("-")[-1]) 192 | epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps) 193 | steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps) 194 | logger.info(" Continuing training from checkpoint, will skip to saved global_step") 195 | logger.info(" Continuing training from iter %d, epoch %d" % (global_step, epochs_trained)) 196 | except ValueError: 197 | logger.info(" Do not load model from %s, restart training" % args.model_name_or_path) 198 | 199 | model.zero_grad() 200 | 201 | # IMPORTANT: save the initialization 202 | if args.gpu == 0 and global_step == 0: 203 | checkpoint_name = f"checkpoint-{global_step:08d}" 204 | ckpt_dir = os.path.join(args.output_dir, 'checkpoints') 205 | os.makedirs(ckpt_dir, exist_ok=True) 206 | save_model(args, ckpt_dir, checkpoint_name, model, tokenizer, optimizer, scheduler) 207 | 208 | while True: 209 | epoch_iterator = tqdm(train_dataloader, desc=f"Epoch: {epochs_trained:03d}", disable=args.gpu != 0) 210 | tr_loss, tr_lm_loss = 0.0, 0.0 211 | t_start = time.time() 212 | model.zero_grad() # Support of accumulating gradients 213 | for step, batch in enumerate(epoch_iterator): 214 | inputs, labels = mask_tokens(batch, tokenizer, args) if args.mlm else (batch, batch) 215 | inputs = inputs.to(args.device) 216 | labels = labels.to(args.device) 217 | # If some of the input is padded, then the attention mask is needed 218 | attention_mask = (inputs != tokenizer.pad_token_id) # word_tokens --> 1, pad_token --> 0 219 | if attention_mask.all(): 220 | attention_mask = None 221 | 222 | model.train() 223 | outputs = model(inputs, 224 | attention_mask=attention_mask, 225 | masked_lm_labels=labels, 226 | current_step=global_step) if args.mlm else model(inputs, labels=labels, current_step=global_step) 227 | loss = outputs['loss'] # model outputs are always tuple in transformers (see doc) 228 | 229 | if args.gradient_accumulation_steps > 1: 230 | loss = loss / args.gradient_accumulation_steps 231 | 232 | if args.fp16: 233 | with amp.scale_loss(loss, optimizer) as scaled_loss: 234 | scaled_loss.backward() 235 | else: 236 | loss.backward() 237 | 238 | tr_loss += loss.item() 239 | tr_lm_loss += outputs['lm_loss'].item() 240 | if (step + 1) % args.gradient_accumulation_steps == 0: 241 | if args.max_grad_norm > 0.: 242 | if args.fp16: 243 | total_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) 244 | else: 245 | total_norm =torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 246 | optimizer.step() 247 | scheduler.step() # Update learning rate schedule 248 | model.zero_grad() 249 | global_step += 1 250 | 251 | if args.gpu == 0 and args.logging_steps > 0 and global_step % args.logging_steps == 0: 252 | t_elapse = time.time() - t_start 253 | 254 | # Log metrics 255 | tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step) 256 | if args.fp16: 257 | try: 258 | from apex.amp import _amp_state 259 | tb_writer.add_scalar("loss_scale", _amp_state.loss_scalers[0]._loss_scale, global_step) 260 | tb_writer.add_scalar("scaled_loss", scaled_loss.item(), global_step) 261 | except ImportError: 262 | logger.warning("Cannot import apex.amp._amp_state, " 263 | "would not state the loss_scale in the log") 264 | if args.max_grad_norm > 0.: # Only clip the grad when it is valid 265 | tb_writer.add_scalar("grad_norm", total_norm, global_step) 266 | train_loss = tr_loss / args.logging_steps 267 | train_ppl = torch.exp(torch.tensor(tr_lm_loss / args.logging_steps)).item() 268 | tb_writer.add_scalar("loss", train_loss, global_step) 269 | tb_writer.add_scalar("train_ppl", train_ppl, global_step) 270 | tr_loss = tr_lm_loss = 0. 271 | 272 | # also evaluate on valid set for ppl 273 | logger.info(" Evaluation Results of step %d: " % global_step) 274 | results = evaluate(args, model.module, tokenizer) 275 | for key, value in results.items(): 276 | tb_writer.add_scalar("eval_{}".format(key), value, global_step) 277 | logger.info("\t %s: %0.4f" % (key, value)) 278 | 279 | output_log_file = os.path.join(args.output_dir, "train_log.txt") 280 | with open(output_log_file, 'a') as f: 281 | eval_ppl = results['perplexity'] 282 | print(f"train_step={global_step}, train_time={t_elapse}, lr={scheduler.get_lr()[0]}, train_loss={train_loss}," 283 | f"train_ppl={train_ppl}, eval_ppl={eval_ppl}", file=f) 284 | 285 | t_start = time.time() 286 | 287 | if args.gpu == 0 and args.ckpt_steps > 0 and global_step % args.ckpt_steps == 0: 288 | checkpoint_name = f"checkpoint-{global_step:08d}" 289 | ckpt_dir = os.path.join(args.output_dir, 'checkpoints') 290 | os.makedirs(ckpt_dir, exist_ok=True) 291 | save_model(args, ckpt_dir, checkpoint_name, model, tokenizer, optimizer, scheduler) 292 | 293 | if args.max_steps > 0 and global_step >= args.max_steps: 294 | break 295 | 296 | if args.max_steps > 0 and global_step >= args.max_steps: 297 | epoch_iterator.close() 298 | break 299 | 300 | epochs_trained += 1 301 | 302 | # consider during the last evaluation, the GPU 0 is still working while others have exited. 303 | # when GPU 0 call torch.no_grad, it will wait for the response from other processes 304 | # however, a deadlock will be caused if other processes just exit 305 | torch.distributed.barrier() 306 | 307 | if args.gpu == 0: 308 | tb_writer.close() 309 | 310 | def evaluate(args, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prefix="") -> Dict: 311 | # Loop to handle MNLI double evaluation (matched, mis-matched) 312 | eval_dataset = load_and_cache_examples(args, tokenizer, evaluate=True) 313 | 314 | args.eval_batch_size = args.per_gpu_eval_batch_size 315 | # Note that DistributedSampler samples randomly 316 | 317 | def collate(examples: List[torch.Tensor]): 318 | if tokenizer._pad_token is None: 319 | return pad_sequence(examples, batch_first=True) 320 | return pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id) 321 | 322 | eval_sampler = SequentialSampler(eval_dataset) 323 | eval_dataloader = DataLoader( 324 | eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, collate_fn=collate 325 | ) 326 | 327 | # Eval! 328 | logger.info("***** Running evaluation {} *****".format(prefix)) 329 | logger.info(" Num examples = %d", len(eval_dataset)) 330 | logger.info(" Batch size = %d", args.eval_batch_size) 331 | eval_loss = 0.0 332 | nb_eval_steps = 0 333 | model.eval() 334 | 335 | for batch in tqdm(eval_dataloader, desc="Evaluating"): 336 | inputs, labels = mask_tokens(batch, tokenizer, args) if args.mlm else (batch, batch) 337 | inputs = inputs.to(args.device) 338 | labels = labels.to(args.device) 339 | # If some of the input is padded, then the attention mask is needed 340 | attention_mask = (inputs != tokenizer.pad_token_id) # word_tokens --> 1, pad_token --> 0 341 | if attention_mask.all(): 342 | attention_mask = None 343 | 344 | with torch.no_grad(): 345 | outputs = model(inputs, attention_mask=attention_mask, masked_lm_labels=labels) if args.mlm else model(inputs, labels=labels) 346 | lm_loss = outputs['lm_loss'] 347 | eval_loss += lm_loss.mean().item() 348 | nb_eval_steps += 1 349 | 350 | eval_loss = eval_loss / nb_eval_steps 351 | perplexity = torch.exp(torch.tensor(eval_loss)).item() 352 | 353 | result = {"perplexity": perplexity} 354 | 355 | return result 356 | 357 | 358 | def save_model(args, ckpt_dir, name, model, tokenizer, optimizer, scheduler): 359 | # Save model checkpoint 360 | output_dir = os.path.join(ckpt_dir, name) 361 | os.makedirs(output_dir, exist_ok=True) 362 | model_to_save = ( 363 | model.module if hasattr(model, "module") else model 364 | ) # Take care of distributed/parallel training 365 | model_to_save.save_pretrained(output_dir) 366 | tokenizer.save_pretrained(output_dir) 367 | 368 | torch.save(args, os.path.join(output_dir, "training_args.bin")) 369 | logger.info("Saving model checkpoint to %s", output_dir) 370 | 371 | torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) 372 | torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) 373 | logger.info("Saving optimizer and scheduler states to %s", output_dir) 374 | 375 | 376 | def main(): 377 | parser = process_args() 378 | args = parser.parse_args() 379 | 380 | os.environ['MASTER_ADDR'] = '127.0.0.1' 381 | port = 9595 382 | while is_port_in_use(port): 383 | port += 1 384 | print("Use port", port) 385 | os.environ['MASTER_PORT'] = str(port) 386 | 387 | # Using all available gpus for multi-processing distributed 388 | args.gpus = torch.cuda.device_count() 389 | print("Use gpus ", list(range(args.gpus))) 390 | args.world_size = args.gpus * args.nodes 391 | mp.spawn(setup, nprocs=args.gpus, args=(args,)) 392 | 393 | 394 | def setup(gpu, args): 395 | if args.should_continue: 396 | ckpt_dir = os.path.join(args.output_dir, 'checkpoints') 397 | checkpoint_names = [] 398 | if os.path.isdir(ckpt_dir): 399 | checkpoint_names = [fn for fn in os.listdir(ckpt_dir) if fn.startswith('checkpoint-')] 400 | if len(checkpoint_names) > 0: 401 | checkpoint_names = sorted(checkpoint_names, key=lambda p: int(p.split('-')[-1])) 402 | args.model_name_or_path = os.path.join(ckpt_dir, checkpoint_names[-1]) 403 | else: 404 | logger.warning('No checkpoint detected: %s', ckpt_dir) 405 | args.model_name_or_path = None 406 | 407 | # Setup CUDA, GPU & distributed training 408 | torch.cuda.set_device(gpu) 409 | device = torch.device("cuda", gpu) 410 | args.gpu = gpu # Local device id. 411 | args.device = device # Local device object. 412 | args.rank = args.nr * args.gpus + gpu # The gpu id in the world. 413 | torch.distributed.init_process_group( 414 | backend="nccl", 415 | init_method='env://', 416 | world_size=args.world_size, 417 | rank=args.rank 418 | ) 419 | 420 | # Setup logging 421 | logging.basicConfig( 422 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 423 | datefmt="%m/%d/%Y %H:%M:%S", 424 | level=logging.INFO if args.gpu == 0 else logging.WARN, 425 | ) 426 | logger.warning( 427 | "Process GPU: %s, num_of_total_GPUs: %s, distributed training: True, 16-bits training: %s", 428 | args.gpu, args.gpus, args.fp16, 429 | ) 430 | 431 | # Set seed 432 | set_seed(args) 433 | 434 | # Load pretrained model and token 435 | # Barrier to make sure only the first process in distributed training 436 | # download model & vocabizer 437 | if gpu != 0: 438 | torch.distributed.barrier() 439 | 440 | config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] 441 | 442 | # Get Config 443 | if args.config_name: 444 | config = config_class.from_pretrained(args.config_name, cache_dir=args.cache_dir) 445 | elif args.model_name_or_path: 446 | config = config_class.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir) 447 | else: 448 | raise ValueError( 449 | "Why do you want the default config?? Please use --config_name or --model_name_or_path" 450 | ) 451 | 452 | # Get Tokenizer 453 | if args.tokenizer_name: 454 | tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name, cache_dir=args.cache_dir) 455 | # BERT always needs lower cased tokens. 456 | if 'uncased' in args.model_type: 457 | assert tokenizer.init_kwargs.get("do_lower_case", False) 458 | elif args.model_name_or_path: 459 | tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir) 460 | else: 461 | raise ValueError( 462 | "You are instantiating a new {} tokenizer. This is not supported, " 463 | "but you can do it from another script, save it," 464 | "and load it from here, using --tokenizer_name".format(tokenizer_class.__name__) 465 | ) 466 | 467 | assert args.block_size <= tokenizer.model_max_length 468 | 469 | args.fuse_init_scheme_depth = args.fuse_init_scheme_width = args.fuse_init_scheme[0] 470 | if len(args.fuse_init_scheme) >= 2: 471 | args.fuse_init_scheme_width = args.fuse_init_scheme[1] 472 | args.fuse_init_noise_depth = args.fuse_init_noise_width = args.fuse_init_noise[0] 473 | if len(args.fuse_init_noise) >= 2: 474 | args.fuse_init_noise_width = args.fuse_init_noise[1] 475 | 476 | model = model_class(config=config, args=args) 477 | model = create_ligo_from_model(model, args) 478 | 479 | if args.model_name_or_path: 480 | state_dict = torch.load(os.path.join(args.model_name_or_path, 'pytorch_model.bin'), map_location=torch.device('cpu')) 481 | model.load_state_dict(state_dict) 482 | 483 | model.to(args.device) 484 | 485 | # End of barrier to make sure only the first process waiting other processes 486 | if gpu == 0: 487 | torch.distributed.barrier() 488 | 489 | logger.info("Training/evaluation parameters %s", args) 490 | 491 | # Training 492 | if args.do_train: 493 | # Barrier to make sure only the first process in distributed training process the dataset, 494 | # and the others will use the cache 495 | if gpu != 0: 496 | torch.distributed.barrier() 497 | train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False) 498 | if gpu == 0: 499 | torch.distributed.barrier() 500 | 501 | train(args, train_dataset, model, tokenizer) 502 | 503 | # Evaluation 504 | if args.do_eval and gpu == 0: 505 | result = evaluate(args, model, tokenizer) 506 | 507 | 508 | if __name__ == "__main__": 509 | main() 510 | -------------------------------------------------------------------------------- /run_lm_distributed.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ 17 | Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, BERT, RoBERTa). 18 | GPT and GPT-2 are fine-tuned using a causal language modeling (CLM) loss while BERT and RoBERTa are fine-tuned 19 | using a masked language modeling (MLM) loss. 20 | """ 21 | 22 | 23 | import argparse 24 | import glob 25 | import json 26 | import logging 27 | import os 28 | import pickle 29 | import random 30 | import re 31 | import shutil 32 | import sys 33 | from typing import Dict, List, Tuple 34 | from datetime import datetime 35 | import time 36 | 37 | import numpy as np 38 | import torch 39 | from torch.nn.utils.rnn import pad_sequence 40 | import torch.multiprocessing as mp 41 | from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler 42 | from torch.utils.data.distributed import DistributedSampler 43 | from tqdm import tqdm, trange 44 | from transformers import ( 45 | WEIGHTS_NAME, 46 | AdamW, 47 | PreTrainedModel, 48 | PreTrainedTokenizer, 49 | BertConfig, 50 | BertForMaskedLM, 51 | BertTokenizer, 52 | RobertaConfig, 53 | RobertaForMaskedLM, 54 | RobertaTokenizer, 55 | get_linear_schedule_with_warmup, 56 | get_cosine_schedule_with_warmup, 57 | get_polynomial_decay_schedule_with_warmup, 58 | ) 59 | 60 | sys.path.append( 61 | os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 62 | ) 63 | from data import CoLDataset 64 | from param import process_args 65 | from model import SimpleBertForMaskedLM, SimpleRobertaForMaskedLM 66 | 67 | try: 68 | from torch.utils.tensorboard import SummaryWriter 69 | except ImportError: 70 | from tensorboardX import SummaryWriter 71 | 72 | from ligo import initialize_model_with_ligo 73 | 74 | logger = logging.getLogger(__name__) 75 | 76 | 77 | MODEL_CLASSES = { 78 | "bert": (BertConfig, SimpleBertForMaskedLM, BertTokenizer), 79 | "roberta": (RobertaConfig, SimpleRobertaForMaskedLM, RobertaTokenizer), 80 | } 81 | 82 | 83 | class TextDataset(Dataset): 84 | def __init__(self, tokenizer: PreTrainedTokenizer, args, file_path: str, block_size=512): 85 | assert os.path.isfile(file_path) 86 | 87 | block_size = block_size - (tokenizer.max_len - tokenizer.max_len_single_sentence) 88 | 89 | directory, filename = os.path.split(file_path) 90 | cached_features_file = os.path.join( 91 | directory, args.model_type + "_cached_lm_" + str(block_size) + "_" + filename 92 | ) 93 | 94 | if os.path.exists(cached_features_file) and not args.overwrite_cache: 95 | logger.info("Loading features from cached file %s", cached_features_file) 96 | with open(cached_features_file, "rb") as handle: 97 | self.examples = pickle.load(handle) 98 | else: 99 | logger.info("Creating features from dataset file at %s", directory) 100 | 101 | self.examples = [] 102 | with open(file_path, encoding="utf-8") as f: 103 | text = f.read() 104 | 105 | tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text)) 106 | 107 | for i in range(0, len(tokenized_text) - block_size + 1, block_size): # Truncate in block of block_size 108 | self.examples.append(tokenizer.build_inputs_with_special_tokens(tokenized_text[i : i + block_size])) 109 | # Note that we are loosing the last truncated example here for the sake of simplicity (no padding) 110 | # If your dataset is small, first you should loook for a bigger one :-) and second you 111 | # can change this behavior by adding (model specific) padding. 112 | 113 | logger.info("Saving features into cached file %s", cached_features_file) 114 | with open(cached_features_file, "wb") as handle: 115 | pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL) 116 | 117 | def __len__(self): 118 | return len(self.examples) 119 | 120 | def __getitem__(self, item): 121 | return torch.tensor(self.examples[item], dtype=torch.long) 122 | 123 | 124 | class LineByLineTextDataset(Dataset): 125 | def __init__(self, tokenizer: PreTrainedTokenizer, args, file_path: str, block_size=512): 126 | assert os.path.isfile(file_path) 127 | # Here, we do not cache the features, operating under the assumption 128 | # that we will soon use fast multithreaded tokenizers from the 129 | # `tokenizers` repo everywhere =) 130 | logger.info("Creating features from dataset file at %s", file_path) 131 | 132 | with open(file_path, encoding="utf-8") as f: 133 | lines = [line for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())] 134 | 135 | self.examples = tokenizer.batch_encode_plus(lines, add_special_tokens=True, max_length=block_size)["input_ids"] 136 | 137 | def __len__(self): 138 | return len(self.examples) 139 | 140 | def __getitem__(self, i): 141 | return torch.tensor(self.examples[i], dtype=torch.long) 142 | 143 | 144 | def load_and_cache_examples(args, tokenizer, evaluate=False): 145 | file_path = args.eval_data_file if evaluate else args.train_data_file 146 | if args.col_data: 147 | return CoLDataset(file_path, args.tokenizer_name, tokenizer, args.block_size, 148 | split_sent=args.split_sent, 149 | verbose=(args.gpu == 0)) 150 | elif args.line_by_line: 151 | return LineByLineTextDataset(tokenizer, args, file_path=file_path, block_size=args.block_size) 152 | else: 153 | return TextDataset(tokenizer, args, file_path=file_path, block_size=args.block_size) 154 | 155 | 156 | def set_seed(args): 157 | random.seed(args.seed) 158 | np.random.seed(args.seed) 159 | torch.manual_seed(args.seed) 160 | 161 | 162 | def mask_tokens(inputs: torch.Tensor, tokenizer: PreTrainedTokenizer, args) -> Tuple[torch.Tensor, torch.Tensor]: 163 | """ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """ 164 | 165 | if tokenizer.mask_token is None: 166 | raise ValueError( 167 | "This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the --mlm flag if you want to use this tokenizer." 168 | ) 169 | 170 | labels = inputs.clone() 171 | # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa) 172 | probability_matrix = torch.full(labels.shape, args.mlm_probability) 173 | special_tokens_mask = [ 174 | tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist() 175 | ] 176 | probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0) 177 | if tokenizer._pad_token is not None: 178 | padding_mask = labels.eq(tokenizer.pad_token_id) 179 | probability_matrix.masked_fill_(padding_mask, value=0.0) 180 | masked_indices = torch.bernoulli(probability_matrix).bool() 181 | labels[~masked_indices] = -100 # We only compute loss on masked tokens 182 | 183 | # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) 184 | indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices 185 | inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token) 186 | 187 | # 10% of the time, we replace masked input tokens with random word 188 | indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced 189 | random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long) 190 | inputs[indices_random] = random_words[indices_random] 191 | 192 | # The rest of the time (10% of the time) we keep the masked input tokens unchanged 193 | return inputs, labels 194 | 195 | 196 | def train(args, train_dataset, model: PreTrainedModel, tokenizer: PreTrainedTokenizer) -> Tuple[int, float]: 197 | set_seed(args) # Added here for reproducibility 198 | 199 | """ Train the model """ 200 | if args.gpu == 0: 201 | current_time = datetime.now().strftime('%b%d_%H-%M-%S') 202 | tb_writer = SummaryWriter(args.output_dir + '/runs/' + current_time) 203 | 204 | args.train_batch_size = args.per_gpu_train_batch_size 205 | 206 | def collate(examples: List[torch.Tensor]): 207 | if tokenizer._pad_token is None: 208 | return pad_sequence(examples, batch_first=True) 209 | return pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id) 210 | 211 | if args.shuffle: 212 | logger.info(f"Shuffle the dataset in training," 213 | f"GPU: {args.gpu}," 214 | f"Rank: {args.rank}," 215 | f"Total: {args.world_size}") 216 | train_sampler = DistributedSampler( 217 | train_dataset, 218 | num_replicas=args.world_size, 219 | rank=args.rank, 220 | shuffle=args.shuffle, 221 | ) 222 | train_dataloader = DataLoader( 223 | train_dataset, sampler=train_sampler, shuffle=False, num_workers=0, 224 | batch_size=args.train_batch_size, collate_fn=collate, pin_memory=True 225 | ) 226 | 227 | t_total = args.max_steps 228 | 229 | # Prepare optimizer and schedule (linear warmup and decay) 230 | no_decay = ["bias", "LayerNorm.weight"] 231 | optimizer_grouped_parameters = [ 232 | { 233 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 234 | "weight_decay": args.weight_decay, 235 | }, 236 | {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}, 237 | ] 238 | optimizer = AdamW(optimizer_grouped_parameters, 239 | # betas=(0.9, 0.98), 240 | lr=args.learning_rate, 241 | eps=args.adam_epsilon) 242 | if args.warmup_ratio > 0.: 243 | assert args.warmup_steps == 0 244 | args.warmup_steps = int(t_total * args.warmup_ratio) 245 | if args.gpu == 0: 246 | print("Optimized with lr %f, steps %d, warmup steps %d, and use beta, epsilon %0.8f." % ( 247 | args.learning_rate, t_total, args.warmup_steps, optimizer.defaults['eps'] 248 | ), optimizer.defaults['betas']) 249 | if args.scheduler_type == 'linear': 250 | scheduler = get_linear_schedule_with_warmup( 251 | optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total 252 | ) 253 | elif args.scheduler_type == 'cosine': 254 | scheduler = get_cosine_schedule_with_warmup( 255 | optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total, 256 | num_cycles=args.scheduler_cosine_cycles 257 | ) 258 | elif args.scheduler_type == 'poly': 259 | scheduler = get_polynomial_decay_schedule_with_warmup( 260 | optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total, 261 | power=args.scheduler_poly_power 262 | ) 263 | else: 264 | raise ValueError(f"Unknow lr scheduler: {args.scheduler_type}") 265 | 266 | # Check if saved optimizer or scheduler states exist 267 | if ( 268 | args.model_name_or_path 269 | and os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) 270 | and os.path.isfile(os.path.join(args.model_name_or_path, "scheduler.pt")) 271 | ): 272 | # Load in optimizer and scheduler states 273 | optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt"), map_location=torch.device('cpu'))) 274 | scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt"), map_location=torch.device('cpu'))) 275 | 276 | if args.fp16: 277 | try: 278 | from apex import amp 279 | except ImportError: 280 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 281 | model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level, 282 | verbosity=0) 283 | from apex.parallel import DistributedDataParallel as DDP 284 | model = DDP(model) 285 | else: 286 | model = torch.nn.parallel.DistributedDataParallel( 287 | model, device_ids=[args.gpu], find_unused_parameters=True 288 | ) 289 | 290 | # Train! 291 | logger.info("***** Running training *****") 292 | logger.info(" Num examples = %d", len(train_dataset)) 293 | logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) 294 | logger.info( 295 | " Total train batch size (w. distributed & accumulation) = %d", 296 | args.train_batch_size 297 | * args.gradient_accumulation_steps 298 | * args.world_size 299 | ) 300 | logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) 301 | logger.info(" Total optimization steps = %d", t_total) 302 | 303 | global_step = 0 304 | epochs_trained = 0 305 | # Check if continuing training from a checkpoint 306 | if args.model_name_or_path and os.path.exists(args.model_name_or_path): 307 | try: 308 | # set global_step to gobal_step of last saved checkpoint from model path 309 | checkpoint_name = os.path.basename(args.model_name_or_path) 310 | global_step = int(checkpoint_name.split("-")[-1]) 311 | epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps) 312 | steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps) 313 | logger.info(" Continuing training from checkpoint, will skip to saved global_step") 314 | logger.info(" Continuing training from iter %d, epoch %d" % (global_step, epochs_trained)) 315 | except ValueError: 316 | logger.info(" Do not load model from %s, restart training" % args.model_name_or_path) 317 | 318 | model.zero_grad() 319 | 320 | # IMPORTANT: save the initialization 321 | if args.gpu == 0 and global_step == 0: 322 | checkpoint_name = f"checkpoint-{global_step:08d}" 323 | ckpt_dir = os.path.join(args.output_dir, 'checkpoints') 324 | os.makedirs(ckpt_dir, exist_ok=True) 325 | save_model(args, ckpt_dir, checkpoint_name, model, tokenizer, optimizer, scheduler) 326 | 327 | while True: 328 | epoch_iterator = tqdm(train_dataloader, desc=f"Epoch: {epochs_trained:03d}", disable=args.gpu != 0) 329 | tr_loss, tr_lm_loss = 0.0, 0.0 330 | t_start = time.time() 331 | model.zero_grad() # Support of accumulating gradients 332 | for step, batch in enumerate(epoch_iterator): 333 | inputs, labels = mask_tokens(batch, tokenizer, args) if args.mlm else (batch, batch) 334 | inputs = inputs.to(args.device) 335 | labels = labels.to(args.device) 336 | # If some of the input is padded, then the attention mask is needed 337 | attention_mask = (inputs != tokenizer.pad_token_id) # word_tokens --> 1, pad_token --> 0 338 | if attention_mask.all(): 339 | attention_mask = None 340 | 341 | model.train() 342 | outputs = model(inputs, 343 | attention_mask=attention_mask, 344 | masked_lm_labels=labels, 345 | current_step=global_step) if args.mlm else model(inputs, labels=labels, current_step=global_step) 346 | loss = outputs['loss'] # model outputs are always tuple in transformers (see doc) 347 | 348 | if args.gradient_accumulation_steps > 1: 349 | loss = loss / args.gradient_accumulation_steps 350 | 351 | if args.fp16: 352 | with amp.scale_loss(loss, optimizer) as scaled_loss: 353 | scaled_loss.backward() 354 | else: 355 | loss.backward() 356 | 357 | tr_loss += loss.item() 358 | tr_lm_loss += outputs['lm_loss'].item() 359 | if (step + 1) % args.gradient_accumulation_steps == 0: 360 | if args.max_grad_norm > 0.: 361 | if args.fp16: 362 | total_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) 363 | else: 364 | total_norm =torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 365 | optimizer.step() 366 | scheduler.step() # Update learning rate schedule 367 | model.zero_grad() 368 | global_step += 1 369 | 370 | if args.gpu == 0 and args.logging_steps > 0 and global_step % args.logging_steps == 0: 371 | t_elapse = time.time() - t_start 372 | 373 | # Log metrics 374 | tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step) 375 | if args.fp16: 376 | try: 377 | from apex.amp import _amp_state 378 | tb_writer.add_scalar("loss_scale", _amp_state.loss_scalers[0]._loss_scale, global_step) 379 | tb_writer.add_scalar("scaled_loss", scaled_loss.item(), global_step) 380 | except ImportError: 381 | logger.warning("Cannot import apex.amp._amp_state, " 382 | "would not state the loss_scale in the log") 383 | if args.max_grad_norm > 0.: # Only clip the grad when it is valid 384 | tb_writer.add_scalar("grad_norm", total_norm, global_step) 385 | train_loss = tr_loss / args.logging_steps 386 | train_ppl = torch.exp(torch.tensor(tr_lm_loss / args.logging_steps)).item() 387 | tb_writer.add_scalar("loss", train_loss, global_step) 388 | tb_writer.add_scalar("train_ppl", train_ppl, global_step) 389 | tr_loss = tr_lm_loss = 0. 390 | 391 | # also evaluate on valid set for ppl 392 | logger.info(" Evaluation Results of step %d: " % global_step) 393 | results = evaluate(args, model.module, tokenizer) 394 | for key, value in results.items(): 395 | tb_writer.add_scalar("eval_{}".format(key), value, global_step) 396 | logger.info("\t %s: %0.4f" % (key, value)) 397 | 398 | output_log_file = os.path.join(args.output_dir, "train_log.txt") 399 | with open(output_log_file, 'a') as f: 400 | eval_ppl = results['perplexity'] 401 | print(f"train_step={global_step}, train_time={t_elapse}, lr={scheduler.get_lr()[0]}, train_loss={train_loss}," 402 | f"train_ppl={train_ppl}, eval_ppl={eval_ppl}", file=f) 403 | 404 | t_start = time.time() 405 | 406 | if args.gpu == 0 and args.ckpt_steps > 0 and global_step % args.ckpt_steps == 0: 407 | checkpoint_name = f"checkpoint-{global_step:08d}" 408 | ckpt_dir = os.path.join(args.output_dir, 'checkpoints') 409 | os.makedirs(ckpt_dir, exist_ok=True) 410 | save_model(args, ckpt_dir, checkpoint_name, model, tokenizer, optimizer, scheduler) 411 | 412 | if args.max_steps > 0 and global_step >= args.max_steps: 413 | break 414 | 415 | if args.max_steps > 0 and global_step >= args.max_steps: 416 | epoch_iterator.close() 417 | break 418 | 419 | epochs_trained += 1 420 | 421 | # consider during the last evaluation, the GPU 0 is still working while others have exited. 422 | # when GPU 0 call torch.no_grad, it will wait for the response from other processes 423 | # however, a deadlock will be caused if other processes just exit 424 | # torch.distributed.barrier() 425 | 426 | if args.gpu == 0: 427 | tb_writer.close() 428 | 429 | 430 | def save_model(args, ckpt_dir, name, model, tokenizer, optimizer, scheduler): 431 | # Save model checkpoint 432 | output_dir = os.path.join(ckpt_dir, name) 433 | os.makedirs(output_dir, exist_ok=True) 434 | model_to_save = ( 435 | model.module if hasattr(model, "module") else model 436 | ) # Take care of distributed/parallel training 437 | model_to_save.save_pretrained(output_dir) 438 | tokenizer.save_pretrained(output_dir) 439 | 440 | torch.save(args, os.path.join(output_dir, "training_args.bin")) 441 | logger.info("Saving model checkpoint to %s", output_dir) 442 | 443 | torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) 444 | torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) 445 | logger.info("Saving optimizer and scheduler states to %s", output_dir) 446 | 447 | 448 | def evaluate(args, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prefix="") -> Dict: 449 | # Loop to handle MNLI double evaluation (matched, mis-matched) 450 | eval_dataset = load_and_cache_examples(args, tokenizer, evaluate=True) 451 | 452 | args.eval_batch_size = args.per_gpu_eval_batch_size 453 | # Note that DistributedSampler samples randomly 454 | 455 | def collate(examples: List[torch.Tensor]): 456 | if tokenizer._pad_token is None: 457 | return pad_sequence(examples, batch_first=True) 458 | return pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id) 459 | 460 | eval_sampler = SequentialSampler(eval_dataset) 461 | eval_dataloader = DataLoader( 462 | eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, collate_fn=collate 463 | ) 464 | 465 | # Eval! 466 | logger.info("***** Running evaluation {} *****".format(prefix)) 467 | logger.info(" Num examples = %d", len(eval_dataset)) 468 | logger.info(" Batch size = %d", args.eval_batch_size) 469 | eval_loss = 0.0 470 | nb_eval_steps = 0 471 | model.eval() 472 | 473 | for batch in tqdm(eval_dataloader, desc="Evaluating"): 474 | inputs, labels = mask_tokens(batch, tokenizer, args) if args.mlm else (batch, batch) 475 | inputs = inputs.to(args.device) 476 | labels = labels.to(args.device) 477 | # If some of the input is padded, then the attention mask is needed 478 | attention_mask = (inputs != tokenizer.pad_token_id) # word_tokens --> 1, pad_token --> 0 479 | if attention_mask.all(): 480 | attention_mask = None 481 | 482 | with torch.no_grad(): 483 | outputs = model(inputs, attention_mask=attention_mask, masked_lm_labels=labels) if args.mlm else model(inputs, labels=labels) 484 | lm_loss = outputs['lm_loss'] 485 | eval_loss += lm_loss.mean().item() 486 | nb_eval_steps += 1 487 | 488 | eval_loss = eval_loss / nb_eval_steps 489 | perplexity = torch.exp(torch.tensor(eval_loss)).item() 490 | 491 | result = {"perplexity": perplexity} 492 | 493 | return result 494 | 495 | 496 | def is_port_in_use(port): 497 | import socket 498 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 499 | return s.connect_ex(('localhost', port)) == 0 500 | 501 | 502 | def main(): 503 | parser = process_args() 504 | args = parser.parse_args() 505 | 506 | os.environ['MASTER_ADDR'] = '127.0.0.1' 507 | port = 9595 508 | while is_port_in_use(port): 509 | port += 1 510 | print("Use port", port) 511 | os.environ['MASTER_PORT'] = str(port) 512 | 513 | # Using all available gpus for multi-processing distributed 514 | args.gpus = torch.cuda.device_count() 515 | print("Use gpus ", list(range(args.gpus))) 516 | args.world_size = args.gpus * args.nodes 517 | mp.spawn(setup, nprocs=args.gpus, args=(args,)) 518 | 519 | 520 | def setup(gpu, args): 521 | if args.should_continue: 522 | ckpt_dir = os.path.join(args.output_dir, 'checkpoints') 523 | checkpoint_names = [] 524 | if os.path.isdir(ckpt_dir): 525 | checkpoint_names = [fn for fn in os.listdir(ckpt_dir) if fn.startswith('checkpoint-')] 526 | if len(checkpoint_names) > 0: 527 | checkpoint_names = sorted(checkpoint_names, key=lambda p: int(p.split('-')[-1])) 528 | args.model_name_or_path = os.path.join(ckpt_dir, checkpoint_names[-1]) 529 | else: 530 | logger.warning('No checkpoint detected: %s', ckpt_dir) 531 | args.model_name_or_path = None 532 | 533 | # Setup CUDA, GPU & distributed training 534 | torch.cuda.set_device(gpu) 535 | device = torch.device("cuda", gpu) 536 | args.gpu = gpu # Local device id. 537 | args.device = device # Local device object. 538 | args.rank = args.nr * args.gpus + gpu # The gpu id in the world. 539 | torch.distributed.init_process_group( 540 | backend="nccl", 541 | init_method='env://', 542 | world_size=args.world_size, 543 | rank=args.rank 544 | ) 545 | 546 | # Setup logging 547 | logging.basicConfig( 548 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 549 | datefmt="%m/%d/%Y %H:%M:%S", 550 | level=logging.INFO if args.gpu == 0 else logging.WARN, 551 | ) 552 | logger.warning( 553 | "Process GPU: %s, num_of_total_GPUs: %s, distributed training: True, 16-bits training: %s", 554 | args.gpu, args.gpus, args.fp16, 555 | ) 556 | 557 | # Set seed 558 | set_seed(args) 559 | 560 | # Load pretrained model and token 561 | # Barrier to make sure only the first process in distributed training 562 | # download model & vocabizer 563 | if gpu != 0: 564 | torch.distributed.barrier() 565 | 566 | config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] 567 | 568 | # Get Config 569 | if args.config_name: 570 | config = config_class.from_pretrained(args.config_name, cache_dir=args.cache_dir) 571 | elif args.model_name_or_path: 572 | config = config_class.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir) 573 | else: 574 | raise ValueError( 575 | "Why do you want the default config?? Please use --config_name or --model_name_or_path" 576 | ) 577 | 578 | # Get Tokenizer 579 | if args.tokenizer_name: 580 | tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name, cache_dir=args.cache_dir) 581 | # BERT always needs lower cased tokens. 582 | if 'uncased' in args.model_type: 583 | assert tokenizer.init_kwargs.get("do_lower_case", False) 584 | elif args.model_name_or_path: 585 | tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir) 586 | else: 587 | raise ValueError( 588 | "You are instantiating a new {} tokenizer. This is not supported, " 589 | "but you can do it from another script, save it," 590 | "and load it from here, using --tokenizer_name".format(tokenizer_class.__name__) 591 | ) 592 | 593 | assert args.block_size <= tokenizer.model_max_length 594 | 595 | if args.model_name_or_path: 596 | model = model_class.from_pretrained( 597 | args.model_name_or_path, 598 | from_tf=bool(".ckpt" in args.model_name_or_path), 599 | config=config, 600 | cache_dir=args.cache_dir, 601 | args=args 602 | ) 603 | elif args.source_model_path or args.pretrained_ligo_path: 604 | logger.info("Growing [%s] new model from: %s", args.grow_scheme, args.source_model_path) 605 | 606 | model = model_class(config=config, args=args) 607 | 608 | if args.grow_scheme == 'none': 609 | logger.info("No initialization scheme applied. Training new model with random initialization ...") 610 | elif args.grow_scheme == 'ligo': 611 | ckpt_dir = os.path.join(args.pretrained_ligo_path, 'checkpoints') 612 | checkpoint_names = [fn for fn in os.listdir(ckpt_dir) if fn.startswith('checkpoint-')] 613 | checkpoint_names = sorted(checkpoint_names, key=lambda p: int(p.split('-')[-1])) 614 | args.pretrained_ligo_path = os.path.join(ckpt_dir, checkpoint_names[-1]) 615 | 616 | args.fuse_init_scheme_depth = args.fuse_init_scheme_width = args.fuse_init_scheme[0] 617 | if len(args.fuse_init_scheme) >= 2: 618 | args.fuse_init_scheme_width = args.fuse_init_scheme[1] 619 | args.fuse_init_noise_depth = args.fuse_init_noise_width = args.fuse_init_noise[0] 620 | if len(args.fuse_init_noise) >= 2: 621 | args.fuse_init_noise_width = args.fuse_init_noise[1] 622 | 623 | model = initialize_model_with_ligo(model, args) 624 | else: 625 | raise NotImplementedError(f'Grow method [{args.grow_scheme}] not implemented yet!') 626 | 627 | else: 628 | logger.info("Training new model from scratch") 629 | model = model_class(config=config, args=args) 630 | 631 | model.to(args.device) 632 | 633 | # End of barrier to make sure only the first process waiting other processes 634 | if gpu == 0: 635 | torch.distributed.barrier() 636 | 637 | logger.info("Training/evaluation parameters %s", args) 638 | 639 | # Training 640 | if args.do_train: 641 | # Barrier to make sure only the first process in distributed training process the dataset, 642 | # and the others will use the cache 643 | if gpu != 0: 644 | torch.distributed.barrier() 645 | train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False) 646 | if gpu == 0: 647 | torch.distributed.barrier() 648 | 649 | train(args, train_dataset, model, tokenizer) 650 | 651 | # Evaluation 652 | if args.do_eval and gpu == 0: 653 | result = evaluate(args, model, tokenizer) 654 | 655 | 656 | if __name__ == "__main__": 657 | main() 658 | -------------------------------------------------------------------------------- /ligo.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import glob 4 | import json 5 | import logging 6 | import os 7 | import pickle 8 | import random 9 | import math 10 | import re 11 | import shutil 12 | import sys 13 | from typing import Dict, List, Tuple 14 | from datetime import datetime 15 | 16 | import numpy as np 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | from torch.nn.utils.rnn import pad_sequence 21 | import torch.multiprocessing as mp 22 | from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler 23 | from torch.utils.data.distributed import DistributedSampler 24 | from tqdm import tqdm, trange 25 | from transformers import ( 26 | AutoConfig, 27 | PreTrainedModel, 28 | PreTrainedTokenizer, 29 | BertConfig, 30 | BertForMaskedLM, 31 | BertTokenizer, 32 | RobertaConfig, 33 | RobertaForMaskedLM, 34 | RobertaTokenizer, 35 | ) 36 | 37 | sys.path.append( 38 | os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 39 | ) 40 | from model import SimpleBertForMaskedLM, SimpleRobertaForMaskedLM 41 | 42 | def is_bert(model): 43 | return isinstance(model, SimpleBertForMaskedLM) 44 | 45 | def is_roberta(model): 46 | return isinstance(model, SimpleRobertaForMaskedLM) 47 | 48 | def num_layer_of(model): 49 | if is_bert(model): 50 | return len(model.bert.encoder.layer) 51 | elif is_roberta(model): 52 | return len(model.roberta.encoder.layer) 53 | else: 54 | raise NotImplementedError 55 | 56 | def is_encoder_layer(name): 57 | return name.startswith('bert.encoder.layer') or name.startswith('roberta.encoder.layer') 58 | 59 | K2N = lambda s: '__sm_' + '_'.join(s.split('.')) 60 | N2K = lambda s: '.'.join(s[5:].split('_')) 61 | 62 | def normalized_uniform_init(w, init_scheme): 63 | init_weight = torch.rand_like(w) 64 | # nn.init.uniform_(init_weight, 0.0, 1.0) 65 | if 'softmax' in init_scheme: 66 | init_weight = F.softmax(init_weight, -1) # softmax normalize 67 | else: 68 | init_weight = init_weight / torch.sum(init_weight, -1, keepdim=True) # normalize 69 | w.copy_(init_weight) 70 | 71 | def stackbert_init(w, layer_index, init_scheme, init_noise=0.03): 72 | init_weight = torch.zeros_like(w) 73 | if 'noisy' in init_scheme: 74 | init_weight.uniform_(0.0, init_noise) 75 | init_weight[layer_index % len(init_weight)] = 1. 76 | 77 | init_weight = init_weight / torch.sum(init_weight) # normalize 78 | w.copy_(init_weight) 79 | 80 | def interlace_init(w, layer_index, init_scheme, init_noise=0.03): 81 | init_weight = torch.zeros_like(w) 82 | if 'noisy' in init_scheme: 83 | init_weight.uniform_(0.0, init_noise) 84 | 85 | init_weight[layer_index // 2] = 1.0 86 | 87 | init_weight = init_weight / torch.sum(init_weight) # normalize 88 | w.copy_(init_weight) 89 | 90 | class FusedDepthParams(nn.Module): 91 | 92 | def __init__(self, layer_index, num_layers, bias=False, learnable=True, init_scheme='rand', init_noise=0.03): 93 | 94 | super(FusedDepthParams, self).__init__() 95 | 96 | assert init_scheme in ['rand', 'rand_softmax', 'stackbert', 'interlace', 'stackbert_noisy', 'interlace_noisy'] 97 | 98 | self.layer_index = layer_index 99 | self.init_scheme = init_scheme 100 | self.init_noise = init_noise 101 | 102 | if learnable: 103 | self.coeffs_weight = nn.Parameter(torch.zeros(num_layers)) 104 | if bias: 105 | self.coeffs_bias = nn.Parameter(torch.zeros(num_layers)) 106 | else: 107 | self.coeffs_bias = None 108 | else: 109 | self.register_buffer('coeffs_weight', torch.zeros(num_layers), persistent=True) 110 | if bias: 111 | self.register_buffer('coeffs_bias', torch.zeros(num_layers), persistent=True) 112 | else: 113 | self.coeffs_bias = None 114 | 115 | self.reset_parameters() 116 | 117 | def reset_parameters(self): 118 | # init depth 119 | if self.init_scheme in ['rand', 'rand_softmax']: 120 | normalized_uniform_init(self.coeffs_weight, self.init_scheme) 121 | if self.coeffs_bias is not None: 122 | normalized_uniform_init(self.coeffs_bias, self.init_scheme) 123 | elif self.init_scheme in ['stackbert', 'stackbert_noisy']: 124 | stackbert_init(self.coeffs_weight, self.layer_index, self.init_scheme, self.init_noise) 125 | if self.coeffs_bias is not None: 126 | stackbert_init(self.coeffs_bias, self.layer_index, self.init_scheme, self.init_noise) 127 | elif self.init_scheme in ['interlace', 'interlace_noisy']: 128 | interlace_init(self.coeffs_weight, self.layer_index, self.init_scheme, self.init_noise) 129 | if self.coeffs_bias is not None: 130 | interlace_init(self.coeffs_bias, self.layer_index, self.init_scheme, self.init_noise) 131 | 132 | class FusedWidthParams(nn.Module): 133 | 134 | def __init__(self, small_dim, large_dim, learnable=False, init_scheme='rand', init_noise=0.03): 135 | 136 | super(FusedWidthParams, self).__init__() 137 | 138 | assert init_scheme in ['rand', 'rand_softmax', 'sel', 'sel_noisy'] 139 | 140 | self.init_scheme = init_scheme 141 | self.init_noise = init_noise 142 | 143 | if large_dim - small_dim > 0: 144 | if learnable: 145 | self.coeffs_weight = nn.Parameter(torch.zeros(large_dim - small_dim, small_dim)) 146 | else: 147 | self.register_buffer('coeffs_weight', torch.zeros(large_dim - small_dim, small_dim)) 148 | else: 149 | self.coeffs_weight = None 150 | 151 | self.reset_parameters() 152 | 153 | def reset_parameters(self): 154 | if self.coeffs_weight is not None: 155 | if self.init_scheme in ['rand', 'rand_softmax']: 156 | normalized_uniform_init(self.coeffs_weight, self.init_scheme) 157 | elif self.init_scheme in ['sel', 'sel_noisy']: 158 | sel = torch.randint(0, self.coeffs_weight.shape[1], (self.coeffs_weight.shape[0],)) 159 | init_weight = torch.zeros_like(self.coeffs_weight, dtype=torch.float32) 160 | if 'noisy' in self.init_scheme: 161 | init_weight.uniform_(0.0, self.init_noise) 162 | init_weight[torch.arange(self.coeffs_weight.shape[0]), sel] = 1. 163 | self.coeffs_weight.copy_(init_weight) 164 | 165 | class FusedLinear(nn.Module): 166 | 167 | def __init__(self, model, module_name, in_features, out_features, layer_index=-1, init_scheme_depth='rand', init_noise_depth=0.03, learn_depth=True, 168 | init_scheme_width='rand', init_noise_width=0.03, learn_width=True, residual=False, depth_tie=None, width_in_tie=None, width_out_tie=None, residual_noise=0.01): 169 | 170 | super(FusedLinear, self).__init__() 171 | 172 | self.in_features = in_features 173 | self.out_features = out_features 174 | 175 | # weights for attention layers if depth expansion 176 | self.get_weights = lambda: getattr(model, K2N(module_name) + '_weight') 177 | self.get_bias = lambda: getattr(model, K2N(module_name) + '_bias', None) 178 | 179 | self.bias = (self.get_bias() is not None) 180 | 181 | self.residual = residual 182 | self.residual_noise = residual_noise 183 | 184 | if residual: 185 | self.residual_weight = nn.Parameter(torch.empty((self.out_features, self.in_features))) 186 | if self.bias: 187 | self.residual_bias = nn.Parameter(torch.empty(self.out_features)) 188 | else: 189 | self.register_parameter('residual_bias', None) 190 | 191 | 192 | # for embedding or classifier layer, specify layer_index to -1 193 | if layer_index >= 0: 194 | if depth_tie is None: 195 | num_layers_small = self.get_weights().shape[-1] 196 | self.fuse_depth_coeffs = FusedDepthParams(layer_index, num_layers_small, bias=self.bias, 197 | learnable=learn_depth, init_scheme=init_scheme_depth, init_noise=init_noise_depth) 198 | depth_tie = self.fuse_depth_coeffs 199 | 200 | self.get_depth_coeffs = lambda: (depth_tie.coeffs_weight, getattr(depth_tie, 'coeffs_bias', depth_tie.coeffs_weight)) 201 | 202 | if width_in_tie is None: 203 | hidden_dim_small = self.get_weights().shape[:-1] if layer_index >= 0 else self.get_weights().shape 204 | self.fuse_width_in = FusedWidthParams(hidden_dim_small[1], self.in_features, 205 | learnable=learn_width, init_scheme=init_scheme_width, init_noise=init_noise_width) 206 | width_in_tie = self.fuse_width_in 207 | 208 | 209 | if width_out_tie is None: 210 | hidden_dim_small = self.get_weights().shape[:-1] if layer_index >= 0 else self.get_weights().shape 211 | self.fuse_width_out = FusedWidthParams(hidden_dim_small[0], self.out_features, 212 | learnable=learn_width, init_scheme=init_scheme_width, init_noise=init_noise_width) 213 | width_out_tie = self.fuse_width_out 214 | 215 | self.get_width_coeffs = lambda: (width_in_tie.coeffs_weight, width_out_tie.coeffs_weight) 216 | 217 | self.reset_parameters() 218 | 219 | def reset_parameters(self): 220 | 221 | if self.residual: 222 | nn.init.uniform_(self.residual_weight, -self.residual_noise, self.residual_noise) 223 | if self.bias: 224 | nn.init.uniform_(self.residual_bias, -self.residual_noise, self.residual_noise) 225 | 226 | if hasattr(self, 'fuse_depth_coeffs'): 227 | self.fuse_depth_coeffs.reset_parameters() 228 | if hasattr(self, 'fuse_width_in'): 229 | self.fuse_width_in.reset_parameters() 230 | if hasattr(self, 'fuse_width_out'): 231 | self.fuse_width_out.reset_parameters() 232 | 233 | def get_params(self): 234 | bias = None 235 | 236 | if hasattr(self, 'get_depth_coeffs'): 237 | coeffs_weights, coeffs_bias = self.get_depth_coeffs() 238 | weight = torch.sum(self.get_weights() * coeffs_weights, -1) 239 | if self.bias: 240 | bias = torch.sum(self.get_bias() * coeffs_bias, -1) 241 | else: 242 | weight = self.get_weights() 243 | if self.bias: 244 | bias = self.get_bias() 245 | 246 | in_dim_expand, out_dim_expand = self.get_width_coeffs() 247 | if in_dim_expand is not None: 248 | in_dim_expand = torch.transpose(in_dim_expand, 0, 1) 249 | weight = torch.cat([weight, torch.matmul(weight, in_dim_expand)], 1) # expand in dimension 250 | if out_dim_expand is not None: 251 | weight = torch.cat([weight, torch.matmul(out_dim_expand, weight)], 0) # expand out dimension 252 | if self.bias: 253 | bias = torch.cat([bias, torch.matmul(out_dim_expand, bias)], 0) # expand out dimension 254 | 255 | if self.residual: 256 | weight = weight + self.residual_weight 257 | if self.bias: 258 | bias = bias + self.residual_bias 259 | 260 | return weight, bias 261 | 262 | def forward(self, input): 263 | weight, bias = self.get_params() 264 | return F.linear(input, weight, bias) 265 | 266 | def extra_repr(self) -> str: 267 | return 'in_features={}, out_features={}, bias={}'.format( 268 | self.in_features, self.out_features, self.bias is not None 269 | ) 270 | 271 | 272 | class FusedLayeredNorm(nn.Module): 273 | 274 | def __init__(self, model, module_name, normalized_shape, layer_index=-1, init_scheme_depth='rand', init_noise_depth=0.03, learn_depth=True, 275 | init_scheme_width='rand', init_noise_width=0.03, learn_width=True, eps=1e-5, residual=False, depth_tie=None, width_out_tie=None, residual_noise=0.01): 276 | 277 | super(FusedLayeredNorm, self).__init__() 278 | 279 | self.get_weights = lambda: getattr(model, K2N(module_name) + '_weight') 280 | self.get_bias = lambda: getattr(model, K2N(module_name) + '_bias', None) 281 | 282 | self.normalized_shape = normalized_shape 283 | self.elementwise_affine = True # only support elementwise_affine 284 | self.eps = eps 285 | self.residual = residual 286 | self.residual_noise = residual_noise 287 | 288 | if residual: 289 | self.residual_weight = nn.Parameter(torch.empty(self.normalized_shape)) 290 | self.residual_bias = nn.Parameter(torch.empty(self.normalized_shape)) 291 | 292 | # for embedding or classifier layer, specify layer_index to -1 293 | if layer_index >= 0: 294 | if depth_tie is None: 295 | num_layers_small = self.get_weights().shape[-1] 296 | self.fuse_depth_coeffs = FusedDepthParams(layer_index, num_layers_small, bias=True, 297 | learnable=learn_depth, init_scheme=init_scheme_depth, init_noise=init_noise_depth) 298 | depth_tie = self.fuse_depth_coeffs 299 | 300 | self.get_depth_coeffs = lambda: (depth_tie.coeffs_weight, getattr(depth_tie, 'coeffs_bias', depth_tie.coeffs_weight)) 301 | 302 | if width_out_tie is None: 303 | hidden_dim_small = self.get_weights().shape[:-1] if layer_index >= 0 else self.get_weights().shape 304 | self.fuse_width_out = FusedWidthParams(hidden_dim_small[0], self.normalized_shape[0], 305 | learnable=learn_width, init_scheme=init_scheme_width, init_noise=init_noise_width) 306 | width_out_tie = self.fuse_width_out 307 | 308 | self.get_width_coeffs = lambda: width_out_tie.coeffs_weight 309 | 310 | def reset_parameters(self): 311 | 312 | if self.residual: 313 | nn.init.uniform_(self.residual_weight, -self.residual_noise, self.residual_noise) 314 | nn.init.uniform_(self.residual_bias, -self.residual_noise, self.residual_noise) 315 | 316 | if hasattr(self, 'fuse_depth_coeffs'): 317 | self.fuse_depth_coeffs.reset_parameters() 318 | if hasattr(self, 'fuse_width_out'): 319 | self.fuse_width_out.reset_parameters() 320 | 321 | def get_params(self): 322 | 323 | if hasattr(self, 'get_depth_coeffs'): 324 | coeffs_weights, coeffs_bias = self.get_depth_coeffs() 325 | weight = torch.sum(self.get_weights() * coeffs_weights, -1) 326 | bias = torch.sum(self.get_bias() * coeffs_bias, -1) 327 | 328 | else: 329 | weight = self.get_weights() 330 | bias = self.get_bias() 331 | 332 | out_dim_expand = self.get_width_coeffs() 333 | if out_dim_expand is not None: 334 | weight = torch.cat([weight, torch.matmul(out_dim_expand, weight)], 0) # expand out dimension 335 | bias = torch.cat([bias, torch.matmul(out_dim_expand, bias)], 0) # expand out dimension 336 | 337 | if self.residual: 338 | weight = weight + self.residual_weight 339 | bias = bias + self.residual_bias 340 | 341 | return weight, bias 342 | 343 | def forward(self, input): 344 | weight, bias = self.get_params() 345 | return F.layer_norm(input, self.normalized_shape, weight, bias, self.eps) 346 | 347 | def extra_repr(self): 348 | return '{}, eps={}, elementwise_affine={}'.format(self.normalized_shape, self.eps, self.elementwise_affine) 349 | 350 | class FusedEmbedding(nn.Module): 351 | 352 | __constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', 'max_norm', 353 | 'norm_type', 'scale_grad_by_freq', 'sparse'] 354 | 355 | def __init__(self, model, module_name, num_embeddings, embedding_dim, padding_idx = None, 356 | max_norm=None, norm_type=2., scale_grad_by_freq=False, sparse=False, 357 | init_scheme_width='rand', init_noise_width=0.03, learn_width=True, 358 | residual=False, width_out_tie=None, residual_noise=0.01): 359 | 360 | super(FusedEmbedding, self).__init__() 361 | 362 | self.get_weights = lambda: getattr(model, K2N(module_name) + '_weight') 363 | 364 | self.num_embeddings = num_embeddings 365 | self.embedding_dim = embedding_dim 366 | if padding_idx is not None: 367 | if padding_idx > 0: 368 | assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings' 369 | elif padding_idx < 0: 370 | assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings' 371 | padding_idx = self.num_embeddings + padding_idx 372 | self.padding_idx = padding_idx 373 | self.max_norm = max_norm 374 | self.norm_type = norm_type 375 | self.scale_grad_by_freq = scale_grad_by_freq 376 | self.residual = residual 377 | self.residual_noise = residual_noise 378 | 379 | if residual: 380 | self.residual_weight = nn.Parameter(torch.empty((num_embeddings, embedding_dim))) 381 | 382 | if width_out_tie is None: 383 | hidden_dim_small = self.get_weights().shape 384 | self.fuse_width_out = FusedWidthParams(hidden_dim_small[1], self.embedding_dim, 385 | learnable=learn_width, init_scheme=init_scheme_width, init_noise=init_noise_width) 386 | width_out_tie = self.fuse_width_out 387 | 388 | self.get_width_coeffs = lambda: width_out_tie.coeffs_weight 389 | 390 | self.sparse = sparse 391 | 392 | def reset_parameters(self): 393 | if self.residual: 394 | nn.init.uniform_(self.residual_weight, -self.residual_noise, self.residual_noise) 395 | 396 | if hasattr(self, 'fuse_width_out'): 397 | self.fuse_width_out.reset_parameters() 398 | 399 | def get_params(self): 400 | 401 | weight = self.get_weights() 402 | 403 | out_dim_expand = self.get_width_coeffs() 404 | if out_dim_expand is not None: 405 | out_dim_expand = torch.transpose(out_dim_expand, 0, 1) 406 | weight = torch.cat([weight, torch.matmul(weight, out_dim_expand)], 1) # expand out dimension 407 | 408 | if self.residual: 409 | weight = weight + self.residual_weight 410 | 411 | return weight 412 | 413 | def forward(self, input): 414 | weight = self.get_params() 415 | return F.embedding( 416 | input, weight, self.padding_idx, self.max_norm, 417 | self.norm_type, self.scale_grad_by_freq, self.sparse) 418 | 419 | def extra_repr(self): 420 | s = '{num_embeddings}, {embedding_dim}' 421 | if self.padding_idx is not None: 422 | s += ', padding_idx={padding_idx}' 423 | if self.max_norm is not None: 424 | s += ', max_norm={max_norm}' 425 | if self.norm_type != 2: 426 | s += ', norm_type={norm_type}' 427 | if self.scale_grad_by_freq is not False: 428 | s += ', scale_grad_by_freq={scale_grad_by_freq}' 429 | if self.sparse is not False: 430 | s += ', sparse=True' 431 | return s.format(**self.__dict__) 432 | 433 | @torch.no_grad() 434 | def create_ligo_from_model(model_large, args, source_model=None): 435 | 436 | # load source model from the setting 437 | if source_model is None: 438 | assert len(args.source_model_path) == 1, 'Not support multiple model.' 439 | 440 | source_model_path = args.source_model_path[0] 441 | # Small model 442 | if source_model_path: 443 | small_config = AutoConfig.from_pretrained(source_model_path, cache_dir=args.cache_dir) 444 | else: 445 | raise ValueError("No config for small model is specified.") 446 | model_small = model_large.__class__.from_pretrained( 447 | source_model_path, 448 | from_tf=bool(".ckpt" in source_model_path), 449 | config=small_config, 450 | cache_dir=args.cache_dir, 451 | ) 452 | 453 | # directly use given model 454 | else: 455 | model_small = source_model 456 | 457 | dict_model_small = model_small.state_dict() 458 | 459 | # save map from module to name 460 | dict_M2N = {} 461 | for name, module in model_large.named_modules(): 462 | # if not name.startswith('bert.encoder.layer'): 463 | if not is_encoder_layer(name): 464 | dict_M2N[id(module)] = name 465 | else: 466 | dict_M2N[id(module)] = '.'.join(name.split('.')[4:]) 467 | M2N = lambda m: dict_M2N[id(m)] 468 | 469 | # extract parameters of embedding and classifier 470 | # NOTE use state_dict can automatically break the tie between embedding and classifier 471 | for name, param in dict_model_small.items(): 472 | if not is_encoder_layer(name): 473 | if args.tune_small_model: 474 | model_large.register_parameter(K2N(name), nn.Parameter(param, requires_grad=True)) 475 | else: 476 | model_large.register_buffer(K2N(name), param, persistent=True) 477 | 478 | # extract parameters of same module at different layers 479 | if is_bert(model_small): 480 | enc_layers = model_small.bert.encoder.layer 481 | template_key = 'bert.encoder.layer' 482 | elif is_roberta(model_small): 483 | enc_layers = model_small.roberta.encoder.layer 484 | template_key = 'roberta.encoder.layer' 485 | else: 486 | raise NotImplementedError 487 | for name, param in enc_layers[0].named_parameters(): 488 | weight_list = [] 489 | for l, _ in enumerate(enc_layers): 490 | k = f'{template_key}.{l}.{name}' 491 | weight_list.append(dict_model_small[k]) 492 | w = torch.stack(weight_list, -1) 493 | if args.tune_small_model: 494 | model_large.register_parameter(K2N(name), nn.Parameter(w, requires_grad=True)) 495 | else: 496 | model_large.register_buffer(K2N(name), w, persistent=True) 497 | 498 | def create_embed_layer(module_large, args, width_out_tie=None): 499 | return FusedEmbedding(model_large, M2N(module_large), module_large.num_embeddings, module_large.embedding_dim, 500 | padding_idx=module_large.padding_idx, max_norm=module_large.max_norm, norm_type=module_large.norm_type, 501 | scale_grad_by_freq=module_large.scale_grad_by_freq, sparse=module_large.sparse, 502 | init_scheme_width=args.fuse_init_scheme_width, init_noise_width=args.fuse_init_noise_width, learn_width=args.tune_width, 503 | residual=args.tune_residual, residual_noise=args.tune_residual_noise, width_out_tie=width_out_tie 504 | ) 505 | 506 | def create_lin_layer(module_large, args, layer_index=-1, depth_tie=None, width_in_tie=None, width_out_tie=None): 507 | return FusedLinear(model_large, M2N(module_large), module_large.in_features, module_large.out_features, 508 | layer_index=layer_index, init_scheme_depth=args.fuse_init_scheme_depth, init_noise_depth=args.fuse_init_noise_depth, learn_depth=args.tune_depth, 509 | init_scheme_width=args.fuse_init_scheme_width, init_noise_width=args.fuse_init_noise_width, learn_width=args.tune_width, 510 | residual=args.tune_residual, residual_noise=args.tune_residual_noise, depth_tie=depth_tie, width_in_tie=width_in_tie, width_out_tie=width_out_tie 511 | ) 512 | 513 | def create_ln_layer(module_large, args, layer_index=-1, depth_tie=None, width_out_tie=None): 514 | return FusedLayeredNorm(model_large, M2N(module_large), module_large.normalized_shape, eps=module_large.eps, 515 | layer_index=layer_index, init_scheme_depth=args.fuse_init_scheme_depth, init_noise_depth=args.fuse_init_noise_depth, learn_depth=args.tune_depth, 516 | init_scheme_width=args.fuse_init_scheme_width, init_noise_width=args.fuse_init_noise_width, learn_width=args.tune_width, 517 | residual=args.tune_residual, residual_noise=args.tune_residual_noise, depth_tie=depth_tie, width_out_tie=width_out_tie 518 | ) 519 | 520 | #### Bert2Bert style coefficient tying 521 | 522 | kwargs_depth_param = dict(learnable=args.tune_depth, init_scheme=args.fuse_init_scheme_depth, init_noise=args.fuse_init_noise_depth) 523 | kwargs_width_param = dict(learnable=args.tune_width, init_scheme=args.fuse_init_scheme_width, init_noise=args.fuse_init_noise_width) 524 | 525 | # Embedding module 526 | if is_bert(model_small) and is_bert(model_large): 527 | emb_small, emb_large = model_small.bert.embeddings, model_large.bert.embeddings 528 | elif is_roberta(model_small) and is_roberta(model_large): 529 | emb_small, emb_large = model_small.roberta.embeddings, model_large.roberta.embeddings 530 | else: 531 | raise NotImplementedError 532 | 533 | if args.fuse_tie_param: 534 | setattr(emb_large, 'fuse_width_emb', FusedWidthParams(emb_small.word_embeddings.weight.shape[-1], emb_large.word_embeddings.weight.shape[-1], **kwargs_width_param)) 535 | else: 536 | setattr(emb_large, 'fuse_width_emb', None) 537 | 538 | g_e = getattr(emb_large, 'fuse_width_emb') 539 | setattr(emb_large, 'word_embeddings', create_embed_layer(emb_large.word_embeddings, args, width_out_tie=g_e)) 540 | setattr(emb_large, 'position_embeddings', create_embed_layer(emb_large.position_embeddings, args, width_out_tie=g_e)) 541 | setattr(emb_large, 'token_type_embeddings', create_embed_layer(emb_large.token_type_embeddings, args, width_out_tie=g_e)) 542 | setattr(emb_large, 'LayerNorm', create_ln_layer(emb_large.LayerNorm, args, layer_index=-1, width_out_tie=g_e)) 543 | 544 | # Encoder layers 545 | gs = [] # index of selected columns 546 | if is_bert(model_small) and is_bert(model_large): 547 | small_layers, large_layers = model_small.bert.encoder.layer, model_large.bert.encoder.layer 548 | elif is_roberta(model_small) and is_roberta(model_large): 549 | small_layers, large_layers = model_small.roberta.encoder.layer, model_large.roberta.encoder.layer 550 | else: 551 | raise NotImplementedError 552 | 553 | for i, l_large in enumerate(large_layers): 554 | if args.fuse_tie_param: 555 | setattr(l_large, 'fuse_width_key', FusedWidthParams(small_layers[0].attention.self.key.weight.shape[0], l_large.attention.self.key.weight.shape[0], **kwargs_width_param)) 556 | setattr(l_large, 'fuse_width_query', FusedWidthParams(small_layers[0].attention.self.query.weight.shape[0], l_large.attention.self.query.weight.shape[0], **kwargs_width_param)) 557 | setattr(l_large, 'fuse_width_value', FusedWidthParams(small_layers[0].attention.self.value.weight.shape[0], l_large.attention.self.value.weight.shape[0], **kwargs_width_param)) 558 | setattr(l_large, 'fuse_width_ffn', FusedWidthParams(small_layers[0].intermediate.dense.weight.shape[0], l_large.intermediate.dense.weight.shape[0], **kwargs_width_param)) 559 | else: 560 | setattr(l_large, 'fuse_width_key', None) 561 | setattr(l_large, 'fuse_width_query', None) 562 | setattr(l_large, 'fuse_width_value', None) 563 | setattr(l_large, 'fuse_width_ffn', None) 564 | 565 | # MHA - Attention 566 | attn_large = l_large.attention.self 567 | for name in ['query', 'key', 'value']: 568 | setattr(attn_large, name, create_lin_layer(getattr(attn_large, name), args, layer_index=i, width_in_tie=g_e, width_out_tie=getattr(l_large, f'fuse_width_{name}'))) 569 | 570 | # MHA - W_o 571 | setattr(l_large.attention.output, 'dense', create_lin_layer(l_large.attention.output.dense, args, layer_index=i, width_in_tie=getattr(l_large, f'fuse_width_value'), width_out_tie=g_e)) 572 | 573 | # MHA - LayerNorm 574 | setattr(l_large.attention.output, 'LayerNorm', create_ln_layer(l_large.attention.output.LayerNorm, args, layer_index=i, width_out_tie=g_e)) 575 | 576 | # FFN - Layer 1 577 | setattr(l_large.intermediate, 'dense', create_lin_layer(l_large.intermediate.dense, args, layer_index=i, width_in_tie=g_e, width_out_tie=getattr(l_large, 'fuse_width_ffn'))) 578 | 579 | # FFN - Layer 2 580 | setattr(l_large.output, 'dense', create_lin_layer(l_large.output.dense, args, layer_index=i, width_in_tie=getattr(l_large, 'fuse_width_ffn'), width_out_tie=g_e)) 581 | 582 | # FFN LayerNorm 583 | setattr(l_large.output, 'LayerNorm', create_ln_layer(l_large.output.LayerNorm, args, layer_index=i, width_out_tie=g_e)) 584 | 585 | # Classifier 586 | if is_bert(model_small) and is_bert(model_large): 587 | cls_small, cls_large = model_small.cls.predictions, model_large.cls.predictions 588 | if args.fuse_tie_param: 589 | setattr(cls_large, 'fuse_width_cls', FusedWidthParams(cls_small.transform.dense.weight.shape[0], cls_large.transform.dense.weight.shape[0], **kwargs_width_param)) 590 | else: 591 | setattr(cls_large, 'fuse_width_cls', None) 592 | 593 | setattr(cls_large.transform, 'dense', create_lin_layer(cls_large.transform.dense, args, layer_index=-1, width_in_tie=g_e, width_out_tie=getattr(cls_large, 'fuse_width_cls'))) 594 | setattr(cls_large.transform, 'LayerNorm', create_ln_layer(cls_large.transform.LayerNorm, args, layer_index=-1, width_out_tie=getattr(cls_large, 'fuse_width_cls'))) 595 | setattr(cls_large, 'decoder', create_lin_layer(cls_large.decoder, args, layer_index=-1, width_in_tie=getattr(cls_large, 'fuse_width_cls'), width_out_tie=None)) 596 | 597 | elif is_roberta(model_small) and is_roberta(model_large): 598 | cls_small, cls_large = model_small.lm_head, model_large.lm_head 599 | if args.fuse_tie_param: 600 | setattr(cls_large, 'fuse_width_cls', FusedWidthParams(cls_small.dense.weight.shape[0], cls_large.dense.weight.shape[0], **kwargs_width_param)) 601 | else: 602 | setattr(cls_large, 'fuse_width_cls', None) 603 | 604 | setattr(cls_large, 'dense', create_lin_layer(cls_large.dense, args, layer_index=-1, width_in_tie=g_e, width_out_tie=getattr(cls_large, 'fuse_width_cls'))) 605 | setattr(cls_large, 'layer_norm', create_ln_layer(cls_large.layer_norm, args, layer_index=-1, width_out_tie=getattr(cls_large, 'fuse_width_cls'))) 606 | setattr(cls_large, 'decoder', create_lin_layer(cls_large.decoder, args, layer_index=-1, width_in_tie=getattr(cls_large, 'fuse_width_cls'), width_out_tie=None)) 607 | 608 | else: 609 | raise NotImplementedError 610 | 611 | return model_large 612 | 613 | 614 | @torch.no_grad() 615 | def initialize_model_with_ligo(model_large, args): 616 | 617 | # load coefficient model 618 | coeff_model_path = args.pretrained_ligo_path 619 | dict_model_coeff = torch.load(os.path.join(coeff_model_path, 'pytorch_model.bin'), map_location=torch.device('cpu')) 620 | model_coeff = model_large.__class__(config=model_large.config, args=args) 621 | model_coeff = create_ligo_from_model(model_coeff, args) 622 | model_coeff.load_state_dict(dict_model_coeff) 623 | 624 | modules_coeff = {name:module for name, module in model_coeff.named_modules()} 625 | for name, module in model_large.named_modules(): 626 | if isinstance(module, (nn.Linear, nn.LayerNorm)): 627 | module.weight.copy_(modules_coeff[name].get_params()[0]) 628 | if hasattr(module, 'bias'): 629 | module.bias.copy_(modules_coeff[name].get_params()[1]) 630 | elif isinstance(module, nn.Embedding): 631 | module.weight.copy_(modules_coeff[name].get_params()) 632 | 633 | return model_large 634 | -------------------------------------------------------------------------------- /bert_model.py: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | import os 4 | import warnings 5 | from dataclasses import dataclass 6 | from typing import Optional, Tuple 7 | 8 | import torch 9 | import torch.utils.checkpoint 10 | from torch import nn 11 | from torch.nn import CrossEntropyLoss, MSELoss 12 | 13 | from transformers.activations import ACT2FN 14 | from transformers.file_utils import ( 15 | ModelOutput, 16 | add_code_sample_docstrings, 17 | add_start_docstrings, 18 | add_start_docstrings_to_model_forward, 19 | replace_return_docstrings, 20 | ) 21 | from transformers.modeling_outputs import ( 22 | BaseModelOutputWithPastAndCrossAttentions, 23 | BaseModelOutputWithPoolingAndCrossAttentions, 24 | CausalLMOutputWithCrossAttentions, 25 | MaskedLMOutput, 26 | MultipleChoiceModelOutput, 27 | NextSentencePredictorOutput, 28 | QuestionAnsweringModelOutput, 29 | SequenceClassifierOutput, 30 | TokenClassifierOutput, 31 | ) 32 | from transformers.modeling_utils import ( 33 | PreTrainedModel, 34 | apply_chunking_to_forward, 35 | find_pruneable_heads_and_indices, 36 | prune_linear_layer, 37 | ) 38 | from transformers.utils import logging 39 | from transformers import BertConfig 40 | from flop_computation import get_flops_computer 41 | 42 | 43 | logger = logging.get_logger(__name__) 44 | 45 | _CONFIG_FOR_DOC = "BertConfig" 46 | _TOKENIZER_FOR_DOC = "BertTokenizer" 47 | 48 | BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ 49 | "bert-base-uncased", 50 | "bert-large-uncased", 51 | "bert-base-cased", 52 | "bert-large-cased", 53 | "bert-base-multilingual-uncased", 54 | "bert-base-multilingual-cased", 55 | "bert-base-chinese", 56 | "bert-base-german-cased", 57 | "bert-large-uncased-whole-word-masking", 58 | "bert-large-cased-whole-word-masking", 59 | "bert-large-uncased-whole-word-masking-finetuned-squad", 60 | "bert-large-cased-whole-word-masking-finetuned-squad", 61 | "bert-base-cased-finetuned-mrpc", 62 | "bert-base-german-dbmdz-cased", 63 | "bert-base-german-dbmdz-uncased", 64 | "cl-tohoku/bert-base-japanese", 65 | "cl-tohoku/bert-base-japanese-whole-word-masking", 66 | "cl-tohoku/bert-base-japanese-char", 67 | "cl-tohoku/bert-base-japanese-char-whole-word-masking", 68 | "TurkuNLP/bert-base-finnish-cased-v1", 69 | "TurkuNLP/bert-base-finnish-uncased-v1", 70 | "wietsedv/bert-base-dutch-cased", 71 | # See all BERT models at https://huggingface.co/models?filter=bert 72 | ] 73 | 74 | 75 | def load_tf_weights_in_bert(model, config, tf_checkpoint_path): 76 | """Load tf checkpoints in a pytorch model.""" 77 | try: 78 | import re 79 | 80 | import numpy as np 81 | import tensorflow as tf 82 | except ImportError: 83 | logger.error( 84 | "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " 85 | "https://www.tensorflow.org/install/ for installation instructions." 86 | ) 87 | raise 88 | tf_path = os.path.abspath(tf_checkpoint_path) 89 | logger.info("Converting TensorFlow checkpoint from {}".format(tf_path)) 90 | # Load weights from TF model 91 | init_vars = tf.train.list_variables(tf_path) 92 | names = [] 93 | arrays = [] 94 | for name, shape in init_vars: 95 | logger.info("Loading TF weight {} with shape {}".format(name, shape)) 96 | array = tf.train.load_variable(tf_path, name) 97 | names.append(name) 98 | arrays.append(array) 99 | 100 | for name, array in zip(names, arrays): 101 | name = name.split("/") 102 | # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v 103 | # which are not required for using pretrained model 104 | if any( 105 | n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] 106 | for n in name 107 | ): 108 | logger.info("Skipping {}".format("/".join(name))) 109 | continue 110 | pointer = model 111 | for m_name in name: 112 | if re.fullmatch(r"[A-Za-z]+_\d+", m_name): 113 | scope_names = re.split(r"_(\d+)", m_name) 114 | else: 115 | scope_names = [m_name] 116 | if scope_names[0] == "kernel" or scope_names[0] == "gamma": 117 | pointer = getattr(pointer, "weight") 118 | elif scope_names[0] == "output_bias" or scope_names[0] == "beta": 119 | pointer = getattr(pointer, "bias") 120 | elif scope_names[0] == "output_weights": 121 | pointer = getattr(pointer, "weight") 122 | elif scope_names[0] == "squad": 123 | pointer = getattr(pointer, "classifier") 124 | else: 125 | try: 126 | pointer = getattr(pointer, scope_names[0]) 127 | except AttributeError: 128 | logger.info("Skipping {}".format("/".join(name))) 129 | continue 130 | if len(scope_names) >= 2: 131 | num = int(scope_names[1]) 132 | pointer = pointer[num] 133 | if m_name[-11:] == "_embeddings": 134 | pointer = getattr(pointer, "weight") 135 | elif m_name == "kernel": 136 | array = np.transpose(array) 137 | try: 138 | assert ( 139 | pointer.shape == array.shape 140 | ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" 141 | except AssertionError as e: 142 | e.args += (pointer.shape, array.shape) 143 | raise 144 | logger.info("Initialize PyTorch weight {}".format(name)) 145 | pointer.data = torch.from_numpy(array) 146 | return model 147 | 148 | 149 | class BertEmbeddings(nn.Module): 150 | """Construct the embeddings from word, position and token_type embeddings.""" 151 | 152 | def __init__(self, config): 153 | super().__init__() 154 | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) 155 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) 156 | self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) 157 | 158 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load 159 | # any TensorFlow checkpoint file 160 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 161 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 162 | 163 | # position_ids (1, len position emb) is contiguous in memory and exported when serialized 164 | self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) 165 | self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") 166 | 167 | def forward( 168 | self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 169 | ): 170 | if input_ids is not None: 171 | input_shape = input_ids.size() 172 | else: 173 | input_shape = inputs_embeds.size()[:-1] 174 | 175 | seq_length = input_shape[1] 176 | 177 | if position_ids is None: 178 | position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] 179 | 180 | if token_type_ids is None: 181 | token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) 182 | 183 | if inputs_embeds is None: 184 | inputs_embeds = self.word_embeddings(input_ids) 185 | token_type_embeddings = self.token_type_embeddings(token_type_ids) 186 | 187 | embeddings = inputs_embeds + token_type_embeddings 188 | if self.position_embedding_type == "absolute": 189 | position_embeddings = self.position_embeddings(position_ids) 190 | embeddings += position_embeddings 191 | embeddings = self.LayerNorm(embeddings) 192 | embeddings = self.dropout(embeddings) 193 | return embeddings 194 | 195 | 196 | class BertSelfAttention(nn.Module): 197 | def __init__(self, config): 198 | super().__init__() 199 | if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): 200 | raise ValueError( 201 | "The hidden size (%d) is not a multiple of the number of attention " 202 | "heads (%d)" % (config.hidden_size, config.num_attention_heads) 203 | ) 204 | 205 | self.num_attention_heads = config.num_attention_heads 206 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 207 | self.all_head_size = self.num_attention_heads * self.attention_head_size 208 | 209 | self.query = nn.Linear(config.hidden_size, self.all_head_size) 210 | self.key = nn.Linear(config.hidden_size, self.all_head_size) 211 | self.value = nn.Linear(config.hidden_size, self.all_head_size) 212 | 213 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 214 | self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") 215 | if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": 216 | self.max_position_embeddings = config.max_position_embeddings 217 | self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) 218 | 219 | self.is_decoder = config.is_decoder 220 | 221 | def transpose_for_scores(self, x): 222 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 223 | x = x.view(*new_x_shape) 224 | return x.permute(0, 2, 1, 3) 225 | 226 | def forward( 227 | self, 228 | hidden_states, 229 | attention_mask=None, 230 | head_mask=None, 231 | encoder_hidden_states=None, 232 | encoder_attention_mask=None, 233 | past_key_value=None, 234 | output_attentions=False, 235 | ): 236 | mixed_query_layer = self.query(hidden_states) 237 | 238 | # If this is instantiated as a cross-attention module, the keys 239 | # and values come from an encoder; the attention mask needs to be 240 | # such that the encoder's padding tokens are not attended to. 241 | is_cross_attention = encoder_hidden_states is not None 242 | 243 | if is_cross_attention and past_key_value is not None: 244 | # reuse k,v, cross_attentions 245 | key_layer = past_key_value[0] 246 | value_layer = past_key_value[1] 247 | attention_mask = encoder_attention_mask 248 | elif is_cross_attention: 249 | key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) 250 | value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) 251 | attention_mask = encoder_attention_mask 252 | elif past_key_value is not None: 253 | key_layer = self.transpose_for_scores(self.key(hidden_states)) 254 | value_layer = self.transpose_for_scores(self.value(hidden_states)) 255 | key_layer = torch.cat([past_key_value[0], key_layer], dim=2) 256 | value_layer = torch.cat([past_key_value[1], value_layer], dim=2) 257 | else: 258 | key_layer = self.transpose_for_scores(self.key(hidden_states)) 259 | value_layer = self.transpose_for_scores(self.value(hidden_states)) 260 | 261 | query_layer = self.transpose_for_scores(mixed_query_layer) 262 | 263 | if self.is_decoder: 264 | # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. 265 | # Further calls to cross_attention layer can then reuse all cross-attention 266 | # key/value_states (first "if" case) 267 | # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of 268 | # all previous decoder key/value_states. Further calls to uni-directional self-attention 269 | # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) 270 | # if encoder bi-directional self-attention `past_key_value` is always `None` 271 | past_key_value = (key_layer, value_layer) 272 | 273 | # Take the dot product between "query" and "key" to get the raw attention scores. 274 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 275 | 276 | if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": 277 | seq_length = hidden_states.size()[1] 278 | position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) 279 | position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) 280 | distance = position_ids_l - position_ids_r 281 | positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) 282 | positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility 283 | 284 | if self.position_embedding_type == "relative_key": 285 | relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) 286 | attention_scores = attention_scores + relative_position_scores 287 | elif self.position_embedding_type == "relative_key_query": 288 | relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) 289 | relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) 290 | attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key 291 | 292 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 293 | if attention_mask is not None: 294 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 295 | attention_scores = attention_scores + attention_mask 296 | 297 | # Normalize the attention scores to probabilities. 298 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 299 | 300 | # This is actually dropping out entire tokens to attend to, which might 301 | # seem a bit unusual, but is taken from the original Transformer paper. 302 | attention_probs = self.dropout(attention_probs) 303 | 304 | # Mask heads if we want to 305 | if head_mask is not None: 306 | attention_probs = attention_probs * head_mask 307 | 308 | context_layer = torch.matmul(attention_probs, value_layer) 309 | 310 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 311 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 312 | context_layer = context_layer.view(*new_context_layer_shape) 313 | 314 | outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) 315 | 316 | if self.is_decoder: 317 | outputs = outputs + (past_key_value,) 318 | return outputs 319 | 320 | 321 | class BertSelfOutput(nn.Module): 322 | def __init__(self, config): 323 | super().__init__() 324 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 325 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 326 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 327 | 328 | def forward(self, hidden_states, input_tensor): 329 | hidden_states = self.dense(hidden_states) 330 | hidden_states = self.dropout(hidden_states) 331 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 332 | return hidden_states 333 | 334 | 335 | class BertAttention(nn.Module): 336 | def __init__(self, config): 337 | super().__init__() 338 | self.self = BertSelfAttention(config) 339 | self.output = BertSelfOutput(config) 340 | self.pruned_heads = set() 341 | 342 | def prune_heads(self, heads): 343 | if len(heads) == 0: 344 | return 345 | heads, index = find_pruneable_heads_and_indices( 346 | heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads 347 | ) 348 | 349 | # Prune linear layers 350 | self.self.query = prune_linear_layer(self.self.query, index) 351 | self.self.key = prune_linear_layer(self.self.key, index) 352 | self.self.value = prune_linear_layer(self.self.value, index) 353 | self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) 354 | 355 | # Update hyper params and store pruned heads 356 | self.self.num_attention_heads = self.self.num_attention_heads - len(heads) 357 | self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads 358 | self.pruned_heads = self.pruned_heads.union(heads) 359 | 360 | def forward( 361 | self, 362 | hidden_states, 363 | attention_mask=None, 364 | head_mask=None, 365 | encoder_hidden_states=None, 366 | encoder_attention_mask=None, 367 | past_key_value=None, 368 | output_attentions=False, 369 | ): 370 | self_outputs = self.self( 371 | hidden_states, 372 | attention_mask, 373 | head_mask, 374 | encoder_hidden_states, 375 | encoder_attention_mask, 376 | past_key_value, 377 | output_attentions, 378 | ) 379 | attention_output = self.output(self_outputs[0], hidden_states) 380 | outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them 381 | return outputs 382 | 383 | 384 | class BertIntermediate(nn.Module): 385 | def __init__(self, config): 386 | super().__init__() 387 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size) 388 | if isinstance(config.hidden_act, str): 389 | self.intermediate_act_fn = ACT2FN[config.hidden_act] 390 | else: 391 | self.intermediate_act_fn = config.hidden_act 392 | 393 | def forward(self, hidden_states): 394 | hidden_states = self.dense(hidden_states) 395 | hidden_states = self.intermediate_act_fn(hidden_states) 396 | return hidden_states 397 | 398 | 399 | class BertOutput(nn.Module): 400 | def __init__(self, config): 401 | super().__init__() 402 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size) 403 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 404 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 405 | 406 | def forward(self, hidden_states, input_tensor): 407 | hidden_states = self.dense(hidden_states) 408 | hidden_states = self.dropout(hidden_states) 409 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 410 | return hidden_states 411 | 412 | 413 | class BertLayer(nn.Module): 414 | def __init__(self, config): 415 | super().__init__() 416 | self.chunk_size_feed_forward = config.chunk_size_feed_forward 417 | self.seq_len_dim = 1 418 | self.attention = BertAttention(config) 419 | self.is_decoder = config.is_decoder 420 | self.add_cross_attention = config.add_cross_attention 421 | if self.add_cross_attention: 422 | assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added" 423 | self.crossattention = BertAttention(config) 424 | self.intermediate = BertIntermediate(config) 425 | self.output = BertOutput(config) 426 | 427 | def forward( 428 | self, 429 | hidden_states, 430 | attention_mask=None, 431 | head_mask=None, 432 | encoder_hidden_states=None, 433 | encoder_attention_mask=None, 434 | past_key_value=None, 435 | output_attentions=False, 436 | ): 437 | # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 438 | self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None 439 | self_attention_outputs = self.attention( 440 | hidden_states, 441 | attention_mask, 442 | head_mask, 443 | output_attentions=output_attentions, 444 | past_key_value=self_attn_past_key_value, 445 | ) 446 | attention_output = self_attention_outputs[0] 447 | 448 | # if decoder, the last output is tuple of self-attn cache 449 | if self.is_decoder: 450 | outputs = self_attention_outputs[1:-1] 451 | present_key_value = self_attention_outputs[-1] 452 | else: 453 | outputs = self_attention_outputs[1:] # add self attentions if we output attention weights 454 | 455 | cross_attn_present_key_value = None 456 | if self.is_decoder and encoder_hidden_states is not None: 457 | assert hasattr( 458 | self, "crossattention" 459 | ), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`" 460 | 461 | # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple 462 | cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None 463 | cross_attention_outputs = self.crossattention( 464 | attention_output, 465 | attention_mask, 466 | head_mask, 467 | encoder_hidden_states, 468 | encoder_attention_mask, 469 | cross_attn_past_key_value, 470 | output_attentions, 471 | ) 472 | attention_output = cross_attention_outputs[0] 473 | outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights 474 | 475 | # add cross-attn cache to positions 3,4 of present_key_value tuple 476 | cross_attn_present_key_value = cross_attention_outputs[-1] 477 | present_key_value = present_key_value + cross_attn_present_key_value 478 | 479 | layer_output = apply_chunking_to_forward( 480 | self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output 481 | ) 482 | outputs = (layer_output,) + outputs 483 | 484 | # if decoder, return the attn key/values as the last output 485 | if self.is_decoder: 486 | outputs = outputs + (present_key_value,) 487 | 488 | return outputs 489 | 490 | def feed_forward_chunk(self, attention_output): 491 | intermediate_output = self.intermediate(attention_output) 492 | layer_output = self.output(intermediate_output, attention_output) 493 | return layer_output 494 | 495 | 496 | class BertEncoder(nn.Module): 497 | def __init__(self, config, args=None): 498 | super().__init__() 499 | self.config = config 500 | self.args = args 501 | self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)]) 502 | 503 | self.layer_drop = getattr(args, 'layer_drop') if args is not None else False 504 | self.layer_drop_rate = getattr(args, 'layer_drop_rate') if args is not None else 0.0 505 | self.layer_drop_lin_decay = getattr(args, 'layer_drop_lin_decay') if args is not None else False 506 | 507 | self.token_drop = getattr(args, 'token_drop') if args is not None else False 508 | self.token_drop_rate = getattr(args, 'token_drop_rate') if args is not None else 0. 509 | self.token_drop_start = getattr(args, 'token_drop_start') if args is not None else -1 510 | self.token_drop_end = getattr(args, 'token_drop_end') if args is not None else -1 511 | 512 | def forward( 513 | self, 514 | hidden_states, 515 | attention_mask=None, 516 | head_mask=None, 517 | encoder_hidden_states=None, 518 | encoder_attention_mask=None, 519 | past_key_values=None, 520 | use_cache=None, 521 | output_attentions=False, 522 | output_hidden_states=False, 523 | return_dict=True, 524 | ): 525 | all_hidden_states = () if output_hidden_states else None 526 | all_self_attentions = () if output_attentions else None 527 | all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None 528 | 529 | token_dropped_indices = None 530 | if self.token_drop: 531 | num_tokens = hidden_states.shape[1] 532 | token_mask = (torch.rand(num_tokens) <= self.token_drop_rate) 533 | indices_kept, indices_dropped = torch.where(token_mask), torch.where(~token_mask) 534 | _, indices_inverse = torch.sort(torch.cat([indices_kept, indices_dropped], 0)) 535 | 536 | 537 | last_layer_states = None 538 | next_decoder_cache = () if use_cache else None 539 | for i, layer_module in enumerate(self.layer): 540 | if output_hidden_states: 541 | all_hidden_states = all_hidden_states + (hidden_states,) 542 | 543 | if self.token_drop and i == self.token_drop_start: 544 | hidden_states, tokens_dropped = hidden_states[:, indices_kept, :], hidden_states[:, indices_dropped, :] 545 | if self.token_drop and i == self.token_drop_end: 546 | hidden_states = torch.cat([hidden_states, tokens_dropped], 1)[indices_inverse] 547 | 548 | layer_head_mask = head_mask[i] if head_mask is not None else None 549 | past_key_value = past_key_values[i] if past_key_values is not None else None 550 | if getattr(self.config, "gradient_checkpointing", False): 551 | 552 | def create_custom_forward(module): 553 | def custom_forward(*inputs): 554 | return module(*inputs, past_key_value, output_attentions) 555 | 556 | return custom_forward 557 | 558 | layer_outputs = torch.utils.checkpoint.checkpoint( 559 | create_custom_forward(layer_module), 560 | hidden_states, 561 | attention_mask, 562 | layer_head_mask, 563 | encoder_hidden_states, 564 | encoder_attention_mask, 565 | ) 566 | else: 567 | layer_outputs = layer_module( 568 | hidden_states, 569 | attention_mask, 570 | layer_head_mask, 571 | encoder_hidden_states, 572 | encoder_attention_mask, 573 | past_key_value, 574 | output_attentions, 575 | ) 576 | 577 | hidden_states = layer_outputs[0] 578 | 579 | if self.layer_drop and i != 0: # no dropping for the first layer 580 | drop_rate = self.layer_drop_rate 581 | if self.layer_drop_lin_decay: 582 | drop_rate *= i / len(self.layer) 583 | 584 | if self.training: 585 | mask = torch.bernoulli(torch.full((hidden_states.shape[0],), drop_rate, device=hidden_states.device)) # [bs] 586 | num_drops = (mask == 1).sum().item() 587 | mask = mask[:, None, None].expand(*hidden_states.shape) # [bs, seq_len, hidden_dim] 588 | hidden_states = hidden_states * (1. - mask) + last_layer_states * mask 589 | else: 590 | hidden_states = hidden_states * (1. - drop_rate) + last_layer_states * drop_rate 591 | 592 | if use_cache: 593 | next_decoder_cache += (layer_outputs[-1],) 594 | if output_attentions: 595 | all_self_attentions = all_self_attentions + (layer_outputs[1],) 596 | if self.config.add_cross_attention: 597 | all_cross_attentions = all_cross_attentions + (layer_outputs[2],) 598 | 599 | # record last layer output 600 | last_layer_states = hidden_states 601 | 602 | if output_hidden_states: 603 | all_hidden_states = all_hidden_states + (hidden_states,) 604 | 605 | if not return_dict: 606 | return tuple( 607 | v 608 | for v in [ 609 | hidden_states, 610 | next_decoder_cache, 611 | all_hidden_states, 612 | all_self_attentions, 613 | all_cross_attentions 614 | ] 615 | if v is not None 616 | ) 617 | return BaseModelOutputWithPastAndCrossAttentions( 618 | last_hidden_state=hidden_states, 619 | past_key_values=next_decoder_cache, 620 | hidden_states=all_hidden_states, 621 | attentions=all_self_attentions, 622 | cross_attentions=all_cross_attentions 623 | ) 624 | 625 | 626 | class BertPooler(nn.Module): 627 | def __init__(self, config): 628 | super().__init__() 629 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 630 | self.activation = nn.Tanh() 631 | 632 | def forward(self, hidden_states): 633 | # We "pool" the model by simply taking the hidden state corresponding 634 | # to the first token. 635 | first_token_tensor = hidden_states[:, 0] 636 | pooled_output = self.dense(first_token_tensor) 637 | pooled_output = self.activation(pooled_output) 638 | return pooled_output 639 | 640 | 641 | class BertPredictionHeadTransform(nn.Module): 642 | def __init__(self, config): 643 | super().__init__() 644 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 645 | if isinstance(config.hidden_act, str): 646 | self.transform_act_fn = ACT2FN[config.hidden_act] 647 | else: 648 | self.transform_act_fn = config.hidden_act 649 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 650 | 651 | def forward(self, hidden_states): 652 | hidden_states = self.dense(hidden_states) 653 | hidden_states = self.transform_act_fn(hidden_states) 654 | hidden_states = self.LayerNorm(hidden_states) 655 | return hidden_states 656 | 657 | 658 | class BertLMPredictionHead(nn.Module): 659 | def __init__(self, config): 660 | super().__init__() 661 | self.transform = BertPredictionHeadTransform(config) 662 | 663 | # The output weights are the same as the input embeddings, but there is 664 | # an output-only bias for each token. 665 | self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 666 | 667 | self.bias = nn.Parameter(torch.zeros(config.vocab_size)) 668 | 669 | # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` 670 | self.decoder.bias = self.bias 671 | 672 | def forward(self, hidden_states): 673 | hidden_states = self.transform(hidden_states) 674 | hidden_states = self.decoder(hidden_states) 675 | return hidden_states 676 | 677 | 678 | class BertOnlyMLMHead(nn.Module): 679 | def __init__(self, config): 680 | super().__init__() 681 | self.predictions = BertLMPredictionHead(config) 682 | 683 | def forward(self, sequence_output): 684 | prediction_scores = self.predictions(sequence_output) 685 | return prediction_scores 686 | 687 | 688 | class BertPreTrainedModel(PreTrainedModel): 689 | """ 690 | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained 691 | models. 692 | """ 693 | 694 | config_class = BertConfig 695 | load_tf_weights = load_tf_weights_in_bert 696 | base_model_prefix = "bert" 697 | _keys_to_ignore_on_load_missing = [r"position_ids"] 698 | 699 | def _init_weights(self, module): 700 | """ Initialize the weights """ 701 | if isinstance(module, (nn.Linear, nn.Embedding)): 702 | # Slightly different from the TF version which uses truncated_normal for initialization 703 | # cf https://github.com/pytorch/pytorch/pull/5617 704 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 705 | elif isinstance(module, nn.LayerNorm): 706 | module.bias.data.zero_() 707 | module.weight.data.fill_(1.0) 708 | if isinstance(module, nn.Linear) and module.bias is not None: 709 | module.bias.data.zero_() 710 | 711 | 712 | class BertModel(BertPreTrainedModel): 713 | """ 714 | 715 | The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of 716 | cross-attention is added between the self-attention layers, following the architecture described in `Attention is 717 | all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, 718 | Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. 719 | 720 | To behave as an decoder the model needs to be initialized with the :obj:`is_decoder` argument of the configuration 721 | set to :obj:`True`. To be used in a Seq2Seq model, the model needs to initialized with both :obj:`is_decoder` 722 | argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an 723 | input to the forward pass. 724 | """ 725 | 726 | def __init__(self, config, args=None, add_pooling_layer=True): 727 | super().__init__(config) 728 | self.config = config 729 | 730 | self.embeddings = BertEmbeddings(config) 731 | self.encoder = BertEncoder(config, args) 732 | 733 | self.pooler = BertPooler(config) if add_pooling_layer else None 734 | 735 | self.init_weights() 736 | 737 | def get_input_embeddings(self): 738 | return self.embeddings.word_embeddings 739 | 740 | def set_input_embeddings(self, value): 741 | self.embeddings.word_embeddings = value 742 | 743 | def _prune_heads(self, heads_to_prune): 744 | """ 745 | Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base 746 | class PreTrainedModel 747 | """ 748 | for layer, heads in heads_to_prune.items(): 749 | self.encoder.layer[layer].attention.prune_heads(heads) 750 | 751 | def forward( 752 | self, 753 | input_ids=None, 754 | attention_mask=None, 755 | token_type_ids=None, 756 | position_ids=None, 757 | head_mask=None, 758 | inputs_embeds=None, 759 | encoder_hidden_states=None, 760 | encoder_attention_mask=None, 761 | past_key_values=None, 762 | use_cache=None, 763 | output_attentions=None, 764 | output_hidden_states=None, 765 | return_dict=None, 766 | ): 767 | r""" 768 | encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): 769 | Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if 770 | the model is configured as a decoder. 771 | encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 772 | Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in 773 | the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: 774 | 775 | - 1 for tokens that are **not masked**, 776 | - 0 for tokens that are **masked**. 777 | past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): 778 | Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. 779 | 780 | If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` 781 | (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` 782 | instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. 783 | use_cache (:obj:`bool`, `optional`): 784 | If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up 785 | decoding (see :obj:`past_key_values`). 786 | """ 787 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 788 | output_hidden_states = ( 789 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 790 | ) 791 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 792 | 793 | if self.config.is_decoder: 794 | use_cache = use_cache if use_cache is not None else self.config.use_cache 795 | else: 796 | use_cache = False 797 | 798 | if input_ids is not None and inputs_embeds is not None: 799 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") 800 | elif input_ids is not None: 801 | input_shape = input_ids.size() 802 | batch_size, seq_length = input_shape 803 | elif inputs_embeds is not None: 804 | input_shape = inputs_embeds.size()[:-1] 805 | batch_size, seq_length = input_shape 806 | else: 807 | raise ValueError("You have to specify either input_ids or inputs_embeds") 808 | 809 | device = input_ids.device if input_ids is not None else inputs_embeds.device 810 | 811 | # past_key_values_length 812 | past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 813 | 814 | if attention_mask is None: 815 | attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) 816 | if token_type_ids is None: 817 | token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) 818 | 819 | # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] 820 | # ourselves in which case we just need to make it broadcastable to all heads. 821 | extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device) 822 | 823 | # If a 2D or 3D attention mask is provided for the cross-attention 824 | # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] 825 | if self.config.is_decoder and encoder_hidden_states is not None: 826 | encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() 827 | encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) 828 | if encoder_attention_mask is None: 829 | encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) 830 | encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) 831 | else: 832 | encoder_extended_attention_mask = None 833 | 834 | # Prepare head mask if needed 835 | # 1.0 in head_mask indicate we keep the head 836 | # attention_probs has shape bsz x n_heads x N x N 837 | # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] 838 | # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] 839 | head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) 840 | 841 | embedding_output = self.embeddings( 842 | input_ids=input_ids, 843 | position_ids=position_ids, 844 | token_type_ids=token_type_ids, 845 | inputs_embeds=inputs_embeds, 846 | past_key_values_length=past_key_values_length, 847 | ) 848 | encoder_outputs = self.encoder( 849 | embedding_output, 850 | attention_mask=extended_attention_mask, 851 | head_mask=head_mask, 852 | encoder_hidden_states=encoder_hidden_states, 853 | encoder_attention_mask=encoder_extended_attention_mask, 854 | past_key_values=past_key_values, 855 | use_cache=use_cache, 856 | output_attentions=output_attentions, 857 | output_hidden_states=output_hidden_states, 858 | return_dict=return_dict, 859 | ) 860 | sequence_output = encoder_outputs[0] 861 | pooled_output = self.pooler(sequence_output) if self.pooler is not None else None 862 | 863 | if not return_dict: 864 | return (sequence_output, pooled_output) + encoder_outputs[1:] 865 | 866 | return BaseModelOutputWithPoolingAndCrossAttentions( 867 | last_hidden_state=sequence_output, 868 | pooler_output=pooled_output, 869 | past_key_values=encoder_outputs.past_key_values, 870 | hidden_states=encoder_outputs.hidden_states, 871 | attentions=encoder_outputs.attentions, 872 | cross_attentions=encoder_outputs.cross_attentions, 873 | ) 874 | 875 | class BertForMaskedLM(BertPreTrainedModel): 876 | 877 | _keys_to_ignore_on_load_unexpected = [r"pooler"] 878 | _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] 879 | 880 | def __init__(self, config, args=None): 881 | super().__init__(config) 882 | if config.is_decoder: 883 | logger.warning( 884 | "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for " 885 | "bi-directional self-attention." 886 | ) 887 | 888 | self.bert = BertModel(config, args=args, add_pooling_layer=False) 889 | self.cls = BertOnlyMLMHead(config) 890 | 891 | self.init_weights() 892 | 893 | def get_output_embeddings(self): 894 | return self.cls.predictions.decoder 895 | 896 | def set_output_embeddings(self, new_embeddings): 897 | self.cls.predictions.decoder = new_embeddings 898 | 899 | def forward( 900 | self, 901 | input_ids=None, 902 | attention_mask=None, 903 | token_type_ids=None, 904 | position_ids=None, 905 | head_mask=None, 906 | inputs_embeds=None, 907 | encoder_hidden_states=None, 908 | encoder_attention_mask=None, 909 | labels=None, 910 | output_attentions=None, 911 | output_hidden_states=None, 912 | return_dict=None, 913 | ): 914 | r""" 915 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 916 | Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., 917 | config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored 918 | (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` 919 | """ 920 | 921 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 922 | 923 | outputs = self.bert( 924 | input_ids, 925 | attention_mask=attention_mask, 926 | token_type_ids=token_type_ids, 927 | position_ids=position_ids, 928 | head_mask=head_mask, 929 | inputs_embeds=inputs_embeds, 930 | encoder_hidden_states=encoder_hidden_states, 931 | encoder_attention_mask=encoder_attention_mask, 932 | output_attentions=output_attentions, 933 | output_hidden_states=output_hidden_states, 934 | return_dict=return_dict, 935 | ) 936 | 937 | sequence_output = outputs[0] 938 | prediction_scores = self.cls(sequence_output) 939 | 940 | masked_lm_loss = None 941 | if labels is not None: 942 | loss_fct = CrossEntropyLoss() # -100 index = padding token 943 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) 944 | 945 | if not return_dict: 946 | output = (prediction_scores,) + outputs[2:] 947 | return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output 948 | 949 | return MaskedLMOutput( 950 | loss=masked_lm_loss, 951 | logits=prediction_scores, 952 | hidden_states=outputs.hidden_states, 953 | attentions=outputs.attentions, 954 | ) 955 | 956 | 957 | def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs): 958 | input_shape = input_ids.shape 959 | effective_batch_size = input_shape[0] 960 | 961 | # add a dummy token 962 | assert self.config.pad_token_id is not None, "The PAD token should be defined for generation" 963 | attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1) 964 | dummy_token = torch.full( 965 | (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device 966 | ) 967 | input_ids = torch.cat([input_ids, dummy_token], dim=1) 968 | 969 | return {"input_ids": input_ids, "attention_mask": attention_mask} 970 | 971 | --------------------------------------------------------------------------------