├── models ├── __init__.py ├── util │ ├── __init__.py │ ├── efficiency_metrics_bert_models.py │ ├── benchmark_original_models.py │ └── efficiency_metrics_roberta.py ├── big_bird │ ├── __init__.py │ ├── modeling_big_bird.py │ └── tokenization_big_bird.py ├── longformer │ ├── __init__.py │ ├── convert_roberta_to_lf.py │ ├── convert_bert_to_lf.py │ └── tokenization_longformer.py └── hat │ ├── __init__.py │ ├── convert_roberta_to_htf.py │ ├── convert_bert_to_htf.py │ ├── configuration_hat.py │ └── tokenization_hat.py ├── data ├── PLMs │ └── placeholder.txt ├── __init__.py ├── .DS_Store ├── roberta │ ├── .DS_Store │ ├── special_tokens_map.json │ ├── tokenizer_config.json │ └── config.json ├── figures │ ├── hat_encoder.png │ └── hat_layouts.png ├── hi-transformer │ └── .DS_Store ├── quality-dataset │ ├── read_quality.py │ └── quality-dataset.py ├── hat │ ├── special_tokens_map.json │ ├── tokenizer_config.json │ └── config.json ├── check-data.py ├── check_training_arguments.py ├── c4-dataset │ ├── transform-data.py │ └── c4-dataset.py ├── mimic-dataset │ └── mimic-dataset.py ├── ecthr-arguments-dataset │ ├── create_dataset.py │ └── ecthr-arguments-dataset.py └── contractnli-dataset │ └── contractnli-dataset.py ├── requirements.txt ├── running_scripts ├── fine_tuning │ ├── train_sentence_order_prediction.sh │ ├── train_masked_sentence_mcqa.sh │ ├── train_document_classification.sh │ ├── train_ecthr_classification.sh │ ├── train_mimic_classification.sh │ ├── train_contract_nli.sh │ ├── train_ecthr_arguments_classification.sh │ └── train_quality_mcqa.sh └── pre_training │ ├── train_longformer.sh │ ├── train_longformer_base.sh │ ├── train_hat.sh │ └── train_hat_base.sh ├── LICENSE ├── evaluation ├── multi_label_utils.py ├── data_collator.py └── run_sequential_sentence_classification.py ├── language_modelling ├── prepare_dataset.py ├── xla_spawn.py ├── train_tokenizer.py └── text_featurization.py ├── .gitignore └── README.md /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/util/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/PLMs/placeholder.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | DATA_DIR = os.path.dirname(os.path.realpath(__file__)) -------------------------------------------------------------------------------- /data/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coastalcph/hierarchical-transformers/HEAD/data/.DS_Store -------------------------------------------------------------------------------- /data/roberta/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coastalcph/hierarchical-transformers/HEAD/data/roberta/.DS_Store -------------------------------------------------------------------------------- /data/figures/hat_encoder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coastalcph/hierarchical-transformers/HEAD/data/figures/hat_encoder.png -------------------------------------------------------------------------------- /data/figures/hat_layouts.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coastalcph/hierarchical-transformers/HEAD/data/figures/hat_layouts.png -------------------------------------------------------------------------------- /data/hi-transformer/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coastalcph/hierarchical-transformers/HEAD/data/hi-transformer/.DS_Store -------------------------------------------------------------------------------- /data/quality-dataset/read_quality.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | 3 | datasets = load_dataset('../contractnli-dataset') 4 | 5 | print() -------------------------------------------------------------------------------- /models/big_bird/__init__.py: -------------------------------------------------------------------------------- 1 | from .modeling_big_bird import BigBirdModelForSentenceClassification 2 | from .tokenization_big_bird import BigbirdTokenizer 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.11.0 2 | transformers>=4.18.0 3 | datasets>=2.0.0 4 | tokenizers>=0.11.0 5 | scikit-learn>=1.0.0 6 | accelerate>=0.19.0 7 | tqdm>=4.62.0 8 | nltk>=3.7.0 -------------------------------------------------------------------------------- /data/hat/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | {"bos_token": "", "eos_token": "", "unk_token": "", "sep_token": "", "pad_token": "", "cls_token": "", "mask_token": ""} -------------------------------------------------------------------------------- /data/roberta/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | {"bos_token": "", "eos_token": "", "unk_token": "", "sep_token": "", "pad_token": "", "cls_token": "", "mask_token": ""} -------------------------------------------------------------------------------- /data/roberta/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | {"model_max_length": 512, "bos_token": "", "eos_token": "", "unk_token": "", "sep_token": "", "pad_token": "", "cls_token": "", "mask_token": "", "tokenizer_class": "PreTrainedTokenizerFast"} -------------------------------------------------------------------------------- /models/longformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .modeling_longformer import LongformerModelForSentenceClassification, LongformerModelForPreTraining, \ 2 | LongformerModelForSequenceClassification, LongformerForMaskedLM, LongformerForMultipleChoice 3 | from .tokenization_longformer import LongformerTokenizer 4 | -------------------------------------------------------------------------------- /data/check-data.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | 3 | dataset = load_dataset('data/wikipedia-dataset', '20200501.en', data_dir='data/wikipedia-dataset') 4 | 5 | print(f'Train subset: {len(dataset["train"])}') 6 | print(f'Validation subset: {len(dataset["validation"])}') 7 | print(f'Test subset: {len(dataset["test"])}') 8 | -------------------------------------------------------------------------------- /data/hat/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | {"model_max_length": 4096, "bos_token": "", "eos_token": "", "unk_token": "", "sep_token": "", "pad_token": "", "cls_token": "", "mask_token": "", "tokenizer_class": "PreTrainedTokenizerFast", "auto_map": {"AutoTokenizer": ["tokenization_hat.HATTokenizer", "tokenization_hat.HATTokenizer"]}} -------------------------------------------------------------------------------- /models/hat/__init__.py: -------------------------------------------------------------------------------- 1 | from .modelling_hat import HATForMaskedLM, HATForSequenceClassification, \ 2 | HATModelForBoWPreTraining, HATModelForVICRegPreTraining, HATModelForSimCLRPreTraining, \ 3 | HATModelForSequentialSentenceClassification, HATForMultipleChoice 4 | from .tokenization_hat import HATTokenizer 5 | from .configuration_hat import HATConfig 6 | -------------------------------------------------------------------------------- /data/check_training_arguments.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | 4 | 5 | def check_args(): 6 | parser = argparse.ArgumentParser() 7 | 8 | # Required arguments 9 | parser.add_argument('--path', default=None) 10 | config = parser.parse_args() 11 | 12 | training_args = torch.load(config.path) 13 | print(training_args) 14 | 15 | 16 | if __name__ == '__main__': 17 | check_args() 18 | -------------------------------------------------------------------------------- /data/roberta/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "RobertaForMaskedLM" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "bos_token_id": 1, 7 | "eos_token_id": 2, 8 | "hidden_act": "gelu", 9 | "hidden_dropout_prob": 0.1, 10 | "hidden_size": 768, 11 | "initializer_range": 0.02, 12 | "intermediate_size": 3072, 13 | "layer_norm_eps": 1e-05, 14 | "max_position_embeddings": 512, 15 | "model_type": "roberta", 16 | "num_attention_heads": 12, 17 | "num_hidden_layers": 12, 18 | "pad_token_id": 0, 19 | "type_vocab_size": 1, 20 | "vocab_size": 50000 21 | } -------------------------------------------------------------------------------- /data/c4-dataset/transform-data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | 4 | import tqdm 5 | from datasets import load_dataset 6 | 7 | 8 | def create_eval_streamable_wiki(): 9 | for subset, size in zip(['train', 'validation'], [60000, 20000]): 10 | dataset = load_dataset('c4', 'en', split=subset, streaming=True) 11 | shuffled_dataset = dataset.shuffle(seed=42, buffer_size=10_000) 12 | valid_samples = [] 13 | count = 0 14 | for sample in tqdm.tqdm(iter(shuffled_dataset)): 15 | if 4000 > len(sample['text'].split()) > 1000: 16 | valid_samples.append(sample['text']) 17 | count += 1 18 | if count >= 2000000: 19 | break 20 | 21 | print(f'TOTAL SAMPLES MEET CRITERIA: {len(valid_samples)}') 22 | samples = random.sample(valid_samples, k=size) 23 | 24 | with open(f'evaluation_{subset}.jsonl', 'w') as file: 25 | for sample in samples: 26 | file.write(json.dumps({'text': sample})+'\n') 27 | 28 | 29 | if __name__ == '__main__': 30 | create_eval_streamable_wiki() 31 | -------------------------------------------------------------------------------- /running_scripts/fine_tuning/train_sentence_order_prediction.sh: -------------------------------------------------------------------------------- 1 | export WANDB_PROJECT="HATs-eval" 2 | export PYTHONPATH=. 3 | 4 | MODEL_NAME='hi-transformer-p1-roberta-mlm' 5 | MODEL_MAX_LENGTH=4096 6 | MAX_SENTENCES=32 7 | 8 | python evaluation/run_sentence_order.py \ 9 | --model_name_or_path data/PLMs/${MODEL_NAME} \ 10 | --dataset_name ./data/wikipedia-dataset \ 11 | --dataset_config_name eval.en \ 12 | --do_train \ 13 | --do_eval \ 14 | --do_predict \ 15 | --output_dir data/PLMs/${MODEL_NAME}-sop \ 16 | --overwrite_output_dir \ 17 | --evaluation_strategy epoch \ 18 | --save_strategy epoch \ 19 | --num_train_epochs 20 \ 20 | --load_best_model_at_end \ 21 | --metric_for_best_model accuracy_score \ 22 | --greater_is_better True \ 23 | --save_total_limit 5 \ 24 | --learning_rate 1e-5 \ 25 | --per_device_train_batch_size 32 \ 26 | --per_device_eval_batch_size 32 \ 27 | --lr_scheduler_type linear \ 28 | --warmup_ratio 0.05 \ 29 | --max_seq_length ${MODEL_MAX_LENGTH} \ 30 | --max_sentences ${MAX_SENTENCES} \ 31 | --pad_to_max_length -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Ilias Chalkidis 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /running_scripts/fine_tuning/train_masked_sentence_mcqa.sh: -------------------------------------------------------------------------------- 1 | export WANDB_PROJECT="HATs-eval" 2 | export PYTHONPATH=. 3 | 4 | MODEL_NAME='hat-p1-roberta-mlm' 5 | MODEL_MAX_LENGTH=4096 6 | MAX_SENTENCES=32 7 | 8 | python evaluation/run_masked_sentence_prediction.py \ 9 | --model_name_or_path data/PLMs/${MODEL_NAME} \ 10 | --dataset_name ./data/wikipedia-dataset \ 11 | --dataset_config_name eval.en \ 12 | --do_train \ 13 | --do_eval \ 14 | --do_predict \ 15 | --output_dir data/PLMs/${MODEL_NAME}-mcqa-sbert \ 16 | --overwrite_output_dir \ 17 | --evaluation_strategy epoch \ 18 | --save_strategy epoch \ 19 | --num_train_epochs 20 \ 20 | --load_best_model_at_end \ 21 | --metric_for_best_model accuracy_score \ 22 | --greater_is_better True \ 23 | --save_total_limit 5 \ 24 | --learning_rate 1e-5 \ 25 | --per_device_train_batch_size 32 \ 26 | --per_device_eval_batch_size 32 \ 27 | --lr_scheduler_type linear \ 28 | --warmup_ratio 0.05 \ 29 | --max_seq_length ${MODEL_MAX_LENGTH} \ 30 | --max_sentences ${MAX_SENTENCES} \ 31 | --pad_to_max_length \ 32 | --sentence_bert_path all-MiniLM-L6-v2 -------------------------------------------------------------------------------- /running_scripts/fine_tuning/train_document_classification.sh: -------------------------------------------------------------------------------- 1 | export WANDB_PROJECT="HATs-eval" 2 | export PYTHONPATH=. 3 | 4 | MODEL_NAME='hat-s1-grouped-mlm' 5 | POOLING_METHOD='max' 6 | MODEL_MAX_LENGTH=1024 7 | MAX_SENTENCES=8 8 | 9 | python evaluation/run_document_classification.py \ 10 | --model_name_or_path data/PLMs/${MODEL_NAME} \ 11 | --pooling ${POOLING_METHOD} \ 12 | --dataset_name multi_eurlex \ 13 | --dataset_config_name en \ 14 | --do_train \ 15 | --do_eval \ 16 | --do_predict \ 17 | --output_dir data/PLMs/${MODEL_NAME}-${POOLING_METHOD}-dc \ 18 | --overwrite_output_dir \ 19 | --evaluation_strategy epoch \ 20 | --save_strategy epoch \ 21 | --num_train_epochs 20 \ 22 | --load_best_model_at_end \ 23 | --metric_for_best_model micro_f1 \ 24 | --greater_is_better True \ 25 | --save_total_limit 5 \ 26 | --learning_rate 1e-5 \ 27 | --per_device_train_batch_size 32 \ 28 | --per_device_eval_batch_size 32 \ 29 | --lr_scheduler_type linear \ 30 | --warmup_ratio 0.05 \ 31 | --max_seq_length ${MODEL_MAX_LENGTH} \ 32 | --max_sentences ${MAX_SENTENCES} \ 33 | --pad_to_max_length -------------------------------------------------------------------------------- /running_scripts/fine_tuning/train_ecthr_classification.sh: -------------------------------------------------------------------------------- 1 | export WANDB_PROJECT="HATs-eval" 2 | export PYTHONPATH=. 3 | 4 | MODEL_NAME='allenai/longformer-base-4096' 5 | POOLING_METHOD='max' 6 | MODEL_MAX_LENGTH=4096 7 | MAX_SENTENCES=32 8 | 9 | python evaluation/run_document_classification.py \ 10 | --model_name_or_path ${MODEL_NAME} \ 11 | --pooling ${POOLING_METHOD} \ 12 | --dataset_name lex_glue \ 13 | --dataset_config_name ecthr_b \ 14 | --do_train \ 15 | --do_eval \ 16 | --do_predict \ 17 | --output_dir data/PLMs/${MODEL_NAME}-${POOLING_METHOD}-ecthr \ 18 | --overwrite_output_dir \ 19 | --evaluation_strategy epoch \ 20 | --save_strategy epoch \ 21 | --num_train_epochs 20 \ 22 | --load_best_model_at_end \ 23 | --metric_for_best_model micro_f1 \ 24 | --greater_is_better True \ 25 | --save_total_limit 5 \ 26 | --learning_rate 1e-5 \ 27 | --per_device_train_batch_size 32 \ 28 | --per_device_eval_batch_size 32 \ 29 | --lr_scheduler_type linear \ 30 | --warmup_ratio 0.05 \ 31 | --max_seq_length ${MODEL_MAX_LENGTH} \ 32 | --max_sentences ${MAX_SENTENCES} \ 33 | --pad_to_max_length -------------------------------------------------------------------------------- /running_scripts/fine_tuning/train_mimic_classification.sh: -------------------------------------------------------------------------------- 1 | export WANDB_PROJECT="HATs-eval" 2 | export PYTHONPATH=. 3 | 4 | MODEL_NAME='allenai/longformer-base-4096' 5 | POOLING_METHOD='max' 6 | MODEL_MAX_LENGTH=4096 7 | MAX_SENTENCES=32 8 | 9 | python evaluation/run_document_classification.py \ 10 | --model_name_or_path ${MODEL_NAME} \ 11 | --pooling ${POOLING_METHOD} \ 12 | --dataset_name data/mimic-dataset \ 13 | --dataset_config_name mimic \ 14 | --do_train \ 15 | --do_eval \ 16 | --do_predict \ 17 | --output_dir data/PLMs/${MODEL_NAME}-${POOLING_METHOD}-mimic \ 18 | --overwrite_output_dir \ 19 | --evaluation_strategy epoch \ 20 | --save_strategy epoch \ 21 | --num_train_epochs 20 \ 22 | --load_best_model_at_end \ 23 | --metric_for_best_model micro_f1 \ 24 | --greater_is_better True \ 25 | --save_total_limit 5 \ 26 | --learning_rate 1e-5 \ 27 | --per_device_train_batch_size 8 \ 28 | --per_device_eval_batch_size 8 \ 29 | --lr_scheduler_type linear \ 30 | --warmup_ratio 0.05 \ 31 | --max_seq_length ${MODEL_MAX_LENGTH} \ 32 | --max_sentences ${MAX_SENTENCES} \ 33 | --pad_to_max_length -------------------------------------------------------------------------------- /running_scripts/fine_tuning/train_contract_nli.sh: -------------------------------------------------------------------------------- 1 | export WANDB_PROJECT="HATs-eval" 2 | export PYTHONPATH=. 3 | 4 | MODEL_NAME='allenai/longformer-base-4096' 5 | POOLING_METHOD='last' 6 | MODEL_MAX_LENGTH=4096 7 | MAX_SENTENCES=32 8 | 9 | python evaluation/run_document_nli.py \ 10 | --model_name_or_path ${MODEL_NAME} \ 11 | --pooling ${POOLING_METHOD} \ 12 | --dataset_name data/contractnli-dataset \ 13 | --dataset_config_name contractnli \ 14 | --do_train \ 15 | --do_eval \ 16 | --do_predict \ 17 | --output_dir data/PLMs/${MODEL_NAME}-${POOLING_METHOD}-contractnli \ 18 | --overwrite_output_dir \ 19 | --evaluation_strategy epoch \ 20 | --save_strategy epoch \ 21 | --num_train_epochs 20 \ 22 | --load_best_model_at_end \ 23 | --metric_for_best_model micro_f1 \ 24 | --greater_is_better True \ 25 | --save_total_limit 5 \ 26 | --learning_rate 1e-5 \ 27 | --per_device_train_batch_size 32 \ 28 | --per_device_eval_batch_size 32 \ 29 | --lr_scheduler_type linear \ 30 | --warmup_ratio 0.05 \ 31 | --max_seq_length ${MODEL_MAX_LENGTH} \ 32 | --max_sentences ${MAX_SENTENCES} \ 33 | --pad_to_max_length -------------------------------------------------------------------------------- /running_scripts/fine_tuning/train_ecthr_arguments_classification.sh: -------------------------------------------------------------------------------- 1 | export WANDB_PROJECT="HATs-eval" 2 | export PYTHONPATH=. 3 | export CUDA_VISIBLE_DEVICES=1,2,3,4 4 | 5 | MODEL_NAME='google/bigbird-roberta-base' 6 | MODEL_MAX_LENGTH=4096 7 | MAX_SENTENCES=32 8 | 9 | python evaluation/run_sequential_sentence_classification.py \ 10 | --dataset_name ../data/ecthr-arguments-dataset \ 11 | --dataset_config_name ecthr-arguments-dataset \ 12 | --model_name_or_path ${MODEL_NAME} \ 13 | --do_train \ 14 | --do_eval \ 15 | --do_predict \ 16 | --output_dir data/PLMs/${MODEL_NAME}-ecthr-args \ 17 | --overwrite_output_dir \ 18 | --evaluation_strategy epoch \ 19 | --save_strategy epoch \ 20 | --num_train_epochs 20 \ 21 | --load_best_model_at_end \ 22 | --metric_for_best_model micro-f1 \ 23 | --greater_is_better True \ 24 | --save_total_limit 5 \ 25 | --learning_rate 3e-5 \ 26 | --per_device_train_batch_size 2 \ 27 | --per_device_eval_batch_size 2 \ 28 | --lr_scheduler_type linear \ 29 | --warmup_ratio 0.05 \ 30 | --max_seq_length ${MODEL_MAX_LENGTH} \ 31 | --max_sentences ${MAX_SENTENCES} \ 32 | --pad_to_max_length 33 | -------------------------------------------------------------------------------- /running_scripts/pre_training/train_longformer.sh: -------------------------------------------------------------------------------- 1 | export WANDB_PROJECT="HATs-pretrain" 2 | export PYTHONPATH=. 3 | 4 | MODEL_MAX_LENGTH=1024 5 | MAX_SENTENCES=8 6 | 7 | python models/longformer/convert_bert_to_lf.py --max_sentences ${MAX_SENTENCES} 8 | 9 | python language_modelling/run_mlm_stream.py \ 10 | --model_name_or_path data/PLMs/longformer \ 11 | --dataset_name ./data/wikipedia-dataset \ 12 | --dataset_config_name 20200501.en \ 13 | --do_train \ 14 | --do_eval \ 15 | --output_dir data/PLMs/longformer-global-mlm \ 16 | --overwrite_output_dir \ 17 | --logging_steps 500 \ 18 | --evaluation_strategy steps \ 19 | --eval_steps 10000 \ 20 | --save_strategy steps \ 21 | --save_steps 10000 \ 22 | --save_total_limit 5 \ 23 | --max_steps 50000 \ 24 | --learning_rate 1e-4 \ 25 | --per_device_train_batch_size 32 \ 26 | --per_device_eval_batch_size 32 \ 27 | --gradient_accumulation_steps 4 \ 28 | --eval_accumulation_steps 4 \ 29 | --lr_scheduler_type linear \ 30 | --warmup_ratio 0.10 \ 31 | --weight_decay 0.01 \ 32 | --mlm_probability 0.15 \ 33 | --max_seq_length ${MODEL_MAX_LENGTH} \ 34 | --line_by_line \ 35 | --pad_to_max_length -------------------------------------------------------------------------------- /running_scripts/fine_tuning/train_quality_mcqa.sh: -------------------------------------------------------------------------------- 1 | export WANDB_PROJECT="HATs-eval" 2 | export PYTHONPATH=. 3 | 4 | MODEL_NAME='allenai/longformer-base-4096' 5 | POOLING_METHOD='last' 6 | MODEL_MAX_LENGTH=4096 7 | MAX_SENTENCES=32 8 | 9 | python evaluation/run_quality_mcqa.py \ 10 | --model_name_or_path ${MODEL_NAME} \ 11 | --pooling ${POOLING_METHOD} \ 12 | --dataset_name data/quality-dataset \ 13 | --dataset_config_name quality \ 14 | --do_train \ 15 | --do_eval \ 16 | --do_predict \ 17 | --output_dir data/PLMs/${MODEL_NAME}-${POOLING_METHOD}-quality \ 18 | --overwrite_output_dir \ 19 | --evaluation_strategy epoch \ 20 | --save_strategy epoch \ 21 | --num_train_epochs 20 \ 22 | --load_best_model_at_end \ 23 | --metric_for_best_model accuracy_score \ 24 | --greater_is_better True \ 25 | --save_total_limit 5 \ 26 | --learning_rate 1e-5 \ 27 | --per_device_train_batch_size 2 \ 28 | --per_device_eval_batch_size 2 \ 29 | --gradient_accumulation_steps 4 \ 30 | --eval_accumulation_steps 4 \ 31 | --lr_scheduler_type linear \ 32 | --warmup_ratio 0.05 \ 33 | --max_seq_length ${MODEL_MAX_LENGTH} \ 34 | --max_sentences ${MAX_SENTENCES} \ 35 | --pad_to_max_length -------------------------------------------------------------------------------- /running_scripts/pre_training/train_longformer_base.sh: -------------------------------------------------------------------------------- 1 | export WANDB_PROJECT="HATs-pretrain" 2 | export PYTHONPATH=. 3 | 4 | MODEL_MAX_LENGTH=4096 5 | MAX_SENTENCES=32 6 | 7 | python models/longformer/convert_roberta_to_lf.py --max_sentences ${MAX_SENTENCES} --num_hidden_layers 12 8 | 9 | python language_modelling/run_mlm_stream.py \ 10 | --model_name_or_path data/PLMs/longformer-roberta \ 11 | --dataset_name c4 \ 12 | --dataset_config_name en \ 13 | --do_train \ 14 | --do_eval \ 15 | --output_dir data/PLMs/longformer-roberta-mlm \ 16 | --overwrite_output_dir \ 17 | --logging_steps 500 \ 18 | --evaluation_strategy steps \ 19 | --eval_steps 10000 \ 20 | --save_strategy steps \ 21 | --save_steps 10000 \ 22 | --save_total_limit 5 \ 23 | --max_steps 50000 \ 24 | --learning_rate 1e-4 \ 25 | --per_device_train_batch_size 8 \ 26 | --per_device_eval_batch_size 8 \ 27 | --gradient_accumulation_steps 16 \ 28 | --eval_accumulation_steps 16 \ 29 | --lr_scheduler_type linear \ 30 | --warmup_ratio 0.10 \ 31 | --weight_decay 0.01 \ 32 | --mlm_probability 0.15 \ 33 | --max_seq_length ${MODEL_MAX_LENGTH} \ 34 | --max_sentences ${MAX_SENTENCES} \ 35 | --min_sequence_length 1024 \ 36 | --pad_to_max_length \ 37 | --max_eval_samples 100000 -------------------------------------------------------------------------------- /evaluation/multi_label_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.special import expit 3 | 4 | 5 | def fix_multi_label_scores(original_predictions, original_label_ids, unpad_sequences=False, flatten_sequences=False): 6 | 7 | if flatten_sequences: 8 | predictions = original_predictions.reshape((-1, original_predictions.shape[-1])) 9 | label_ids = original_label_ids.reshape((-1, original_predictions.shape[-1])) 10 | else: 11 | predictions = original_predictions 12 | label_ids = original_label_ids 13 | 14 | if unpad_sequences: 15 | predictions = np.asarray([pred for pred, label in zip(predictions, label_ids) if label[0] != -1]) 16 | label_ids = np.asarray([label for label in label_ids if label[0] != -1]) 17 | 18 | # Fix gold labels 19 | y_true = np.zeros((len(label_ids), len(label_ids[0]) + 1), dtype=np.int32) 20 | y_true[:, :-1] = label_ids 21 | y_true[:, -1] = (np.sum(label_ids, axis=1) == 0).astype('int32') 22 | # Fix predictions 23 | logits = predictions[0] if isinstance(predictions, tuple) else predictions 24 | preds = (expit(logits) > 0.5).astype('int32') 25 | y_pred = np.zeros((len(label_ids), len(label_ids[0]) + 1), dtype=np.int32) 26 | y_pred[:, :-1] = preds 27 | y_pred[:, -1] = (np.sum(preds, axis=1) == 0).astype('int32') 28 | 29 | return y_true, y_pred 30 | -------------------------------------------------------------------------------- /running_scripts/pre_training/train_hat.sh: -------------------------------------------------------------------------------- 1 | export WANDB_PROJECT="HATs-pretrain" 2 | export XRT_TPU_CONFIG="localservice;0;localhost:51011" 3 | export PYTHONPATH=. 4 | 5 | LAYOUT='s1' 6 | MODEL_WARMUP_STRATEGY='grouped' 7 | MODEL_MAX_LENGTH=1024 8 | MAX_SENTENCES=8 9 | 10 | python3 models/hat/convert_bert_to_htf.py --layout ${LAYOUT} --max_sentences ${MAX_SENTENCES} 11 | 12 | python3 language_modelling/xla_spawn.py --num_cores=8 language_modelling/run_mlm_stream.py \ 13 | --model_name_or_path data/PLMs/hi-transformer-${LAYOUT}-${MODEL_WARMUP_STRATEGY} \ 14 | --dataset_name c4 \ 15 | --dataset_config_name en \ 16 | --do_train \ 17 | --do_eval \ 18 | --output_dir data/PLMs/hi-transformer-${LAYOUT}-${MODEL_WARMUP_STRATEGY}-mlm \ 19 | --overwrite_output_dir \ 20 | --logging_steps 500 \ 21 | --evaluation_strategy steps \ 22 | --eval_steps 10000 \ 23 | --save_strategy steps \ 24 | --save_steps 10000 \ 25 | --save_total_limit 5 \ 26 | --max_steps 50000 \ 27 | --learning_rate 1e-4 \ 28 | --per_device_train_batch_size 4 \ 29 | --per_device_eval_batch_size 4 \ 30 | --gradient_accumulation_steps 4 \ 31 | --eval_accumulation_steps 4 \ 32 | --lr_scheduler_type linear \ 33 | --warmup_ratio 0.10 \ 34 | --weight_decay 0.01 \ 35 | --mlm_probability 0.15 \ 36 | --max_seq_length ${MODEL_MAX_LENGTH} \ 37 | --line_by_line \ 38 | --pad_to_max_length -------------------------------------------------------------------------------- /running_scripts/pre_training/train_hat_base.sh: -------------------------------------------------------------------------------- 1 | export WANDB_PROJECT="HATs-pretrain" 2 | export XRT_TPU_CONFIG="localservice;0;localhost:51011" 3 | export PYTHONPATH=. 4 | 5 | LAYOUT='p1' 6 | MODEL_MAX_LENGTH=4096 7 | MAX_SENTENCES=32 8 | 9 | python3 models/hat/convert_roberta_to_htf.py --layout ${LAYOUT} --max_sentences ${MAX_SENTENCES} 10 | 11 | python3 language_modelling/xla_spawn.py --num_cores=8 language_modelling/run_mlm_stream.py \ 12 | --model_name_or_path data/PLMs/hat-${LAYOUT}-roberta \ 13 | --dataset_name c4 \ 14 | --dataset_config_name en \ 15 | --do_train \ 16 | --do_eval \ 17 | --output_dir data/PLMs/hi-transformer-${LAYOUT}-roberta-mlm \ 18 | --overwrite_output_dir \ 19 | --logging_steps 500 \ 20 | --evaluation_strategy steps \ 21 | --eval_steps 10000 \ 22 | --save_strategy steps \ 23 | --save_steps 10000 \ 24 | --save_total_limit 5 \ 25 | --max_steps 50000 \ 26 | --learning_rate 1e-4 \ 27 | --per_device_train_batch_size 4 \ 28 | --per_device_eval_batch_size 4 \ 29 | --gradient_accumulation_steps 4 \ 30 | --eval_accumulation_steps 4 \ 31 | --lr_scheduler_type linear \ 32 | --warmup_ratio 0.10 \ 33 | --weight_decay 0.01 \ 34 | --mlm_probability 0.15 \ 35 | --max_seq_length ${MODEL_MAX_LENGTH} \ 36 | --max_sentences ${MAX_SENTENCES} \ 37 | --min_sequence_length 1024 \ 38 | --pad_to_max_length \ 39 | --max_eval_samples 100000 -------------------------------------------------------------------------------- /language_modelling/prepare_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tqdm 3 | import matplotlib.pyplot as plt 4 | 5 | from datasets import load_dataset 6 | import argparse 7 | 8 | 9 | def main(): 10 | """ set default hyperparams in default_hyperparams.py """ 11 | parser = argparse.ArgumentParser() 12 | 13 | # Required arguments 14 | parser.add_argument('--dataset_name', default='wikipedia') 15 | parser.add_argument('--dataset_config', default='20200501.en') 16 | config = parser.parse_args() 17 | 18 | # load datasets 19 | dataset = load_dataset(config.dataset_name, config.dataset_config, split='train') 20 | 21 | text_length = [] 22 | for text in tqdm.tqdm(dataset['text']): 23 | text_length.append(len(text.split())) 24 | 25 | # reduce to truncated size, max 4096 26 | text_length = [x if x <= 4000 else 4096 for x in text_length] 27 | 28 | print(f'AVG: {np.mean(text_length):.1f} ± {np.std(text_length):.1f}, MAX: {np.max(text_length):.1f}') 29 | 30 | # print stats in percentiles 31 | for min_size in [512, 1024, 2048]: 32 | n_docs = len([1 for x in text_length if x >= min_size]) 33 | perc = (n_docs * 100) / len(text_length) 34 | print(f'No of document over {min_size} words: {n_docs}/{len(text_length)} ({perc:.1f}%)') 35 | 36 | # plot document length histogram 37 | plt.hist(text_length, range=(500, 4096), bins=50) 38 | plt.savefig(f'{config.dataset_name}_hist.png') 39 | 40 | 41 | if __name__ == '__main__': 42 | main() 43 | -------------------------------------------------------------------------------- /evaluation/data_collator.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from transformers.tokenization_utils_base import PreTrainedTokenizerBase 3 | 4 | @dataclass 5 | class DataCollatorForMultiLabelClassification: 6 | tokenizer: PreTrainedTokenizerBase 7 | 8 | def __call__(self, features): 9 | import torch 10 | first = features[0] 11 | batch = {} 12 | 13 | # Special handling for labels. 14 | # Ensure that tensor is created with the correct type 15 | # (it should be automatically the case, but let's make sure of it.) 16 | if "label_ids" in first and first["label_ids"] is not None: 17 | if isinstance(first["labels"], torch.Tensor): 18 | batch["labels"] = torch.stack([f["label_ids"] for f in features]) 19 | else: 20 | dtype = torch.long if type(first["label_ids"][0]) is int else torch.float 21 | batch["labels"] = torch.tensor([f["label_ids"] for f in features], dtype=dtype) 22 | 23 | # Handling of all other possible keys. 24 | # Again, we will use the first element to figure out which key/values are not None for this model. 25 | for k, v in first.items(): 26 | if k not in ("label", "label_ids", "labels") and v is not None and not isinstance(v, str): 27 | if isinstance(v, torch.Tensor): 28 | batch[k] = torch.stack([f[k] for f in features]) 29 | else: 30 | batch[k] = torch.tensor([f[k] for f in features]) 31 | 32 | return batch 33 | -------------------------------------------------------------------------------- /data/hat/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "HATForMaskedLM" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "bos_token_id": 1, 7 | "do_sample": false, 8 | "eos_token_id": 2, 9 | "finetuning_task": null, 10 | "hidden_act": "gelu", 11 | "hidden_dropout_prob": 0.1, 12 | "hidden_size": 256, 13 | "id2label": { 14 | "0": "LABEL_0", 15 | "1": "LABEL_1" 16 | }, 17 | "initializer_range": 0.02, 18 | "intermediate_size": 512, 19 | "is_decoder": false, 20 | "label2id": { 21 | "LABEL_0": 0, 22 | "LABEL_1": 1 23 | }, 24 | "layer_norm_eps": 1e-12, 25 | "length_penalty": 1.0, 26 | "max_sentence_length": 128, 27 | "max_sentences": 64, 28 | "model_max_length": 8192, 29 | "max_position_embeddings": 128, 30 | "model_type": "hierarchical-transformer", 31 | "num_attention_heads": 8, 32 | "num_beams": 1, 33 | "num_hidden_layers": 6, 34 | "encoder_layout": { 35 | "0": {"sentence_encoder": true, "document_encoder": false}, 36 | "1": {"sentence_encoder": true, "document_encoder": false}, 37 | "2": {"sentence_encoder": true, "document_encoder": true}, 38 | "3": {"sentence_encoder": true, "document_encoder": false}, 39 | "4": {"sentence_encoder": true, "document_encoder": false}, 40 | "5": {"sentence_encoder": true, "document_encoder": true}}, 41 | "num_labels": 2, 42 | "num_return_sequences": 1, 43 | "output_attentions": false, 44 | "output_hidden_states": false, 45 | "output_past": true, 46 | "pad_token_id": 0, 47 | "pruned_heads": {}, 48 | "repetition_penalty": 1.0, 49 | "temperature": 1.0, 50 | "top_k": 50, 51 | "top_p": 1.0, 52 | "torchscript": false, 53 | "type_vocab_size": 2, 54 | "use_bfloat16": false, 55 | "vocab_size": 50000, 56 | "parameters": 136350720 57 | } -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # PyCharm files 132 | .idea 133 | 134 | # Log files 135 | logs/ 136 | -------------------------------------------------------------------------------- /language_modelling/xla_spawn.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | A simple launcher script for TPU training 16 | 17 | Inspired by https://github.com/pytorch/pytorch/blob/master/torch/distributed/launch.py 18 | 19 | :: 20 | >>> python xla_spawn.py --num_cores=NUM_CORES_YOU_HAVE 21 | YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 and all other 22 | arguments of your training script) 23 | 24 | """ 25 | 26 | 27 | import importlib 28 | import sys 29 | from argparse import REMAINDER, ArgumentParser 30 | from pathlib import Path 31 | 32 | import torch_xla.distributed.xla_multiprocessing as xmp 33 | 34 | 35 | def parse_args(): 36 | """ 37 | Helper function parsing the command line options 38 | @retval ArgumentParser 39 | """ 40 | parser = ArgumentParser( 41 | description=( 42 | "PyTorch TPU distributed training launch " 43 | "helper utility that will spawn up " 44 | "multiple distributed processes" 45 | ) 46 | ) 47 | 48 | # Optional arguments for the launch helper 49 | parser.add_argument("--num_cores", type=int, default=1, help="Number of TPU cores to use (1 or 8).") 50 | 51 | # positional 52 | parser.add_argument( 53 | "training_script", 54 | type=str, 55 | help=( 56 | "The full path to the single TPU training " 57 | "program/script to be launched in parallel, " 58 | "followed by all the arguments for the " 59 | "training script" 60 | ), 61 | ) 62 | 63 | # rest from the training program 64 | parser.add_argument("training_script_args", nargs=REMAINDER) 65 | 66 | return parser.parse_args() 67 | 68 | 69 | def main(): 70 | args = parse_args() 71 | 72 | # Import training_script as a module. 73 | script_fpath = Path(args.training_script) 74 | sys.path.append(str(script_fpath.parent.resolve())) 75 | mod_name = script_fpath.stem 76 | mod = importlib.import_module(mod_name) 77 | 78 | # Patch sys.argv 79 | sys.argv = [args.training_script] + args.training_script_args + ["--tpu_num_cores", str(args.num_cores)] 80 | 81 | xmp.spawn(mod._mp_fn, args=(), nprocs=args.num_cores) 82 | 83 | 84 | if __name__ == "__main__": 85 | main() -------------------------------------------------------------------------------- /language_modelling/train_tokenizer.py: -------------------------------------------------------------------------------- 1 | from tokenizers import models, normalizers, pre_tokenizers, decoders, processors, trainers 2 | from tokenizers import Tokenizer 3 | from datasets import load_dataset 4 | from transformers import PreTrainedTokenizerFast, AutoTokenizer 5 | import argparse 6 | 7 | CUSTOM_TOK_FOLDER = '../data/custom-tokenizer' 8 | hat_FOLDER = '../data/hi-transformer' 9 | ROBERTA_FOLDER = '../data/roberta' 10 | 11 | 12 | def batch_iterator(dataset): 13 | for example in dataset['text']: 14 | yield example 15 | 16 | 17 | def main(): 18 | """ set default hyperparams in default_hyperparams.py """ 19 | parser = argparse.ArgumentParser() 20 | 21 | # Required arguments 22 | parser.add_argument('--vocab_size', default=50000) 23 | config = parser.parse_args() 24 | 25 | # configure tokenizer 26 | backend_tokenizer = Tokenizer(models.BPE(unk_token="")) 27 | backend_tokenizer.normalizer = normalizers.Lowercase() 28 | backend_tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=True) 29 | backend_tokenizer.decoder = decoders.ByteLevel() 30 | backend_tokenizer.post_processor = processors.RobertaProcessing(sep=("", 2), cls=("", 1), 31 | add_prefix_space=True, trim_offsets=True) 32 | 33 | trainer = trainers.BpeTrainer( 34 | vocab_size=config.vocab_size, 35 | min_frequency=2, 36 | special_tokens=["", "", "", "", ""], 37 | show_progress=True 38 | ) 39 | 40 | # load datasets 41 | dataset = load_dataset("multi_eurlex", "en", split='train') 42 | 43 | # train tokenizer 44 | backend_tokenizer.train_from_iterator(trainer=trainer, iterator=batch_iterator(dataset)) 45 | 46 | # test tokenizer 47 | tokens = backend_tokenizer.encode('dog ' * 5, add_special_tokens=False) 48 | print('Original Tokenizer: ', tokens.tokens) 49 | 50 | # save tokenizer 51 | new_roberta_tokenizer = PreTrainedTokenizerFast( 52 | tokenizer_object=backend_tokenizer, 53 | model_max_length=512, 54 | # padding_side="Set me if you want", 55 | # truncation_side="Set me if you want", 56 | # model_input_names="Set me if you want", 57 | bos_token='', 58 | eos_token='', 59 | unk_token='', 60 | sep_token='', 61 | pad_token='', 62 | cls_token='', 63 | mask_token='', 64 | ) 65 | 66 | new_hat_tokenizer = PreTrainedTokenizerFast( 67 | tokenizer_object=backend_tokenizer, 68 | model_max_length=8192, 69 | # padding_side="Set me if you want", 70 | # truncation_side="Set me if you want", 71 | # model_input_names="Set me if you want", 72 | bos_token='', 73 | eos_token='', 74 | unk_token='', 75 | sep_token='', 76 | pad_token='', 77 | cls_token='', 78 | mask_token='', 79 | ) 80 | 81 | new_roberta_tokenizer.save_pretrained(CUSTOM_TOK_FOLDER) 82 | new_roberta_tokenizer.save_pretrained(ROBERTA_FOLDER) 83 | new_hat_tokenizer.save_pretrained(hat_FOLDER) 84 | 85 | # re-load tokenizer and test 86 | reloaded_tokenizer = AutoTokenizer.from_pretrained(CUSTOM_TOK_FOLDER) 87 | tokens = reloaded_tokenizer.tokenize('dog ' * 5, add_special_tokens=False) 88 | print('Reloaded Tokenizer: ', tokens) 89 | 90 | 91 | if __name__ == '__main__': 92 | main() 93 | -------------------------------------------------------------------------------- /language_modelling/text_featurization.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import random 3 | 4 | import nltk 5 | import tqdm 6 | from sklearn.feature_extraction.text import TfidfVectorizer 7 | from sklearn.decomposition import PCA 8 | import warnings 9 | warnings.filterwarnings("ignore") 10 | 11 | 12 | def train_text_featurizer(documents, tokenizer_path='google/bert_uncased_L-6_H-256_A-4', hidden_units=768): 13 | 14 | def tokenize(document: str): 15 | return tokenizer.tokenize(document, 16 | padding=False, 17 | truncation=True, 18 | max_length=1024) 19 | 20 | # init tokenizer 21 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, model_max_length=1024) 22 | 23 | # init tfidf vectorizer 24 | vocab = [(key, value) for (key, value) in tokenizer.vocab.items()] 25 | vocab = sorted(vocab, key=lambda tup: tup[1]) 26 | vocab = [key for (key, value) in vocab] 27 | tfidf_vectorizer = TfidfVectorizer(lowercase=False, preprocessor=None, tokenizer=tokenize, 28 | vocabulary=vocab) 29 | pca_solver = PCA(n_components=hidden_units) 30 | 31 | tfidf_scores = tfidf_vectorizer.fit_transform(documents) 32 | print('TFIDF-VECTORIZER DONE!') 33 | 34 | pca_solver.fit(tfidf_scores.toarray()) 35 | print('PCA SOLVER DONE!') 36 | 37 | with open('./data/wikipedia-dataset/tifidf_vectorizer.pkl', 'wb') as fin: 38 | pickle.dump(tfidf_vectorizer, fin) 39 | print('TFIDF-VECTORIZER SAVED!') 40 | 41 | with open('./data/wikipedia-dataset/pca_solver.pkl', 'wb') as fin: 42 | pickle.dump(pca_solver, fin) 43 | print('PCA SOLVER SAVED!') 44 | 45 | 46 | def learn_idfs(documents, tokenizer_path='google/bert_uncased_L-6_H-256_A-4'): 47 | 48 | def tokenize(document: str): 49 | return tokenizer.tokenize(document) 50 | 51 | # init tokenizer 52 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, model_max_length=1024) 53 | 54 | # init tfidf vectorizer 55 | vocab = [(key, value) for (key, value) in tokenizer.vocab.items()] 56 | vocab = sorted(vocab, key=lambda tup: tup[1]) 57 | vocab = [key for (key, value) in vocab] 58 | tfidf_vectorizer = TfidfVectorizer(lowercase=False, preprocessor=None, tokenizer=tokenize, 59 | vocabulary=vocab) 60 | 61 | tfidf_vectorizer.fit(documents) 62 | 63 | with open('./data/wikipedia-dataset/idf_scores.pkl', 'wb') as file: 64 | pickle.dump(tfidf_vectorizer.idf_, file) 65 | 66 | 67 | def embed_sentences(documents, model_path='all-MiniLM-L6-v2'): 68 | from sentence_transformers import SentenceTransformer 69 | 70 | # Define the model 71 | model = SentenceTransformer(model_path) 72 | 73 | # Start the multi-process pool on all available CUDA devices 74 | pool = model.start_multi_process_pool() 75 | 76 | # Sub-sample sentences 77 | grouped_sentences = [] 78 | for document in documents: 79 | doc_sentences = nltk.sent_tokenize(' '.join(document.split()[:1024])) 80 | # Build grouped sentences up to 100 words 81 | temp_sentence = '' 82 | for doc_sentence in doc_sentences: 83 | if len(temp_sentence.split()) + len(doc_sentence.split()) <= 100: 84 | temp_sentence += ' ' + doc_sentence 85 | else: 86 | if len(temp_sentence): 87 | grouped_sentences.append(temp_sentence) 88 | temp_sentence = doc_sentence 89 | if len(temp_sentence): 90 | grouped_sentences.append(temp_sentence) 91 | del documents 92 | 93 | # Compute the embeddings using the multi-process pool 94 | sentence_embeddings = model.encode_multi_process(grouped_sentences, pool) 95 | print("Embeddings computed. Shape:", sentence_embeddings.shape) 96 | 97 | # Optional: Stop the proccesses in the pool 98 | model.stop_multi_process_pool(pool) 99 | 100 | with open('../data/wikipedia-dataset/sentence_embeddings.pkl', 'wb') as file: 101 | pickle.dump(sentence_embeddings, file) 102 | 103 | print('SENTENCE EMBEDDINGS DONE!') 104 | 105 | 106 | if __name__ == '__main__': 107 | from transformers import AutoTokenizer 108 | from datasets import load_dataset 109 | 110 | # load dataset 111 | dataset = load_dataset("lex_glue", "eurlex", split='train') 112 | dataset = dataset['text'] 113 | subset = random.sample(range(len(dataset)), k=100) 114 | dataset_small = [] 115 | for i in tqdm.tqdm(subset): 116 | dataset_small.append(dataset[i]) 117 | 118 | # re-load tokenizer and test 119 | embed_sentences(documents=dataset_small) 120 | 121 | -------------------------------------------------------------------------------- /data/mimic-dataset/mimic-dataset.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """MIMIC-III.""" 16 | 17 | 18 | import json 19 | 20 | import datasets 21 | 22 | 23 | _CITATION = """@article{johnson-mit-2016-mimic-iii, 24 | author={Johnson, Alistair E W and Pollard, Tom J and Shen, Lu and Li-Wei, H Lehman and Feng, Mengling and Ghassemi, 25 | Mohammad and Moody, Benjamin and Szolovits, Peter and Celi, Leo Anthony and Mark, Roger G}, 26 | year={2016}, 27 | title={{MIMIC-III, a freely accessible critical care database}}, 28 | journal={Sci. Data}, 29 | volume={3}, 30 | url={https://www.nature.com/articles/sdata201635.pdf} 31 | }""" 32 | 33 | _DESCRIPTION = """MIMIC""" 34 | 35 | _HOMEPAGE = "https://physionet.org/content/mimiciii/1.4/" 36 | 37 | _LICENSE = "CC BY-SA (Creative Commons / Attribution-ShareAlike)" 38 | 39 | _LABEL_NAMES = ['001-139', '140-239', '240-279', '280-289', '290-319', '320-389', '390-459', '460-519', 40 | '520-579', '580-629', '630-679', '680-709', '710-739', '740-759', '760-779', '780-799', 41 | '800-999', 'V01-V91', 'E000-E999'] 42 | 43 | class MIMIC(datasets.GeneratorBasedBuilder): 44 | """MIMIC""" 45 | 46 | VERSION = datasets.Version("1.1.0") 47 | 48 | BUILDER_CONFIGS = [ 49 | datasets.BuilderConfig( 50 | name="mimic", version=VERSION, description="MIMIC" 51 | ), 52 | ] 53 | 54 | DEFAULT_CONFIG_NAME = "mimic" 55 | 56 | def _info(self): 57 | features = datasets.Features( 58 | { 59 | "summary_id": datasets.Value("string"), 60 | "text": datasets.Value("string"), 61 | "labels": datasets.features.Sequence(datasets.ClassLabel(names=_LABEL_NAMES)), 62 | } 63 | ) 64 | return datasets.DatasetInfo( 65 | # This is the description that will appear on the datasets page. 66 | description=_DESCRIPTION, 67 | # This defines the different columns of the dataset and their types 68 | features=features, # Here we define them above because they are different between the two configurations 69 | # If there's a common (input, target) tuple from the features, 70 | # specify them here. They'll be used if as_supervised=True in 71 | # builder.as_dataset. 72 | supervised_keys=None, 73 | # Homepage of the dataset for documentation 74 | homepage=_HOMEPAGE, 75 | # License for the dataset if available 76 | license=_LICENSE, 77 | # Citation for the dataset 78 | citation=_CITATION, 79 | ) 80 | 81 | def _split_generators(self, dl_manager): 82 | """Returns SplitGenerators.""" 83 | data_dir = dl_manager.download_and_extract('mimic.jsonl') 84 | return [ 85 | datasets.SplitGenerator( 86 | name=datasets.Split.TRAIN, 87 | # These kwargs will be passed to _generate_examples 88 | gen_kwargs={ 89 | "filepath": data_dir, 90 | "split": "train", 91 | }, 92 | ), 93 | datasets.SplitGenerator( 94 | name=datasets.Split.TEST, 95 | # These kwargs will be passed to _generate_examples 96 | gen_kwargs={"filepath": data_dir, 97 | "split": "test"}, 98 | ), 99 | datasets.SplitGenerator( 100 | name=datasets.Split.VALIDATION, 101 | # These kwargs will be passed to _generate_examples 102 | gen_kwargs={ 103 | "filepath": data_dir, 104 | "split": "dev", 105 | }, 106 | ), 107 | ] 108 | 109 | def _generate_examples( 110 | self, filepath, split # method parameters are unpacked from `gen_kwargs` as given in `_split_generators` 111 | ): 112 | """Yields examples as (key, example) tuples.""" 113 | 114 | with open(filepath, encoding="utf-8") as f: 115 | for id_, row in enumerate(f): 116 | data = json.loads(row) 117 | if data['data_type'] == split: 118 | yield id_, { 119 | "summary_id": data["summary_id"], 120 | "text": data['text'], 121 | "labels": data["level_1"], 122 | } -------------------------------------------------------------------------------- /data/c4-dataset/c4-dataset.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TensorFlow Datasets Authors and the HuggingFace Datasets Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """English Wikipedia dataset containing cleaned articles.""" 18 | 19 | import json 20 | import os 21 | 22 | import datasets 23 | 24 | 25 | logger = datasets.logging.get_logger(__name__) 26 | 27 | 28 | _CITATION = """\ 29 | @ONLINE {wikidump, 30 | author = {Wikimedia Foundation}, 31 | title = {Wikimedia Downloads}, 32 | url = {https://dumps.wikimedia.org} 33 | } 34 | """ 35 | 36 | _DESCRIPTION = """\ 37 | Wikipedia dataset containing cleaned articles. 38 | The datasets are built from the Wikipedia dump 39 | (https://dumps.wikimedia.org/). Each example 40 | contains the content of one full Wikipedia article with cleaning to strip 41 | markdown and unwanted sections (references, etc.). 42 | """ 43 | 44 | _LICENSE = ( 45 | "This work is licensed under the Creative Commons Attribution-ShareAlike " 46 | "3.0 Unported License. To view a copy of this license, visit " 47 | "http://creativecommons.org/licenses/by-sa/3.0/ or send a letter to " 48 | "Creative Commons, PO Box 1866, Mountain View, CA 94042, USA." 49 | ) 50 | 51 | _VERSION = datasets.Version("1.0.0", "") 52 | 53 | 54 | class C4Config(datasets.BuilderConfig): 55 | """BuilderConfig for Wikipedia.""" 56 | 57 | def __init__(self, dump=None, version=_VERSION, **kwargs): 58 | """BuilderConfig for Wikipedia. 59 | 60 | Args: 61 | language: string, the language code for the Wikipedia dump to use. 62 | date: string, date of the Wikipedia dump in YYYYMMDD format. A list of 63 | available dates can be found at https://dumps.wikimedia.org/enwiki/. 64 | **kwargs: keyword arguments forwarded to super. 65 | """ 66 | super().__init__( 67 | name=f"{dump}", 68 | description=f"C4 dataset for {dump} dump.", 69 | version=version, 70 | **kwargs, 71 | ) 72 | self.dump = dump 73 | 74 | 75 | _DATE = "20220301" 76 | 77 | 78 | class C4(datasets.GeneratorBasedBuilder): 79 | """Wikipedia dataset.""" 80 | 81 | # Use mirror (your.org) to avoid download caps. 82 | BUILDER_CONFIG_CLASS = C4Config 83 | BUILDER_CONFIGS = [ 84 | C4Config( 85 | dump="eval.en", 86 | ) # pylint:disable=g-complex-comprehension 87 | ] 88 | 89 | def _info(self): 90 | return datasets.DatasetInfo( 91 | description=_DESCRIPTION, 92 | features=datasets.Features( 93 | { 94 | "text": datasets.Value("string"), 95 | } 96 | ), 97 | # No default supervised_keys. 98 | supervised_keys=None, 99 | homepage="https://dumps.wikimedia.org", 100 | citation=_CITATION, 101 | ) 102 | 103 | def _split_generators(self, dl_manager): 104 | data_dir = dl_manager.download(os.path.join('./', f'c4.{self.config.dump}.tar.gz')) 105 | return [ 106 | datasets.SplitGenerator( 107 | name=datasets.Split.TRAIN, 108 | # These kwargs will be passed to _generate_examples 109 | gen_kwargs={"filepath": 'train.jsonl', 110 | "split": "train", 111 | "files": dl_manager.iter_archive(data_dir)}, 112 | ), 113 | datasets.SplitGenerator( 114 | name=datasets.Split.TEST, 115 | # These kwargs will be passed to _generate_examples 116 | gen_kwargs={"filepath": 'test.jsonl', 117 | "split": "test", 118 | "files": dl_manager.iter_archive(data_dir)}, 119 | ), 120 | datasets.SplitGenerator( 121 | name=datasets.Split.VALIDATION, 122 | # These kwargs will be passed to _generate_examples 123 | gen_kwargs={"filepath": 'dev.jsonl', 124 | "split": "validation", 125 | "files": dl_manager.iter_archive(data_dir)}, 126 | ), 127 | ] 128 | 129 | def _generate_examples(self, filepath, split, files): 130 | """This function returns the examples in the raw (text) form.""" 131 | for path, f in files: 132 | if path == filepath: 133 | for id_, row in enumerate(f): 134 | data = json.loads(row) 135 | yield id_, { 136 | "text": data['text'], 137 | } 138 | -------------------------------------------------------------------------------- /data/quality-dataset/quality-dataset.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """QuALITY: Question Answering with Long Input Texts, Yes!""" 16 | 17 | 18 | import json 19 | import os 20 | 21 | import datasets 22 | 23 | 24 | _CITATION = """@article{pang2021quality, 25 | title={{QuALITY}: Question Answering with Long Input Texts, Yes!}, 26 | author={Pang, Richard Yuanzhe and Parrish, Alicia and Joshi, Nitish and Nangia, Nikita and Phang, Jason and Chen, 27 | Angelica and Padmakumar, Vishakh and Ma, Johnny and Thompson, Jana and He, He and Bowman, Samuel R.}, 28 | journal={arXiv preprint arXiv:2112.08608}, 29 | year={2021} 30 | }""" 31 | 32 | _DESCRIPTION = """QuALITY: Question Answering with Long Input Texts, Yes!""" 33 | 34 | _LICENSE = "CC BY-SA (Creative Commons / Attribution-ShareAlike)" 35 | 36 | _LABEL_NAMES = [f'choice_{i}' for i in range(5)] 37 | 38 | 39 | class QuALITY(datasets.GeneratorBasedBuilder): 40 | """QuALITY: Question Answering with Long Input Texts, Yes!""" 41 | 42 | VERSION = datasets.Version("1.1.0") 43 | 44 | BUILDER_CONFIGS = [ 45 | datasets.BuilderConfig( 46 | name="quality", version=VERSION, description="QuALITY: Question Answering with Long Input Texts" 47 | ), 48 | ] 49 | 50 | DEFAULT_CONFIG_NAME = "quality" 51 | 52 | def _info(self): 53 | features = datasets.Features( 54 | { 55 | "article": datasets.Value("string"), 56 | "question": datasets.Value("string"), 57 | "options": datasets.Sequence(datasets.Value("string")), 58 | "label": datasets.ClassLabel(names=_LABEL_NAMES), 59 | } 60 | ) 61 | return datasets.DatasetInfo( 62 | # This is the description that will appear on the datasets page. 63 | description=_DESCRIPTION, 64 | # This defines the different columns of the dataset and their types 65 | features=features, # Here we define them above because they are different between the two configurations 66 | # If there's a common (input, target) tuple from the features, 67 | # specify them here. They'll be used if as_supervised=True in 68 | # builder.as_dataset. 69 | supervised_keys=None, 70 | # Homepage of the dataset for documentation 71 | homepage='https://github.com/nyu-mll/quality', 72 | # License for the dataset if available 73 | license=_LICENSE, 74 | # Citation for the dataset 75 | citation=_CITATION, 76 | ) 77 | 78 | def _split_generators(self, dl_manager): 79 | """Returns SplitGenerators.""" 80 | data_dir = dl_manager.download_and_extract('QuALITY.v1.0.zip') 81 | return [ 82 | datasets.SplitGenerator( 83 | name=datasets.Split.TRAIN, 84 | # These kwargs will be passed to _generate_examples 85 | gen_kwargs={ 86 | "filepath": os.path.join(data_dir, 'QuALITY.v1.0.htmlstripped.train'), 87 | "split": "train", 88 | }, 89 | ), 90 | datasets.SplitGenerator( 91 | name=datasets.Split.TEST, 92 | # These kwargs will be passed to _generate_examples 93 | gen_kwargs={"filepath": os.path.join(data_dir, 'QuALITY.v1.0.htmlstripped.test'), 94 | "split": "test"}, 95 | ), 96 | datasets.SplitGenerator( 97 | name=datasets.Split.VALIDATION, 98 | # These kwargs will be passed to _generate_examples 99 | gen_kwargs={ 100 | "filepath": os.path.join(data_dir, 'QuALITY.v1.0.htmlstripped.dev'), 101 | "split": "dev", 102 | }, 103 | ), 104 | ] 105 | 106 | def _generate_examples( 107 | self, filepath, split # method parameters are unpacked from `gen_kwargs` as given in `_split_generators` 108 | ): 109 | """Yields examples as (key, example) tuples.""" 110 | 111 | with open(filepath, encoding="utf-8") as f: 112 | count = 0 113 | for id_, row in enumerate(f): 114 | data = json.loads(row) 115 | for question in data['questions']: 116 | count += 1 117 | yield count, { 118 | 'article': data['article'], 119 | 'question': question['question'], 120 | 'options': question['options'], 121 | 'label': f"choice_{question['gold_label']}" 122 | } 123 | -------------------------------------------------------------------------------- /models/hat/convert_roberta_to_htf.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import warnings 4 | from data import DATA_DIR 5 | from transformers import AutoModelForMaskedLM, AutoTokenizer 6 | from models.hat import HATForMaskedLM, HATConfig, HATTokenizer 7 | warnings.filterwarnings("ignore") 8 | 9 | LAYOUTS = { 10 | 'p1': 'S|S|SD|S|S|SD|S|S|SD|S|S|SD', 11 | 'l1': 'S|S|S|S|S|SD|S|SD|S|SD|S|SD', 12 | 'f12': 'S|S|S|S|S|S|S|S|S|S|S|SD|D|D|D' 13 | } 14 | 15 | 16 | def convert_roberta_to_htf(): 17 | ''' set default hyperparams in default_hyperparams.py ''' 18 | parser = argparse.ArgumentParser() 19 | 20 | # Required arguments 21 | parser.add_argument('--layout', default='p1', choices=['p1', 'l1', 'f12'], 22 | help='S|D encoders layout') 23 | parser.add_argument('--max_sentences', default=32) 24 | config = parser.parse_args() 25 | MAX_SENTENCE_LENGTH = 128 26 | MAX_SENTENCES = int(config.max_sentences) 27 | ENCODER_LAYOUT = {} 28 | for idx, block_pattern in enumerate(LAYOUTS[config.layout].split('|')): 29 | ENCODER_LAYOUT[str(idx)] = {"sentence_encoder": True if 'S' in block_pattern else False, 30 | "document_encoder": True if 'D' in block_pattern else False} 31 | 32 | NUM_HIDDEN_LAYERS = len(ENCODER_LAYOUT.keys()) 33 | ROBERTA_CHECKPOINT = 'roberta-base' 34 | 35 | # load pre-trained bert model and tokenizer 36 | roberta_model = AutoModelForMaskedLM.from_pretrained(ROBERTA_CHECKPOINT) 37 | tokenizer = AutoTokenizer.from_pretrained(ROBERTA_CHECKPOINT, model_max_length=MAX_SENTENCE_LENGTH * MAX_SENTENCES) 38 | 39 | # load dummy config and change specifications 40 | roberta_config = roberta_model.config 41 | htf_config = HATConfig.from_pretrained(f'{DATA_DIR}/hat') 42 | # Text length parameters 43 | htf_config.max_sentence_length = MAX_SENTENCE_LENGTH 44 | htf_config.max_sentences = MAX_SENTENCES 45 | htf_config.max_position_embeddings = MAX_SENTENCE_LENGTH + 2 46 | htf_config.model_max_length = int(MAX_SENTENCE_LENGTH * MAX_SENTENCES) 47 | htf_config.num_hidden_layers = NUM_HIDDEN_LAYERS 48 | # Transformer parameters 49 | htf_config.hidden_size = roberta_config.hidden_size 50 | htf_config.intermediate_size = roberta_config.intermediate_size 51 | htf_config.num_attention_heads = roberta_config.num_attention_heads 52 | htf_config.hidden_act = roberta_config.hidden_act 53 | htf_config.encoder_layout = ENCODER_LAYOUT 54 | # Vocabulary parameters 55 | htf_config.vocab_size = roberta_config.vocab_size 56 | htf_config.pad_token_id = roberta_config.pad_token_id 57 | htf_config.bos_token_id = roberta_config.bos_token_id 58 | htf_config.eos_token_id = roberta_config.eos_token_id 59 | htf_config.type_vocab_size = roberta_config.type_vocab_size 60 | 61 | # load dummy hi-transformer model 62 | htf_model = HATForMaskedLM.from_config(htf_config) 63 | 64 | # copy embeddings 65 | htf_model.hi_transformer.embeddings.position_embeddings.weight.data = roberta_model.roberta.embeddings.position_embeddings.weight[:MAX_SENTENCE_LENGTH+roberta_config.pad_token_id+1] 66 | htf_model.hi_transformer.embeddings.word_embeddings.load_state_dict(roberta_model.roberta.embeddings.word_embeddings.state_dict()) 67 | htf_model.hi_transformer.embeddings.token_type_embeddings.load_state_dict(roberta_model.roberta.embeddings.token_type_embeddings.state_dict()) 68 | htf_model.hi_transformer.embeddings.LayerNorm.load_state_dict(roberta_model.roberta.embeddings.LayerNorm.state_dict()) 69 | 70 | # copy transformer layers 71 | for idx in range(min(NUM_HIDDEN_LAYERS, roberta_config.num_hidden_layers)): 72 | if htf_model.config.encoder_layout[str(idx)]['sentence_encoder']: 73 | htf_model.hi_transformer.encoder.layer[idx].sentence_encoder.load_state_dict(roberta_model.roberta.encoder.layer[idx].state_dict()) 74 | if htf_model.config.encoder_layout[str(idx)]['document_encoder']: 75 | htf_model.hi_transformer.encoder.layer[idx].document_encoder.load_state_dict(roberta_model.roberta.encoder.layer[idx].state_dict()) 76 | htf_model.hi_transformer.encoder.layer[idx].position_embeddings.weight.data = roberta_model.roberta.embeddings.position_embeddings.weight[1:MAX_SENTENCES+2] 77 | 78 | # copy lm_head 79 | htf_model.lm_head.load_state_dict(roberta_model.lm_head.state_dict()) 80 | 81 | # save model 82 | htf_model.save_pretrained(f'{DATA_DIR}/PLMs/hat-{config.layout}-roberta') 83 | 84 | # save tokenizer 85 | tokenizer.save_pretrained(f'{DATA_DIR}/PLMs/hat-{config.layout}-roberta') 86 | 87 | # re-load model 88 | htf_model = HATForMaskedLM.from_pretrained(f'{DATA_DIR}/PLMs/hat-{config.layout}-roberta') 89 | htf_tokenizer = HATTokenizer.from_pretrained(f'{DATA_DIR}/PLMs/hat-{config.layout}-roberta') 90 | print(f'RoBERTa-based HAT model with layout {config.layout} is ready to run!') 91 | 92 | # input_ids = torch.randint(1, 30000, (2, 1024), dtype=torch.long) 93 | # input_ids[:, :: 128] = htf_tokenizer.cls_token_id 94 | # labels = input_ids.clone() 95 | # attention_mask = torch.ones((2, 1024), dtype=torch.int) 96 | # htf_model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) 97 | # roberta_model(input_ids=input_ids[:, :128], attention_mask=attention_mask[:, :128], labels=labels[:, :128]) 98 | 99 | 100 | if __name__ == '__main__': 101 | convert_roberta_to_htf() 102 | -------------------------------------------------------------------------------- /data/ecthr-arguments-dataset/create_dataset.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import glob 3 | import json 4 | import re 5 | from collections import Counter 6 | import pandas as pd 7 | from ast import literal_eval 8 | USED_LABELS = ['Subsumtion', 9 | 'Vorherige Rechtsprechung des EGMR', 10 | 'Verhältnismäßigkeitsprüfung – Angemessenheit', #Proportionality 11 | 'Entscheidung des EGMR', 12 | 'Verhältnismäßigkeitsprüfung – Rechtsgrundlage', #Proportionality 13 | 'Verhältnismäßigkeitsprüfung – Legitimer Zweck', #Proportionality 14 | 'Konsens der prozessualen Parteien'] 15 | 16 | ECTHR_ARG_TYPES = ['Application', 17 | 'Precedent', 18 | 'Proportionality', #Proportionality 19 | 'Decision', 20 | 'Legal Basis', #Proportionality 21 | 'Legitimate Purpose', #Proportionality 22 | 'Non Contestation'] 23 | 24 | all_labels = [] 25 | article_paragraphs = [] 26 | sample_ids = [] 27 | article_paragraph_labels = [] 28 | 29 | 30 | def fix_labels(labels): 31 | labels = list(set([re.sub('[BI]-', '', label) for label in labels])) 32 | if len(labels) >= 2 and 'O' in labels: 33 | labels.remove('O') 34 | labels = [ECTHR_ARG_TYPES[USED_LABELS.index(label)] for label in labels if label in USED_LABELS] 35 | if len(labels) == 0: 36 | labels = ['O'] 37 | return labels 38 | 39 | short_chunks = [] 40 | 41 | for subset in ['train', 'val', 'test']: 42 | for filename in glob.glob(f'../mining-legal-arguments/data/{subset}/argType/*.csv'): 43 | with open(filename) as file: 44 | df = pd.read_csv(file, sep='\t', encoding='utf-8') 45 | df['labels'] = df['labels'].map(lambda x: literal_eval(x)) 46 | df['tokens'] = df['tokens'].map(lambda x: literal_eval(x)) 47 | paragraphs = [] 48 | paragraph_labels = [] 49 | temp_paragraph = '' 50 | temp_labels = [] 51 | for tokens, token_labels in zip(df['tokens'], df['labels']): 52 | paragraph = re.sub(r'([\(\-\[]) ', r'\1', re.sub(r' ([\.\)\,\:\;\'\-\]]|\'s)', r'\1', ' '.join(tokens))) 53 | labels = fix_labels(token_labels) 54 | if len(labels) > 1: 55 | print() 56 | if re.match('(FOR THESE REASONS, THE COURT|for these reasons, the court unanimously)', paragraph): 57 | break 58 | if not re.match('\d{2,}\.', paragraph) and not re.match('[IVX]+\. [A-Z]{2,}', paragraph): 59 | if len(paragraph.split()) <= 5 and re.match('(([A-Z]|\d)\.|\([a-zα-ω]+\))', paragraph): 60 | short_chunks.append(paragraph) 61 | continue 62 | temp_paragraph = temp_paragraph + '\n' + paragraph 63 | temp_labels.extend(labels) 64 | continue 65 | elif len(paragraphs) and len(temp_paragraph): 66 | paragraphs[-1] = paragraphs[-1] + '' + temp_paragraph 67 | paragraph_labels[-1] = list(set(copy.deepcopy(paragraph_labels[-1]) + copy.deepcopy(temp_labels))) 68 | if len(paragraph_labels[-1]) > 1 and 'O' in paragraph_labels[-1]: 69 | paragraph_labels[-1].remove('O') 70 | temp_paragraph = '' 71 | temp_labels = [] 72 | if len(paragraph.split()) <= 10 and not re.match('\d{2,}\.', paragraph) and not re.match('[IVX]+\. [A-Z]{2,}', paragraph): 73 | continue 74 | if re.match('[IVX]+\. [A-Z]{2,}', paragraph) and len(paragraphs): 75 | article_paragraphs.append(copy.deepcopy(paragraphs)) 76 | article_paragraph_labels.append(copy.deepcopy(paragraph_labels)) 77 | sample_ids.append(filename.split('argType/')[1].replace('.csv', '')) 78 | paragraphs = [] 79 | paragraph_labels = [] 80 | paragraphs.append(paragraph) 81 | paragraph_labels.append(labels) 82 | article_paragraphs.append(copy.deepcopy(paragraphs)) 83 | article_paragraph_labels.append(copy.deepcopy(paragraph_labels)) 84 | sample_ids.append(filename.split('argType/')[1].replace('.csv', '')) 85 | 86 | article_paragraphs_clean = [] 87 | article_paragraph_labels_clean = [] 88 | for paragraphs, paragraph_labels in zip(article_paragraphs, article_paragraph_labels): 89 | if len(paragraphs) == len(paragraph_labels): 90 | article_paragraphs_clean.append(copy.deepcopy(paragraphs[:32])) 91 | article_paragraph_labels_clean.append(copy.deepcopy(paragraph_labels[:32])) 92 | all_labels.extend([label for label_group in copy.deepcopy(paragraph_labels[:32]) for label in label_group]) 93 | 94 | label_counts = Counter(all_labels) 95 | n_paragraphs = [] 96 | for paragraphs in article_paragraphs_clean: 97 | if len(paragraphs) <= 32: 98 | n_paragraphs.append(32) 99 | elif len(paragraphs) <= 64: 100 | n_paragraphs.append(64) 101 | elif len(paragraphs) <= 128: 102 | n_paragraphs.append(128) 103 | else: 104 | n_paragraphs.append('Long') 105 | 106 | par_counts = Counter(n_paragraphs) 107 | print(par_counts.most_common()) 108 | 109 | count = 0 110 | with open('ecthr_arguments.jsonl', 'w') as file: 111 | for paragraphs, labels, sample_id in zip(article_paragraphs_clean, article_paragraph_labels_clean, sample_ids): 112 | count += 1 113 | if count <= 900: 114 | data_type = 'train' 115 | elif count <= 1000: 116 | data_type = 'dev' 117 | elif count <= 1100: 118 | data_type = 'test' 119 | else: 120 | break 121 | if labels is None: 122 | print() 123 | else: 124 | for paragraph_labels in labels: 125 | if paragraph_labels is None: 126 | print() 127 | file.write(json.dumps({'case_id': sample_id, 'paragraphs': paragraphs, 'labels': labels, 'data_type': data_type}) + '\n') 128 | 129 | 130 | label_counts = {'train': [], 'dev': [], 'test': []} 131 | with open('ecthr_arguments.jsonl', ) as file: 132 | for line in file: 133 | data = json.loads(line) 134 | label_counts[data['data_type']].extend([label for par_labels in data['labels'] for label in par_labels]) 135 | 136 | for key in label_counts: 137 | print(Counter(label_counts[key]).most_common()) -------------------------------------------------------------------------------- /models/longformer/convert_roberta_to_lf.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | import copy 5 | import warnings 6 | from data import DATA_DIR 7 | from transformers import AutoModelForMaskedLM, AutoTokenizer, AutoConfig 8 | warnings.filterwarnings("ignore") 9 | 10 | 11 | def convert_roberta_to_htf(): 12 | ''' set default hyperparams in default_hyperparams.py ''' 13 | parser = argparse.ArgumentParser() 14 | 15 | # Required arguments 16 | parser.add_argument('--max_sentences', default=32) 17 | parser.add_argument('--num_hidden_layers', default=12) 18 | config = parser.parse_args() 19 | MAX_SENTENCE_LENGTH = 128 20 | MAX_SENTENCES = int(config.max_sentences) 21 | NUM_HIDDEN_LAYERS = int(config.num_hidden_layers) 22 | BERT_CHECKPOINT = 'roberta-base' 23 | 24 | # load pre-trained bert model and tokenizer 25 | roberta_model = AutoModelForMaskedLM.from_pretrained(BERT_CHECKPOINT) 26 | tokenizer = AutoTokenizer.from_pretrained(BERT_CHECKPOINT, model_max_length=MAX_SENTENCE_LENGTH * MAX_SENTENCES) 27 | 28 | # load dummy config and change specifications 29 | roberta_config = roberta_model.config 30 | lf_config = AutoConfig.from_pretrained('allenai/longformer-base-4096') 31 | # Text length parameters 32 | lf_config.max_position_embeddings = int(MAX_SENTENCE_LENGTH * MAX_SENTENCES) + roberta_config.pad_token_id + 2 33 | lf_config.model_max_length = int(MAX_SENTENCE_LENGTH * MAX_SENTENCES) 34 | lf_config.num_hidden_layers = NUM_HIDDEN_LAYERS 35 | # Transformer parameters 36 | lf_config.hidden_size = roberta_config.hidden_size 37 | lf_config.intermediate_size = roberta_config.intermediate_size 38 | lf_config.num_attention_heads = roberta_config.num_attention_heads 39 | lf_config.hidden_act = roberta_config.hidden_act 40 | lf_config.attention_window = [MAX_SENTENCE_LENGTH] * NUM_HIDDEN_LAYERS 41 | # Vocabulary parameters 42 | lf_config.vocab_size = roberta_config.vocab_size 43 | lf_config.pad_token_id = roberta_config.pad_token_id 44 | lf_config.bos_token_id = roberta_config.bos_token_id 45 | lf_config.eos_token_id = roberta_config.eos_token_id 46 | lf_config.cls_token_id = tokenizer.cls_token_id 47 | lf_config.sep_token_id = tokenizer.sep_token_id 48 | lf_config.type_vocab_size = roberta_config.type_vocab_size 49 | 50 | # load dummy hi-transformer model 51 | lf_model = AutoModelForMaskedLM.from_config(lf_config) 52 | 53 | # copy embeddings 54 | k = 2 55 | step = roberta_config.max_position_embeddings - 2 56 | while k < lf_config.max_position_embeddings - 1: 57 | if k + step >= lf_config.max_position_embeddings: 58 | lf_model.longformer.embeddings.position_embeddings.weight.data[k:] = roberta_model.roberta.embeddings.position_embeddings.weight[2:(roberta_config.max_position_embeddings + 2 - k)] 59 | else: 60 | lf_model.longformer.embeddings.position_embeddings.weight.data[k:(k + step)] = roberta_model.roberta.embeddings.position_embeddings.weight[2:] 61 | k += step 62 | lf_model.longformer.embeddings.word_embeddings.load_state_dict(roberta_model.roberta.embeddings.word_embeddings.state_dict()) 63 | lf_model.longformer.embeddings.token_type_embeddings.load_state_dict(roberta_model.roberta.embeddings.token_type_embeddings.state_dict()) 64 | lf_model.longformer.embeddings.LayerNorm.load_state_dict(roberta_model.roberta.embeddings.LayerNorm.state_dict()) 65 | 66 | # copy transformer layers 67 | roberta_model.roberta.encoder.layer = roberta_model.roberta.encoder.layer[:NUM_HIDDEN_LAYERS] 68 | for i in range(len(roberta_model.roberta.encoder.layer)): 69 | # generic 70 | lf_model.longformer.encoder.layer[i].intermediate.dense = copy.deepcopy( 71 | roberta_model.roberta.encoder.layer[i].intermediate.dense) 72 | lf_model.longformer.encoder.layer[i].output.dense = copy.deepcopy( 73 | roberta_model.roberta.encoder.layer[i].output.dense) 74 | lf_model.longformer.encoder.layer[i].output.LayerNorm = copy.deepcopy( 75 | roberta_model.roberta.encoder.layer[i].output.LayerNorm) 76 | # attention output 77 | lf_model.longformer.encoder.layer[i].attention.output.dense = copy.deepcopy( 78 | roberta_model.roberta.encoder.layer[i].attention.output.dense) 79 | lf_model.longformer.encoder.layer[i].attention.output.LayerNorm = copy.deepcopy( 80 | roberta_model.roberta.encoder.layer[i].attention.output.LayerNorm) 81 | # local q,k,v 82 | lf_model.longformer.encoder.layer[i].attention.self.query = copy.deepcopy( 83 | roberta_model.roberta.encoder.layer[i].attention.self.query) 84 | lf_model.longformer.encoder.layer[i].attention.self.key = copy.deepcopy( 85 | roberta_model.roberta.encoder.layer[i].attention.self.key) 86 | lf_model.longformer.encoder.layer[i].attention.self.value = copy.deepcopy( 87 | roberta_model.roberta.encoder.layer[i].attention.self.value) 88 | # global q,k,v 89 | lf_model.longformer.encoder.layer[i].attention.self.query_global = copy.deepcopy( 90 | roberta_model.roberta.encoder.layer[i].attention.self.query) 91 | lf_model.longformer.encoder.layer[i].attention.self.key_global = copy.deepcopy( 92 | roberta_model.roberta.encoder.layer[i].attention.self.key) 93 | lf_model.longformer.encoder.layer[i].attention.self.value_global = copy.deepcopy( 94 | roberta_model.roberta.encoder.layer[i].attention.self.value) 95 | 96 | # copy lm_head 97 | lf_model.lm_head.load_state_dict(roberta_model.lm_head.state_dict()) 98 | 99 | # check position ids 100 | # batch = tokenizer(['this is a dog', 'this is a cat'], return_tensors='pt') 101 | # lf_model(input_ids=batch['input_ids'], attention_mask=batch['attention_mask']) 102 | 103 | # save model 104 | lf_model.save_pretrained(f'{DATA_DIR}/PLMs/longformer-roberta-{NUM_HIDDEN_LAYERS}') 105 | 106 | # save tokenizer 107 | tokenizer.save_pretrained(f'{DATA_DIR}/PLMs/longformer-roberta-{NUM_HIDDEN_LAYERS}') 108 | 109 | # re-load model 110 | lf_model = AutoModelForMaskedLM.from_pretrained(f'{DATA_DIR}/PLMs/longformer-roberta-{NUM_HIDDEN_LAYERS}') 111 | lf_tokenizer = AutoTokenizer.from_pretrained(f'{DATA_DIR}/PLMs/longformer-roberta-{NUM_HIDDEN_LAYERS}') 112 | # batch = tokenizer(['this is a dog', 'this is a cat'], return_tensors='pt') 113 | # lf_model(input_ids=batch['input_ids'], attention_mask=batch['attention_mask']) 114 | print(f'RoBERTa-based Longformer model is ready to run!') 115 | 116 | 117 | if __name__ == '__main__': 118 | convert_roberta_to_htf() 119 | -------------------------------------------------------------------------------- /models/longformer/convert_bert_to_lf.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | import copy 5 | import warnings 6 | from data import DATA_DIR 7 | from transformers import AutoModelForMaskedLM, AutoTokenizer, AutoConfig 8 | warnings.filterwarnings("ignore") 9 | 10 | 11 | def convert_bert_to_htf(): 12 | ''' set default hyperparams in default_hyperparams.py ''' 13 | parser = argparse.ArgumentParser() 14 | 15 | # Required arguments 16 | parser.add_argument('--max_sentences', default=8) 17 | config = parser.parse_args() 18 | MAX_SENTENCE_LENGTH = 128 19 | MAX_SENTENCES = int(config.max_sentences) 20 | NUM_HIDDEN_LAYERS = 6 21 | BERT_LAYERS = NUM_HIDDEN_LAYERS 22 | BERT_CHECKPOINT = f'google/bert_uncased_L-{str(BERT_LAYERS)}_H-256_A-4' 23 | 24 | # load pre-trained bert model and tokenizer 25 | bert_model = AutoModelForMaskedLM.from_pretrained(BERT_CHECKPOINT) 26 | tokenizer = AutoTokenizer.from_pretrained(BERT_CHECKPOINT, model_max_length=MAX_SENTENCE_LENGTH * MAX_SENTENCES) 27 | 28 | # load dummy config and change specifications 29 | bert_config = bert_model.config 30 | lf_config = AutoConfig.from_pretrained('allenai/longformer-base-4096') 31 | # Text length parameters 32 | lf_config.max_position_embeddings = int(MAX_SENTENCE_LENGTH * 8) + bert_config.pad_token_id + 2 33 | lf_config.model_max_length = int(MAX_SENTENCE_LENGTH * MAX_SENTENCES) 34 | lf_config.num_hidden_layers = NUM_HIDDEN_LAYERS 35 | # Transformer parameters 36 | lf_config.hidden_size = bert_config.hidden_size 37 | lf_config.intermediate_size = bert_config.intermediate_size 38 | lf_config.num_attention_heads = bert_config.num_attention_heads 39 | lf_config.hidden_act = bert_config.hidden_act 40 | lf_config.attention_window = [MAX_SENTENCE_LENGTH] * NUM_HIDDEN_LAYERS 41 | # Vocabulary parameters 42 | lf_config.vocab_size = bert_config.vocab_size 43 | lf_config.pad_token_id = bert_config.pad_token_id 44 | lf_config.bos_token_id = bert_config.bos_token_id 45 | lf_config.eos_token_id = bert_config.eos_token_id 46 | lf_config.cls_token_id = tokenizer.cls_token_id 47 | lf_config.sep_token_id = tokenizer.sep_token_id 48 | lf_config.type_vocab_size = bert_config.type_vocab_size 49 | 50 | # load dummy hi-transformer model 51 | lf_model = AutoModelForMaskedLM.from_config(lf_config) 52 | 53 | # copy embeddings 54 | lf_model.longformer.embeddings.position_embeddings.weight.data[0] = torch.zeros((bert_config.hidden_size,)) 55 | k = 1 56 | step = bert_config.max_position_embeddings - 1 57 | while k < lf_config.max_position_embeddings - 1: 58 | if k + step >= lf_config.max_position_embeddings: 59 | lf_model.longformer.embeddings.position_embeddings.weight.data[k:] = bert_model.bert.embeddings.position_embeddings.weight[1:(bert_config.max_position_embeddings + 1 - k)] 60 | else: 61 | lf_model.longformer.embeddings.position_embeddings.weight.data[k:(k + step)] = bert_model.bert.embeddings.position_embeddings.weight[1:] 62 | k += step 63 | lf_model.longformer.embeddings.word_embeddings.load_state_dict(bert_model.bert.embeddings.word_embeddings.state_dict()) 64 | lf_model.longformer.embeddings.token_type_embeddings.load_state_dict(bert_model.bert.embeddings.token_type_embeddings.state_dict()) 65 | lf_model.longformer.embeddings.LayerNorm.load_state_dict(bert_model.bert.embeddings.LayerNorm.state_dict()) 66 | 67 | # copy transformer layers 68 | for i in range(len(bert_model.bert.encoder.layer)): 69 | # generic 70 | lf_model.longformer.encoder.layer[i].intermediate.dense = copy.deepcopy( 71 | bert_model.bert.encoder.layer[i].intermediate.dense) 72 | lf_model.longformer.encoder.layer[i].output.dense = copy.deepcopy( 73 | bert_model.bert.encoder.layer[i].output.dense) 74 | lf_model.longformer.encoder.layer[i].output.LayerNorm = copy.deepcopy( 75 | bert_model.bert.encoder.layer[i].output.LayerNorm) 76 | # attention output 77 | lf_model.longformer.encoder.layer[i].attention.output.dense = copy.deepcopy( 78 | bert_model.bert.encoder.layer[i].attention.output.dense) 79 | lf_model.longformer.encoder.layer[i].attention.output.LayerNorm = copy.deepcopy( 80 | bert_model.bert.encoder.layer[i].attention.output.LayerNorm) 81 | # local q,k,v 82 | lf_model.longformer.encoder.layer[i].attention.self.query = copy.deepcopy( 83 | bert_model.bert.encoder.layer[i].attention.self.query) 84 | lf_model.longformer.encoder.layer[i].attention.self.key = copy.deepcopy( 85 | bert_model.bert.encoder.layer[i].attention.self.key) 86 | lf_model.longformer.encoder.layer[i].attention.self.value = copy.deepcopy( 87 | bert_model.bert.encoder.layer[i].attention.self.value) 88 | # global q,k,v 89 | lf_model.longformer.encoder.layer[i].attention.self.query_global = copy.deepcopy( 90 | bert_model.bert.encoder.layer[i].attention.self.query) 91 | lf_model.longformer.encoder.layer[i].attention.self.key_global = copy.deepcopy( 92 | bert_model.bert.encoder.layer[i].attention.self.key) 93 | lf_model.longformer.encoder.layer[i].attention.self.value_global = copy.deepcopy( 94 | bert_model.bert.encoder.layer[i].attention.self.value) 95 | 96 | # copy lm_head 97 | lf_model.lm_head.dense.load_state_dict(bert_model.cls.predictions.transform.dense.state_dict()) 98 | lf_model.lm_head.layer_norm.load_state_dict(bert_model.cls.predictions.transform.LayerNorm.state_dict()) 99 | lf_model.lm_head.decoder.load_state_dict(bert_model.cls.predictions.decoder.state_dict()) 100 | lf_model.lm_head.bias = copy.deepcopy(bert_model.cls.predictions.bias) 101 | 102 | # check position ids 103 | # batch = tokenizer(['this is a dog', 'this is a cat'], return_tensors='pt') 104 | # lf_model(input_ids=batch['input_ids'], attention_mask=batch['attention_mask']) 105 | 106 | # save model 107 | lf_model.save_pretrained(f'{DATA_DIR}/PLMs/longformer') 108 | 109 | # save tokenizer 110 | tokenizer.save_pretrained(f'{DATA_DIR}/PLMs/longformer') 111 | 112 | # re-load model 113 | lf_model = AutoModelForMaskedLM.from_pretrained(f'{DATA_DIR}/PLMs/longformer') 114 | lf_tokenizer = AutoTokenizer.from_pretrained(f'{DATA_DIR}/PLMs/longformer') 115 | # batch = tokenizer(['this is a dog', 'this is a cat'], return_tensors='pt') 116 | # lf_model(input_ids=batch['input_ids'], attention_mask=batch['attention_mask']) 117 | print(f'Longformer model is ready to run!') 118 | 119 | 120 | if __name__ == '__main__': 121 | convert_bert_to_htf() 122 | -------------------------------------------------------------------------------- /models/hat/convert_bert_to_htf.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | import copy 5 | import warnings 6 | from data import DATA_DIR 7 | from transformers import AutoModelForMaskedLM, AutoTokenizer 8 | from models.hat import HATForMaskedLM, HATConfig, HATTokenizer 9 | warnings.filterwarnings("ignore") 10 | 11 | LAYOUTS = { 12 | 's1': 'SD|SD|SD|SD|SD|SD', 13 | 's2': 'S|SD|D|S|SD|D|S|SD|D', 14 | 'p1': 'S|SD|S|SD|S|SD|S|SD', 15 | 'p2': 'S|S|SD|S|S|SD|S|S|SD', 16 | 'e1': 'SD|SD|SD|S|S|S|S|S|S', 17 | 'e2': 'S|SD|D|S|SD|D|S|S|S|S', 18 | 'l1': 'S|S|S|S|S|S|SD|SD|SD', 19 | 'l2': 'S|S|S|S|S|SD|D|S|SD|D', 20 | 'b1': 'S|S|SD|D|S|SD|D|S|S|S', 21 | 'b2': 'S|S|SD|SD|SD|S|S|S|S', 22 | 'f12': 'S|S|S|S|S|S|S|S|S|S|S|S', 23 | 'f8': 'S|S|S|S|S|S|S|S', 24 | 'f6': 'S|S|S|S|S|S', 25 | } 26 | 27 | 28 | def convert_bert_to_htf(): 29 | ''' set default hyperparams in default_hyperparams.py ''' 30 | parser = argparse.ArgumentParser() 31 | 32 | # Required arguments 33 | parser.add_argument('--warmup_strategy', default='grouped', choices=['linear', 'grouped', 'random', 'embeds-only', 'none'], 34 | help='linear: S|D encoders are warm-started independently (one-by-one)' 35 | 'grouped: pairs of S|D are warm-started with weights from the very same level' 36 | 'random: D encoders are not warm-started' 37 | 'embeds-only: No warm-starting, except embeddings' 38 | 'none: No warm-starting') 39 | parser.add_argument('--layout', default='s1', choices=['s1', 's2', 'p1', 'p2', 'e1', 'e2', 40 | 'l1', 'l2', 'b1', 'b2', 'f12', 'f8', 'f6'], 41 | help='S|D encoders layout') 42 | parser.add_argument('--max_sentences', default=8) 43 | config = parser.parse_args() 44 | MAX_SENTENCE_LENGTH = 128 45 | MAX_SENTENCES = int(config.max_sentences) 46 | ENCODER_LAYOUT = {} 47 | for idx, block_pattern in enumerate(LAYOUTS[config.layout].split('|')): 48 | ENCODER_LAYOUT[str(idx)] = {"sentence_encoder": True if 'S' in block_pattern else False, 49 | "document_encoder": True if 'D' in block_pattern else False} 50 | 51 | NUM_HIDDEN_LAYERS = len(ENCODER_LAYOUT.keys()) 52 | BERT_LAYERS = NUM_HIDDEN_LAYERS if config.warmup_strategy != 'linear' else NUM_HIDDEN_LAYERS*2 53 | BERT_LAYERS = BERT_LAYERS + 1 if BERT_LAYERS % 2 else BERT_LAYERS 54 | BERT_CHECKPOINT = f'google/bert_uncased_L-{str(BERT_LAYERS)}_H-256_A-4' 55 | 56 | # load pre-trained bert model and tokenizer 57 | bert_model = AutoModelForMaskedLM.from_pretrained(BERT_CHECKPOINT) 58 | tokenizer = AutoTokenizer.from_pretrained(BERT_CHECKPOINT, model_max_length=MAX_SENTENCE_LENGTH * MAX_SENTENCES) 59 | 60 | # load dummy config and change specifications 61 | bert_config = bert_model.config 62 | htf_config = HATConfig.from_pretrained(f'{DATA_DIR}/hat') 63 | # Text length parameters 64 | htf_config.max_sentence_length = MAX_SENTENCE_LENGTH 65 | htf_config.max_sentences = MAX_SENTENCES 66 | htf_config.max_position_embeddings = MAX_SENTENCE_LENGTH 67 | htf_config.model_max_length = int(MAX_SENTENCE_LENGTH * MAX_SENTENCES) 68 | htf_config.num_hidden_layers = NUM_HIDDEN_LAYERS 69 | # Transformer parameters 70 | htf_config.hidden_size = bert_config.hidden_size 71 | htf_config.intermediate_size = bert_config.intermediate_size 72 | htf_config.num_attention_heads = bert_config.num_attention_heads 73 | htf_config.hidden_act = bert_config.hidden_act 74 | htf_config.encoder_layout = ENCODER_LAYOUT 75 | # Vocabulary parameters 76 | htf_config.vocab_size = bert_config.vocab_size 77 | htf_config.pad_token_id = bert_config.pad_token_id 78 | htf_config.bos_token_id = bert_config.bos_token_id 79 | htf_config.eos_token_id = bert_config.eos_token_id 80 | htf_config.type_vocab_size = bert_config.type_vocab_size 81 | 82 | # load dummy hi-transformer model 83 | htf_model = HATForMaskedLM.from_config(htf_config) 84 | 85 | if config.warmup_strategy != 'none': 86 | # copy embeddings 87 | htf_model.hi_transformer.embeddings.position_embeddings.weight.data[0] = torch.zeros((bert_config.hidden_size,)) 88 | htf_model.hi_transformer.embeddings.position_embeddings.weight.data[1:] = bert_model.bert.embeddings.position_embeddings.weight[1:MAX_SENTENCE_LENGTH+htf_config.pad_token_id+1] 89 | htf_model.hi_transformer.embeddings.word_embeddings.load_state_dict(bert_model.bert.embeddings.word_embeddings.state_dict()) 90 | htf_model.hi_transformer.embeddings.token_type_embeddings.load_state_dict(bert_model.bert.embeddings.token_type_embeddings.state_dict()) 91 | htf_model.hi_transformer.embeddings.LayerNorm.load_state_dict(bert_model.bert.embeddings.LayerNorm.state_dict()) 92 | 93 | if config.warmup_strategy != 'embeds-only': 94 | # copy transformer layers 95 | if config.warmup_strategy != 'linear': 96 | for idx in range(NUM_HIDDEN_LAYERS): 97 | if htf_model.config.encoder_layout[str(idx)]['sentence_encoder']: 98 | htf_model.hi_transformer.encoder.layer[idx].sentence_encoder.load_state_dict(bert_model.bert.encoder.layer[idx].state_dict()) 99 | if htf_model.config.encoder_layout[str(idx)]['document_encoder']: 100 | if config.warmup_strategy == 'grouped': 101 | htf_model.hi_transformer.encoder.layer[idx].document_encoder.load_state_dict(bert_model.bert.encoder.layer[idx].state_dict()) 102 | htf_model.hi_transformer.encoder.layer[idx].position_embeddings.weight.data = bert_model.bert.embeddings.position_embeddings.weight[1:MAX_SENTENCES+2] 103 | else: 104 | for idx, l_idx in enumerate(range(0, NUM_HIDDEN_LAYERS*2, 2)): 105 | if htf_model.config.encoder_layout[str(idx)]['sentence_encoder']: 106 | htf_model.hi_transformer.encoder.layer[idx].sentence_encoder.load_state_dict(bert_model.bert.encoder.layer[l_idx].state_dict()) 107 | if htf_model.config.encoder_layout[str(idx)]['document_encoder']: 108 | htf_model.hi_transformer.encoder.layer[idx].document_encoder.load_state_dict(bert_model.bert.encoder.layer[l_idx+1].state_dict()) 109 | htf_model.hi_transformer.encoder.layer[idx].position_embeddings.weight.data = bert_model.bert.embeddings.position_embeddings.weight[1:MAX_SENTENCES+2] 110 | 111 | # copy lm_head 112 | htf_model.lm_head.dense.load_state_dict(bert_model.cls.predictions.transform.dense.state_dict()) 113 | htf_model.lm_head.layer_norm.load_state_dict(bert_model.cls.predictions.transform.LayerNorm.state_dict()) 114 | htf_model.lm_head.decoder.load_state_dict(bert_model.cls.predictions.decoder.state_dict()) 115 | htf_model.lm_head.bias = copy.deepcopy(bert_model.cls.predictions.bias) 116 | 117 | # save model 118 | htf_model.save_pretrained(f'{DATA_DIR}/PLMs/hat-{config.layout}-{config.warmup_strategy}') 119 | 120 | # save tokenizer 121 | tokenizer.save_pretrained(f'{DATA_DIR}/PLMs/hat-{config.layout}-{config.warmup_strategy}') 122 | 123 | # re-load model 124 | htf_model = HATForMaskedLM.from_pretrained(f'{DATA_DIR}/PLMs/hat-{config.layout}-{config.warmup_strategy}') 125 | htf_tokenizer = HATTokenizer.from_pretrained(f'{DATA_DIR}/PLMs/hat-{config.layout}-{config.warmup_strategy}') 126 | print(f'HAT model with layout {config.layout} and warm-up strategy {config.warmup_strategy} is ready to run!') 127 | 128 | 129 | if __name__ == '__main__': 130 | convert_bert_to_htf() 131 | -------------------------------------------------------------------------------- /data/ecthr-arguments-dataset/ecthr-arguments-dataset.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ECtHRArguments""" 16 | 17 | import json 18 | import os 19 | import textwrap 20 | 21 | import datasets 22 | 23 | MAIN_CITATION = """\ 24 | @article{Habernal.et.al.2022.arg, 25 | author = {Habernal, Ivan and Faber, Daniel and Recchia, Nicola and 26 | Bretthauer, Sebastian and Gurevych, Iryna and 27 | Spiecker genannt Döhmann, Indra and Burchard, Christoph}, 28 | title = {{Mining Legal Arguments in Court Decisions}}, 29 | journal = {arXiv preprint}, 30 | year = {2022}, 31 | doi = {10.48550/arXiv.2208.06178}, 32 | }""" 33 | 34 | _DESCRIPTION = """\ 35 | The dataset contains approx. 300 cases from the European Court of Human Rights (ECtHR). For each case, the dataset provides 36 | a list of argumentative paragraphs from the case analysis. Spans in each paragraph has been labeled with one or more out 37 | of 13 argument types. We re-formulate this task, as a sequential paragraph classification task, where each paragraph is 38 | labelled with one or more labels. The input of the model is the list of paragraphs of a case, and the output is the set 39 | of relevant argument types per paragraph. 40 | """ 41 | 42 | ECTHR_ARG_TYPES = ['Application', 'Precedent', 'Proportionality', 'Decision', 43 | 'Legal Basis', 'Legitimate Purpose', 'Non Contestation'] 44 | 45 | 46 | class ECtHRArgumentsConfig(datasets.BuilderConfig): 47 | """BuilderConfig for ECtHRArguments.""" 48 | 49 | def __init__( 50 | self, 51 | text_column, 52 | label_column, 53 | url, 54 | data_url, 55 | data_file, 56 | citation, 57 | label_classes=None, 58 | multi_label=None, 59 | dev_column="dev", 60 | **kwargs, 61 | ): 62 | """BuilderConfig for ECtHRArguments. 63 | 64 | Args: 65 | text_column: ``string`, name of the column in the jsonl file corresponding 66 | to the text 67 | label_column: `string`, name of the column in the jsonl file corresponding 68 | to the label 69 | url: `string`, url for the original project 70 | data_url: `string`, url to download the zip file from 71 | data_file: `string`, filename for data set 72 | citation: `string`, citation for the data set 73 | url: `string`, url for information about the data set 74 | label_classes: `list[string]`, the list of classes if the label is 75 | categorical. If not provided, then the label will be of type 76 | `datasets.Value('float32')`. 77 | multi_label: `boolean`, True if the task is multi-label 78 | dev_column: `string`, name for the development subset 79 | **kwargs: keyword arguments forwarded to super. 80 | """ 81 | super(ECtHRArgumentsConfig, self).__init__(version=datasets.Version("1.0.0", ""), **kwargs) 82 | self.text_column = text_column 83 | self.label_column = label_column 84 | self.label_classes = label_classes 85 | self.multi_label = multi_label 86 | self.dev_column = dev_column 87 | self.url = url 88 | self.data_url = data_url 89 | self.data_file = data_file 90 | self.citation = citation 91 | 92 | 93 | class LexGLUE(datasets.GeneratorBasedBuilder): 94 | """LexGLUE: A Benchmark Dataset for Legal Language Understanding in English. Version 1.0""" 95 | 96 | BUILDER_CONFIGS = [ 97 | ECtHRArgumentsConfig( 98 | name="ecthr-arguments-dataset", 99 | description=textwrap.dedent( 100 | """\ 101 | The UKLEX dataset consists of UK laws that have been labeled with concepts. 102 | Given a document, the task is to predict its labels (concepts). 103 | """ 104 | ), 105 | text_column="paragraphs", 106 | label_column="labels", 107 | label_classes=ECTHR_ARG_TYPES, 108 | multi_label=True, 109 | dev_column="dev", 110 | data_url=f"ecthr_arguments.tar.gz", 111 | data_file="ecthr_arguments.jsonl", 112 | url="https://github.com/trusthlt/mining-legal-arguments", 113 | citation=textwrap.dedent( 114 | """@article{Habernal.et.al.2022.arg, 115 | author = {Habernal, Ivan and Faber, Daniel and Recchia, Nicola and 116 | Bretthauer, Sebastian and Gurevych, Iryna and 117 | Spiecker genannt Döhmann, Indra and Burchard, Christoph}, 118 | title = {{Mining Legal Arguments in Court Decisions}}, 119 | journal = {arXiv preprint}, 120 | year = {2022}, 121 | doi = {10.48550/arXiv.2208.06178}, 122 | }""" 123 | ), 124 | ) 125 | ] 126 | 127 | def _info(self): 128 | features = {"text": datasets.features.Sequence(datasets.Value("string")), 129 | "labels": datasets.features.Sequence(datasets.features.Sequence(datasets.ClassLabel(names=self.config.label_classes)))} 130 | return datasets.DatasetInfo( 131 | description=self.config.description, 132 | features=datasets.Features(features), 133 | homepage=self.config.url, 134 | citation=self.config.citation + "\n" + MAIN_CITATION, 135 | ) 136 | 137 | def _split_generators(self, dl_manager): 138 | data_dir = dl_manager.download_and_extract(self.config.data_url) 139 | return [ 140 | datasets.SplitGenerator( 141 | name=datasets.Split.TRAIN, 142 | # These kwargs will be passed to _generate_examples 143 | gen_kwargs={"filepath": os.path.join(data_dir, self.config.data_file), "split": "train"}, 144 | ), 145 | datasets.SplitGenerator( 146 | name=datasets.Split.TEST, 147 | # These kwargs will be passed to _generate_examples 148 | gen_kwargs={"filepath": os.path.join(data_dir, self.config.data_file), "split": "test"}, 149 | ), 150 | datasets.SplitGenerator( 151 | name=datasets.Split.VALIDATION, 152 | # These kwargs will be passed to _generate_examples 153 | gen_kwargs={ 154 | "filepath": os.path.join(data_dir, self.config.data_file), 155 | "split": self.config.dev_column, 156 | }, 157 | ), 158 | ] 159 | 160 | def _generate_examples(self, filepath, split): 161 | """This function returns the examples in the raw (text) form.""" 162 | with open(filepath, "r", encoding="utf-8") as f: 163 | for id_, row in enumerate(f): 164 | data = json.loads(row) 165 | labels = [list(set(par_labels)).remove('O') if 'O' in par_labels else list(set(par_labels)) for 166 | par_labels in data[self.config.label_column]] 167 | if data["data_type"] == split: 168 | yield id_, { 169 | "text": data[self.config.text_column], 170 | "labels": labels, 171 | } -------------------------------------------------------------------------------- /models/big_bird/modeling_big_bird.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Allen Institute for AI team and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch BigBird model.""" 16 | 17 | from dataclasses import dataclass 18 | from typing import Optional, Tuple, Union 19 | 20 | import torch 21 | import torch.utils.checkpoint 22 | from torch import nn 23 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss, CosineEmbeddingLoss 24 | from transformers.models.big_bird.modeling_big_bird import BigBirdPreTrainedModel, BigBirdModel 25 | 26 | from transformers.modeling_outputs import ( 27 | ModelOutput 28 | ) 29 | 30 | @dataclass 31 | class SentenceClassifierOutput(ModelOutput): 32 | """ 33 | Base class for outputs of sentence classification models. 34 | 35 | Args: 36 | loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided) : 37 | Classification loss. 38 | logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`): 39 | Classification scores (before SoftMax). 40 | hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): 41 | Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + 42 | one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. 43 | 44 | Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. 45 | attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): 46 | Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, 47 | sequence_length)`. 48 | sentence_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): 49 | Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, 50 | sequence_length)`. 51 | 52 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention 53 | heads. 54 | """ 55 | 56 | loss: Optional[Tuple[torch.FloatTensor]] = None 57 | logits: torch.FloatTensor = None 58 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None 59 | attentions: Optional[Tuple[torch.FloatTensor]] = None 60 | sentence_attentions: Optional[Tuple[torch.FloatTensor]] = None 61 | 62 | 63 | class BigBirdSentencizer(nn.Module): 64 | def __init__(self, config): 65 | super().__init__() 66 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 67 | self.activation = nn.Tanh() 68 | self.max_sentence_length = config.max_sentence_length 69 | 70 | def forward(self, hidden_states): 71 | sentence_repr_hidden_states = hidden_states[:, ::self.max_sentence_length] 72 | sentence_outputs = self.dense(sentence_repr_hidden_states) 73 | sentence_outputs = self.activation(sentence_outputs) 74 | return sentence_outputs 75 | 76 | 77 | class BigBirdModelForSentenceClassification(BigBirdPreTrainedModel): 78 | _keys_to_ignore_on_load_missing = [r"position_ids"] 79 | 80 | def __init__(self, config): 81 | super().__init__(config) 82 | self.num_labels = config.num_labels 83 | self.config = config 84 | 85 | self.bert = BigBirdModel(config, add_pooling_layer=False) 86 | self.sentencizer = BigBirdSentencizer(config) 87 | classifier_dropout = ( 88 | config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob 89 | ) 90 | self.dropout = nn.Dropout(classifier_dropout) 91 | self.classifier = nn.Linear(config.hidden_size, config.num_labels) 92 | 93 | # Initialize weights and apply final processing 94 | self.post_init() 95 | 96 | @classmethod 97 | def from_config(cls, config): 98 | return cls._from_config(config) 99 | 100 | def forward( 101 | self, 102 | input_ids=None, 103 | attention_mask=None, 104 | token_type_ids=None, 105 | position_ids=None, 106 | inputs_embeds=None, 107 | labels=None, 108 | output_attentions=None, 109 | output_hidden_states=None, 110 | return_dict=None, 111 | ): 112 | r""" 113 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 114 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., 115 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If 116 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy). 117 | """ 118 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 119 | 120 | outputs = self.bert( 121 | input_ids, 122 | attention_mask=attention_mask, 123 | token_type_ids=token_type_ids, 124 | position_ids=position_ids, 125 | inputs_embeds=inputs_embeds, 126 | output_attentions=output_attentions, 127 | output_hidden_states=output_hidden_states, 128 | return_dict=return_dict, 129 | ) 130 | sequence_output = outputs[0] 131 | sentence_outputs = self.sentencizer(sequence_output) 132 | sentence_outputs = self.dropout(sentence_outputs) 133 | logits = self.classifier(sentence_outputs) 134 | 135 | loss = None 136 | if labels is not None: 137 | if self.config.problem_type is None: 138 | if self.num_labels == 1: 139 | self.config.problem_type = "regression" 140 | elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): 141 | self.config.problem_type = "single_label_classification" 142 | else: 143 | self.config.problem_type = "multi_label_classification" 144 | 145 | if self.config.problem_type == "regression": 146 | loss_fct = MSELoss() 147 | if self.num_labels == 1: 148 | loss = loss_fct(logits.view(-1, 1).squeeze(), labels.view(-1).squeeze()) 149 | else: 150 | loss = loss_fct(logits.view(-1, 1), labels.view(-1)) 151 | elif self.config.problem_type == "single_label_classification": 152 | loss_fct = CrossEntropyLoss() 153 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 154 | elif self.config.problem_type == "multi_label_classification": 155 | loss_fct = BCEWithLogitsLoss() 156 | mask = labels[:, :, 0] != -1 157 | loss = loss_fct(logits[mask], labels[mask]) 158 | 159 | if not return_dict: 160 | output = (logits,) + outputs[2:] 161 | return ((loss,) + output) if loss is not None else output 162 | 163 | return SentenceClassifierOutput( 164 | loss=loss, 165 | logits=logits, 166 | hidden_states=outputs.hidden_states, 167 | attentions=outputs.attentions, 168 | sentence_attentions=None 169 | ) 170 | 171 | -------------------------------------------------------------------------------- /data/contractnli-dataset/contractnli-dataset.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ContractNLI""" 16 | 17 | import json 18 | import os 19 | import textwrap 20 | 21 | import datasets 22 | 23 | MAIN_CITATION = """\ 24 | @inproceedings{koreeda-manning-2021-contractnli-dataset, 25 | title = "{C}ontract{NLI}: A Dataset for Document-level Natural Language Inference for Contracts", 26 | author = "Koreeda, Yuta and 27 | Manning, Christopher", 28 | booktitle = "Findings of the Association for Computational Linguistics: EMNLP 2021", 29 | month = nov, 30 | year = "2021", 31 | address = "Punta Cana, Dominican Republic", 32 | publisher = "Association for Computational Linguistics", 33 | url = "https://aclanthology.org/2021.findings-emnlp.164", 34 | doi = "10.18653/v1/2021.findings-emnlp.164", 35 | pages = "1907--1919", 36 | }""" 37 | 38 | _DESCRIPTION = """\ 39 | The ContractNLI dataset consists of Non-Disclosure Agreements (NDAs). All NDAs have been labeled based 40 | on several hypothesis templates as entailment, neutral or contradiction. In this version of the task 41 | (Task B), the input consists of the full document. 42 | """ 43 | 44 | LABELS = ["contradiction", "entailment", "neutral"] 45 | 46 | 47 | class ContractNLIConfig(datasets.BuilderConfig): 48 | """BuilderConfig for ContractNLI.""" 49 | 50 | def __init__( 51 | self, 52 | text_column, 53 | label_column, 54 | url, 55 | data_url, 56 | data_file, 57 | citation, 58 | label_classes=None, 59 | multi_label=None, 60 | dev_column="dev", 61 | **kwargs, 62 | ): 63 | """BuilderConfig for ContractNLI. 64 | 65 | Args: 66 | text_column: ``string`, name of the column in the jsonl file corresponding 67 | to the text 68 | label_column: `string`, name of the column in the jsonl file corresponding 69 | to the label 70 | url: `string`, url for the original project 71 | data_url: `string`, url to download the zip file from 72 | data_file: `string`, filename for data set 73 | citation: `string`, citation for the data set 74 | url: `string`, url for information about the data set 75 | label_classes: `list[string]`, the list of classes if the label is 76 | categorical. If not provided, then the label will be of type 77 | `datasets.Value('float32')`. 78 | multi_label: `boolean`, True if the task is multi-label 79 | dev_column: `string`, name for the development subset 80 | **kwargs: keyword arguments forwarded to super. 81 | """ 82 | super(ContractNLIConfig, self).__init__(version=datasets.Version("1.0.0", ""), **kwargs) 83 | self.text_column = text_column 84 | self.label_column = label_column 85 | self.label_classes = label_classes 86 | self.multi_label = multi_label 87 | self.dev_column = dev_column 88 | self.url = url 89 | self.data_url = data_url 90 | self.data_file = data_file 91 | self.citation = citation 92 | 93 | 94 | class LexGLUE(datasets.GeneratorBasedBuilder): 95 | """LexGLUE: A Benchmark Dataset for Legal Language Understanding in English. Version 1.0""" 96 | 97 | BUILDER_CONFIGS = [ 98 | ContractNLIConfig( 99 | name="contractnli", 100 | description=textwrap.dedent( 101 | """\ 102 | The ContractNLI dataset consists of Non-Disclosure Agreements (NDAs). All NDAs have been labeled based 103 | on several hypothesis templates as entailment, neutral or contradiction. In this version of the task 104 | (Task B), the input consists of the full document. 105 | """ 106 | ), 107 | text_column="premise", 108 | label_column="label", 109 | label_classes=LABELS, 110 | multi_label=False, 111 | dev_column="dev", 112 | data_url="contract_nli.zip", 113 | data_file="contract_nli_long.jsonl", 114 | url="https://stanfordnlp.github.io/contract-nli/", 115 | citation=textwrap.dedent( 116 | """\ 117 | @inproceedings{koreeda-manning-2021-contractnli-dataset, 118 | title = "{C}ontract{NLI}: A Dataset for Document-level Natural Language Inference for Contracts", 119 | author = "Koreeda, Yuta and 120 | Manning, Christopher", 121 | booktitle = "Findings of the Association for Computational Linguistics: EMNLP 2021", 122 | month = nov, 123 | year = "2021", 124 | address = "Punta Cana, Dominican Republic", 125 | publisher = "Association for Computational Linguistics", 126 | url = "https://aclanthology.org/2021.findings-emnlp.164", 127 | doi = "10.18653/v1/2021.findings-emnlp.164", 128 | pages = "1907--1919", 129 | } 130 | }""" 131 | ), 132 | ) 133 | ] 134 | 135 | def _info(self): 136 | features = {"premise": datasets.Value("string"), "hypothesis": datasets.Value("string"), 137 | 'label': datasets.ClassLabel(names=LABELS)} 138 | return datasets.DatasetInfo( 139 | description=self.config.description, 140 | features=datasets.Features(features), 141 | homepage=self.config.url, 142 | citation=self.config.citation + "\n" + MAIN_CITATION, 143 | ) 144 | 145 | def _split_generators(self, dl_manager): 146 | data_dir = dl_manager.download_and_extract(self.config.data_url) 147 | return [ 148 | datasets.SplitGenerator( 149 | name=datasets.Split.TRAIN, 150 | # These kwargs will be passed to _generate_examples 151 | gen_kwargs={"filepath": os.path.join(data_dir, self.config.data_file), "split": "train"}, 152 | ), 153 | datasets.SplitGenerator( 154 | name=datasets.Split.TEST, 155 | # These kwargs will be passed to _generate_examples 156 | gen_kwargs={"filepath": os.path.join(data_dir, self.config.data_file), "split": "test"}, 157 | ), 158 | datasets.SplitGenerator( 159 | name=datasets.Split.VALIDATION, 160 | # These kwargs will be passed to _generate_examples 161 | gen_kwargs={ 162 | "filepath": os.path.join(data_dir, self.config.data_file), 163 | "split": self.config.dev_column, 164 | }, 165 | ), 166 | ] 167 | 168 | def _generate_examples(self, filepath, split): 169 | """This function returns the examples in the raw (text) form.""" 170 | with open(filepath, "r", encoding="utf-8") as f: 171 | sid = -1 172 | for id_, row in enumerate(f): 173 | data = json.loads(row) 174 | if data["subset"] == split: 175 | for sample in data['hypothesises/labels']: 176 | sid += 1 177 | yield sid, { 178 | "premise": data["premise"], 179 | "hypothesis": sample['hypothesis'], 180 | "label": sample['label'], 181 | } -------------------------------------------------------------------------------- /models/hat/configuration_hat.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | """ HAT configuration""" 14 | from collections import OrderedDict 15 | from typing import Mapping 16 | 17 | from transformers.onnx import OnnxConfig 18 | from transformers.utils import logging 19 | from transformers import PretrainedConfig 20 | 21 | 22 | logger = logging.get_logger(__name__) 23 | 24 | HAT_PRETRAINED_CONFIG_ARCHIVE_MAP = { 25 | "kiddothe2b/hierarchical-transformer-base-4096": "https://huggingface.co/kiddothe2b/hierarchical-transformer-base-4096/resolve/main/config.json", 26 | "kiddothe2b/adhoc-hierarchical-transformer-base-4096": "https://huggingface.co/kiddothe2b/adhoc-hierarchical-transformer-base-4096/resolve/main/config.json", 27 | } 28 | 29 | 30 | class HATConfig(PretrainedConfig): 31 | r""" 32 | This is the configuration class to store the configuration of a :class:`~transformers.HAT`. 33 | It is used to instantiate a HAT model according to the specified arguments, 34 | defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration 35 | to that of the HAT `kiddothe2b/hierarchical-transformer-base-4096 36 | `__ architecture. 37 | 38 | Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model 39 | outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information. 40 | 41 | 42 | Args: 43 | vocab_size (:obj:`int`, `optional`, defaults to 30522): 44 | Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the 45 | :obj:`inputs_ids` passed when calling :class:`~transformers.BertModel` or 46 | :class:`~transformers.TFBertModel`. 47 | max_sentences (:obj:`int`, `optional`, defaults to 64): 48 | The maximum number of sentences that this model might ever be used with. 49 | max_sentence_size (:obj:`int`, `optional`, defaults to 128): 50 | The maximum sentence length that this model might ever be used with. 51 | model_max_length (:obj:`int`, `optional`, defaults to 8192): 52 | The maximum sequence length (max_sentences * max_sentence_size) that this model might ever be used with 53 | encoder_layout (:obj:`Dict`): 54 | The sentence/document encoder layout. 55 | hidden_size (:obj:`int`, `optional`, defaults to 768): 56 | Dimensionality of the encoder layers and the pooler layer. 57 | num_hidden_layers (:obj:`int`, `optional`, defaults to 12): 58 | Number of hidden layers in the Transformer encoder. 59 | num_attention_heads (:obj:`int`, `optional`, defaults to 12): 60 | Number of attention heads for each attention layer in the Transformer encoder. 61 | intermediate_size (:obj:`int`, `optional`, defaults to 3072): 62 | Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. 63 | hidden_act (:obj:`str` or :obj:`Callable`, `optional`, defaults to :obj:`"gelu"`): 64 | The non-linear activation function (function or string) in the encoder and pooler. If string, 65 | :obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported. 66 | hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1): 67 | The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. 68 | attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1): 69 | The dropout ratio for the attention probabilities. 70 | max_position_embeddings (:obj:`int`, `optional`, defaults to 512): 71 | The maximum sequence length that this model might ever be used with. Typically set this to something large 72 | just in case (e.g., 512 or 1024 or 2048). 73 | type_vocab_size (:obj:`int`, `optional`, defaults to 2): 74 | The vocabulary size of the :obj:`token_type_ids` passed when calling :class:`~transformers.BertModel` or 75 | :class:`~transformers.TFBertModel`. 76 | initializer_range (:obj:`float`, `optional`, defaults to 0.02): 77 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 78 | layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): 79 | The epsilon used by the layer normalization layers. 80 | position_embedding_type (:obj:`str`, `optional`, defaults to :obj:`"absolute"`): 81 | Type of position embedding. Choose one of :obj:`"absolute"`, :obj:`"relative_key"`, 82 | :obj:`"relative_key_query"`. For positional embeddings use :obj:`"absolute"`. For more information on 83 | :obj:`"relative_key"`, please refer to `Self-Attention with Relative Position Representations (Shaw et al.) 84 | `__. For more information on :obj:`"relative_key_query"`, please refer to 85 | `Method 4` in `Improve Transformer Models with Better Relative Position Embeddings (Huang et al.) 86 | `__. 87 | use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): 88 | Whether or not the model should return the last key/values attentions (not used by all models). Only 89 | relevant if ``config.is_decoder=True``. 90 | classifier_dropout (:obj:`float`, `optional`): 91 | The dropout ratio for the classification head. 92 | """ 93 | model_type = "hierarchical-transformer" 94 | 95 | def __init__( 96 | self, 97 | vocab_size=30522, 98 | hidden_size=768, 99 | max_sentences=64, 100 | max_sentence_size=128, 101 | model_max_length=8192, 102 | num_hidden_layers=12, 103 | num_attention_heads=12, 104 | intermediate_size=3072, 105 | hidden_act="gelu", 106 | hidden_dropout_prob=0.1, 107 | attention_probs_dropout_prob=0.1, 108 | max_position_embeddings=512, 109 | type_vocab_size=2, 110 | initializer_range=0.02, 111 | layer_norm_eps=1e-12, 112 | pad_token_id=0, 113 | position_embedding_type="absolute", 114 | encoder_layout=None, 115 | use_cache=True, 116 | classifier_dropout=None, 117 | **kwargs 118 | ): 119 | super().__init__(pad_token_id=pad_token_id, **kwargs) 120 | 121 | self.vocab_size = vocab_size 122 | self.hidden_size = hidden_size 123 | self.max_sentences = max_sentences 124 | self.max_sentence_size = max_sentence_size 125 | self.model_max_length = model_max_length 126 | self.encoder_layout = encoder_layout 127 | self.num_hidden_layers = num_hidden_layers 128 | self.num_attention_heads = num_attention_heads 129 | self.hidden_act = hidden_act 130 | self.intermediate_size = intermediate_size 131 | self.hidden_dropout_prob = hidden_dropout_prob 132 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 133 | self.max_position_embeddings = max_position_embeddings 134 | self.type_vocab_size = type_vocab_size 135 | self.initializer_range = initializer_range 136 | self.layer_norm_eps = layer_norm_eps 137 | self.position_embedding_type = position_embedding_type 138 | self.use_cache = use_cache 139 | self.classifier_dropout = classifier_dropout 140 | 141 | 142 | class HATOnnxConfig(OnnxConfig): 143 | @property 144 | def inputs(self) -> Mapping[str, Mapping[int, str]]: 145 | return OrderedDict( 146 | [ 147 | ("input_ids", {0: "batch", 1: "sequence"}), 148 | ("attention_mask", {0: "batch", 1: "sequence"}), 149 | ] 150 | ) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Hierarchical Attention Transformers (HATs) 2 | 3 | Implementation of Hierarchical Attention Transformers (HATs) presented in _"An Exploration of Hierarchical Attention Transformers for Efficient Long Document Classification"_ of Chalkidis et al. (2022). HAT use a hierarchical attention scheme, which is a combination of segment-wise and cross-segment attention operations. You can think segments as paragraphs or sentences. 4 | 5 | 6 | 7 | 8 | ## Citation 9 | 10 | If you use HAT in your research, please cite: 11 | 12 | [An Exploration of Hierarchical Attention Transformers for Efficient Long Document Classification](https://arxiv.org/abs/2210.05529). Ilias Chalkidis, Xiang Dai, Manos Fergadiotis, Prodromos Malakasiotis, and Desmond Elliott. 2022. arXiv:2210.05529 (Preprint). 13 | 14 | ``` 15 | @misc{chalkidis-etal-2022-hat, 16 | url = {https://arxiv.org/abs/2210.05529}, 17 | author = {Chalkidis, Ilias and Dai, Xiang and Fergadiotis, Manos and Malakasiotis, Prodromos and Elliott, Desmond}, 18 | title = {An Exploration of Hierarchical Attention Transformers for Efficient Long Document Classification}, 19 | publisher = {arXiv}, 20 | year = {2022}, 21 | } 22 | ``` 23 | 24 | ## Implementation Details 25 | 26 | The repository supports several variants of the HAT architecture. The implementation of HAT is build on top of HuggingFace Transformers in Torch. The implementations is available at `models/hat/modelling_hat.py`. The layout of stacking segment-wise (SW) and cross-segment (CS) encoders is specified in the configuration file with the `encoder_layout` parameter. 27 | 28 | 29 | 30 | 31 | * **_Ad-Hoc (AH):_** An ad-hoc (partially pre-trained) HAT comprises an initial stack of shared L-SWE segment encoders from a pre-trained transformer-based model, followed by L-CSE ad-hoc segment-wise encoders. In this case the model initially encodes and contextualize token representations per segment, and then builds higher-order segment-level representationse.g., a 6-layer model has 12 effective transformer blocks (Layout: S/S/S/S/S/S/S/S/D/D/D/D). 32 | 33 | ```json 34 | "encoder_layout": { 35 | "0": {"sentence_encoder": true, "document_encoder": false}, 36 | "1": {"sentence_encoder": true, "document_encoder": false}, 37 | "2": {"sentence_encoder": true, "document_encoder": false}, 38 | "3": {"sentence_encoder": true, "document_encoder": false}, 39 | "4": {"sentence_encoder": true, "document_encoder": false}, 40 | "5": {"sentence_encoder": true, "document_encoder": false}, 41 | "6": {"sentence_encoder": true, "document_encoder": false}, 42 | "7": {"sentence_encoder": true, "document_encoder": false}, 43 | "8": {"sentence_encoder": false, "document_encoder": true}, 44 | "9": {"sentence_encoder": false, "document_encoder": true}, 45 | "10": {"sentence_encoder": false, "document_encoder": true}, 46 | "11": {"sentence_encoder": false, "document_encoder": true} 47 | } 48 | ``` 49 | 50 | * **_Interleaved (I):_** An interleaved HAT comprises a stack of L paired segment-wise and cross-segment encoders. 51 | e.g., a 6-layer model has 12 effective transformer blocks (Layout: SD/SD/SD/SD/SD/SD). 52 | 53 | ```json 54 | "encoder_layout": { 55 | "0": {"sentence_encoder": true, "document_encoder": true}, 56 | "1": {"sentence_encoder": true, "document_encoder": true}, 57 | "2": {"sentence_encoder": true, "document_encoder": true}, 58 | "3": {"sentence_encoder": true, "document_encoder": true}, 59 | "4": {"sentence_encoder": true, "document_encoder": true}, 60 | "5": {"sentence_encoder": true, "document_encoder": true} 61 | } 62 | ``` 63 | * **_Early-Contextualization (EC):_** n early-contextualized HAT comprises an initial stack of L-P paired segment-wise and cross-segment encoders, followed by a stack of L-SWE segment-wise encoders. In this case, cross-segment attention (contextualization) is only performed at the initial layers of the model,e.g., a 6-layer model and 8 effective transformer blocks (Layout: SD/SD/S/S/S/S). 64 | 65 | ```json 66 | "encoder_layout": { 67 | "0": {"sentence_encoder": true, "document_encoder": true}, 68 | "1": {"sentence_encoder": true, "document_encoder": true}, 69 | "2": {"sentence_encoder": true, "document_encoder": false}, 70 | "3": {"sentence_encoder": true, "document_encoder": false}, 71 | "4": {"sentence_encoder": true, "document_encoder": false}, 72 | "5": {"sentence_encoder": true, "document_encoder": false} 73 | } 74 | ``` 75 | 76 | 77 | * **_Late-Contextualization (LC):_** A late-contextualized HAT comprises an initial stack of $L_{\mathrm{SWE}}$ segment-wise encoders, followed by a stack of $L_{\mathrm{P}}$ paired segment and segment-wise encoders. In this case, cross-segment attention (contextualization) is only performed in the latter layers of the model, e.g., a 6-layer model and 8 effective transformer blocks (Layout: S/S/S/S/SD/SD). 78 | 79 | 80 | ```json 81 | "encoder_layout": { 82 | "0": {"sentence_encoder": true, "document_encoder": false}, 83 | "1": {"sentence_encoder": true, "document_encoder": false}, 84 | "2": {"sentence_encoder": true, "document_encoder": false}, 85 | "3": {"sentence_encoder": true, "document_encoder": false}, 86 | "4": {"sentence_encoder": true, "document_encoder": true}, 87 | "5": {"sentence_encoder": true, "document_encoder": true} 88 | } 89 | ``` 90 | 91 | In thi study, we examine the efficacy of 8 alternative layouts: 92 | 93 | ```json 94 | { 95 | 'I1': 'SD|SD|SD|SD|SD|SD', 96 | 'I2': 'S|SD|D|S|SD|D|S|SD|D', 97 | 'I3': 'S|SD|S|SD|S|SD|S|SD', 98 | 'LC1': 'S|S|S|S|S|S|SD|SD|SD', 99 | 'LC2': 'S|S|S|S|S|SD|D|S|SD|D', 100 | 'EC1': 'S|S|SD|D|S|SD|D|S|S|S', 101 | 'EC2': 'S|S|SD|SD|SD|S|S|S|S', 102 | 'AH': 'S|S|S|S|S|S|S|S|S|S|S|S', 103 | } 104 | 105 | ``` 106 | 107 | ## Available Models on HuggingFace Hub 108 | 109 | | Model Name | Layers | Hidden Units | Attention Heads | Vocab | Parameters | 110 | |---------------------------------------------------------------------------------------------------------------------------------------------|--------|--------------|-----------------|-------|------------| 111 | | [`kiddothe2b/hierarchical-transformer-base-4096`](https://huggingface.co/kiddothe2b/hierarchical-transformer-base-4096) | 16 | 768 | 12 | 50K | 152M | 112 | | [`kiddothe2b/longformer-base-4096`](https://huggingface.co/kiddothe2b/longformer-base-4096) | 12 | 768 | 12 | 50K | 152M | 113 | | [`kiddothe2b/adhoc-hierarchical-transformer-base-4096`](https://huggingface.co/kiddothe2b/adhoc-hierarchical-transformer-base-4096) | 16 | 768 | 12 | 50K | 140M | 114 | | [`kiddothe2b/adhoc-hierarchical-transformer-I1-mini-1024`](https://huggingface.co/kiddothe2b/adhoc-hierarchical-transformer-I1-mini-1024) | 12 | 256 | 4 | 32K | 18M | 115 | | [`kiddothe2b/adhoc-hierarchical-transformer-I3-mini-1024`](https://huggingface.co/kiddothe2b/adhoc-hierarchical-transformer-I2-mini-1024) | 12 | 256 | 4 | 32K | 18M | 116 | | [`kiddothe2b/adhoc-hierarchical-transformer-LC1-mini-1024`](https://huggingface.co/kiddothe2b/adhoc-hierarchical-transformer-LC1-mini-1024) | 12 | 256 | 4 | 32K | 18M | 117 | | [`kiddothe2b/adhoc-hierarchical-transformer-EC2-mini-1024`](https://huggingface.co/kiddothe2b/adhoc-hierarchical-transformer-EC2-mini-1024) | 12 | 256 | 4 | 32K | 18M | 118 | | [`kiddothe2b/longformer-mini-1024`](https://huggingface.co/kiddothe2b/longformer-mini-1024) | 6 | 256 | 4 | 32K | 14M | 119 | 120 | 121 | ## Requirements 122 | 123 | Make sure that all required packages are installed: 124 | 125 | ``` 126 | torch>=1.11.0 127 | transformers>=4.18.0 128 | datasets>=2.0.0 129 | tokenizers>=0.11.0 130 | scikit-learn>=1.0.0 131 | tqdm>=4.62.0 132 | nltk>=3.7.0 133 | ``` 134 | 135 | ## How to run experiments? 136 | 137 | You can use the shell scripts provided in the `running_scripts` directory to pre-train new models or fine-tune the ones released. 138 | 139 | Try on Google Colab: https://colab.research.google.com/drive/15feh49wqBshgkcvbO6QypvJoa3dG6P5S?usp=sharing 140 | 141 | ### I still have open questions... 142 | 143 | Please post your question on [Discussions](https://github.com/coastalcph/hi-transformers/discussions) section or communicate with the corresponding author via e-mail. 144 | -------------------------------------------------------------------------------- /models/util/efficiency_metrics_bert_models.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import warnings 3 | 4 | import torch 5 | import time 6 | from transformers import AutoConfig 7 | 8 | import numpy as np 9 | from data import DATA_DIR 10 | from models.hat import HATForMaskedLM, HATConfig, HATForSequenceClassification, \ 11 | HATForMultipleChoice 12 | from models.longformer import LongformerForMaskedLM, LongformerModelForSequenceClassification, LongformerForMultipleChoice 13 | 14 | warnings.filterwarnings("ignore") 15 | 16 | LAYOUTS = { 17 | # 'f6': 'S|S|S|S|S|S', 18 | # 'f8': 'S|S|S|S|S|S', 19 | 's1': 'SD|SD|SD|SD|SD|SD', 20 | 's2': 'S|SD|D|S|SD|D|S|SD|D', 21 | 'p1': 'S|SD|S|SD|S|SD|S|SD', 22 | 'p2': 'S|S|SD|S|S|SD|S|S|SD', 23 | 'e1': 'SD|SD|SD|S|S|S|S|S|S', 24 | 'e2': 'S|SD|D|S|SD|D|S|S|S|S', 25 | 'l1': 'S|S|S|S|S|S|SD|SD|SD', 26 | 'l2': 'S|S|S|S|S|SD|D|S|SD|D', 27 | # 'b1': 'S|S|SD|D|S|SD|D|S|S|S', 28 | # 'b2': 'S|S|SD|SD|SD|S|S|S|S', 29 | } 30 | 31 | 32 | TASK_MODEL = {'lm': {'longformer': LongformerForMaskedLM, 'hilm': HATForMaskedLM}, 33 | 'doc_cls': {'longformer': LongformerModelForSequenceClassification, 'hilm': HATForSequenceClassification}, 34 | 'mc_qa': {'longformer': LongformerForMultipleChoice, 'hilm': HATForMultipleChoice}, 35 | } 36 | 37 | 38 | def test_memory_usage(model, steps=40, batch_size=2, seq_length=1024, mode='test', task_type='lm'): 39 | model.to('cuda') 40 | if task_type != 'mc_qa': 41 | input_ids = torch.randint(1, 30000, (batch_size, seq_length), dtype=torch.long).to('cuda') 42 | input_ids[:, :: 128] = model.config.bos_token_id 43 | attention_mask = torch.ones((batch_size, seq_length), dtype=torch.int).to('cuda') 44 | else: 45 | input_ids = torch.randint(1, 30000, (batch_size, 2, seq_length), dtype=torch.long).to('cuda') 46 | input_ids[:, :: 128] = model.config.bos_token_id 47 | attention_mask = torch.ones((batch_size, 2, seq_length), dtype=torch.int).to('cuda') 48 | if mode == 'train': 49 | if task_type == 'lm': 50 | labels = input_ids.clone() 51 | else: 52 | labels = torch.ones((batch_size, ), dtype=torch.int).long().to('cuda') 53 | optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5) 54 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, steps) 55 | max_time = [] 56 | max_mem = [] 57 | for _ in range(steps): 58 | torch.cuda.reset_peak_memory_stats() 59 | start = time.time() 60 | if mode == 'train': 61 | outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) 62 | loss = outputs.loss 63 | loss.backward(loss) 64 | optimizer.step() 65 | lr_scheduler.step() 66 | optimizer.zero_grad() 67 | else: 68 | with torch.no_grad(): 69 | outputs = model(input_ids=input_ids, attention_mask=attention_mask) 70 | end = time.time() 71 | total_time = (end - start) 72 | max_time.append(total_time) 73 | max_mem.append(torch.cuda.max_memory_allocated() / 1e9) 74 | 75 | return np.mean(max_mem), np.mean(max_time) 76 | 77 | 78 | def efficiency_metrics(): 79 | MAX_SENTENCE_LENGTH = 128 80 | CONFIGS = [{'num_hidden_layers': 6, 81 | 'hidden_size': 256, 82 | 'intermediate_size': 1024, 83 | 'num_attention_heads': 4}, 84 | # {'num_hidden_layers': 6, 85 | # 'hidden_size': 768, 86 | # 'intermediate_size': 3072, 87 | # 'num_attention_heads': 12} 88 | ] 89 | 90 | for mode in ['train', 'test']: 91 | print(F'MODE: {mode.upper()}') 92 | for task in ['lm', 'doc_cls', 'mc_qa']: 93 | for CONFIG in CONFIGS: 94 | print('-' * 150) 95 | print(F'TASK: {task.upper()}\t' 96 | F'NUM LAYERS: {CONFIG["num_hidden_layers"]}\t' 97 | F'NUM HIDDEN: {CONFIG["hidden_size"]}\t' 98 | F'ATTENTION HEADS: {CONFIG["num_attention_heads"]}') 99 | print('-' * 150) 100 | 101 | for max_sentences in [8]: 102 | print('-' * 150) 103 | print(F'MAX SEQ LENGTH: {int(max_sentences * MAX_SENTENCE_LENGTH)}') 104 | print('-' * 150) 105 | 106 | lf_config = AutoConfig.from_pretrained('allenai/longformer-base-4096') 107 | lf_config.num_hidden_layers = CONFIG['num_hidden_layers'] 108 | # Transformer parameters 109 | lf_config.hidden_size = CONFIG['hidden_size'] 110 | lf_config.intermediate_size = CONFIG['intermediate_size'] 111 | lf_config.num_attention_heads = CONFIG['num_attention_heads'] 112 | # Vocabulary parameters 113 | lf_config.vocab_size = 32000 114 | lf_config.type_vocab_size = 2 115 | lf_config.model_max_length = int(MAX_SENTENCE_LENGTH * max_sentences) 116 | lf_config.max_position_embeddings = int(MAX_SENTENCE_LENGTH * max_sentences) + 2 117 | lf_config.attention_window = [128] * CONFIG['num_hidden_layers'] 118 | lf_config.max_sentence_length = MAX_SENTENCE_LENGTH 119 | lf_config.max_sentences = max_sentences 120 | lf_config.cls_token_id = 100 121 | # load dummy longformer model 122 | htf_model = TASK_MODEL[task]['longformer'].from_config(lf_config) 123 | model_total_params = sum(p.numel() for p in htf_model.longformer.parameters() if p.requires_grad) 124 | model_total_params = model_total_params / 1e6 125 | memory_use, time_use = test_memory_usage(htf_model, seq_length=lf_config.model_max_length, mode=mode, task_type=task) 126 | lf_mem_use = copy.deepcopy(memory_use) 127 | lf_time_use = copy.deepcopy(time_use) 128 | print(f'Longformer model has {model_total_params:.1f}M number of parameters ' 129 | f'and {memory_use:.2f}GB peak memory use and {time_use:.3f} batch/second!') 130 | print('-' * 150) 131 | 132 | for layout in LAYOUTS: 133 | ENCODER_LAYOUT = {} 134 | for idx, block_pattern in enumerate(LAYOUTS[layout].split('|')): 135 | ENCODER_LAYOUT[str(idx)] = {"sentence_encoder": True if 'S' in block_pattern else False, 136 | "document_encoder": True if 'D' in block_pattern else False} 137 | 138 | # load dummy config and change specifications 139 | htf_config = HATConfig.from_pretrained(f'{DATA_DIR}/hat') 140 | # Text length parameters 141 | htf_config.max_sentence_length = MAX_SENTENCE_LENGTH 142 | htf_config.max_sentences = max_sentences 143 | htf_config.max_position_embeddings = MAX_SENTENCE_LENGTH 144 | htf_config.model_max_length = int(MAX_SENTENCE_LENGTH * max_sentences) 145 | htf_config.num_hidden_layers = len(ENCODER_LAYOUT.keys()) 146 | # Transformer parameters 147 | htf_config.hidden_size = CONFIG['hidden_size'] 148 | htf_config.intermediate_size = CONFIG['intermediate_size'] 149 | htf_config.num_attention_heads = CONFIG['num_attention_heads'] 150 | htf_config.encoder_layout = ENCODER_LAYOUT 151 | # Vocabulary parameters 152 | htf_config.vocab_size = 32000 153 | htf_config.type_vocab_size = 2 154 | 155 | # load dummy hat model 156 | htf_model = TASK_MODEL[task]['hilm'].from_config(htf_config) 157 | model_total_params = sum(p.numel() for p in htf_model.hat.parameters() if p.requires_grad) 158 | model_total_params = model_total_params / 1e6 159 | memory_use, time_use = test_memory_usage(htf_model, seq_length=int(MAX_SENTENCE_LENGTH * max_sentences), mode=mode, task_type=task) 160 | mem_gains = (lf_mem_use / memory_use) - 1 161 | time_gains = (lf_time_use / time_use) - 1 162 | print(f'Hi-transformer model with layout {layout} has {model_total_params:.1f}M number of parameters ' 163 | f'{memory_use:.2f}GB peak memory use (-{mem_gains*100:.2f}%) and {time_use:.3f} batch/second (-{time_gains*100:.2f}%)!') 164 | 165 | 166 | if __name__ == '__main__': 167 | efficiency_metrics() 168 | -------------------------------------------------------------------------------- /models/util/benchmark_original_models.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import torch 4 | import time 5 | from transformers import AutoConfig, AutoModelForSequenceClassification 6 | from models.longformer import LongformerModelForSentenceClassification 7 | from models.big_bird import BigBirdModelForSentenceClassification 8 | import numpy as np 9 | warnings.filterwarnings("ignore") 10 | 11 | 12 | def test_memory_usage(model, steps=100, batch_size=2, seq_length=4096, mode='test', task_type='lm'): 13 | model.to('cuda') 14 | if task_type != 'mc_qa': 15 | input_ids = torch.randint(1, 40000, (batch_size, seq_length), dtype=torch.long).to('cuda') 16 | input_ids[:, :: 128] = model.config.bos_token_id 17 | attention_mask = torch.ones((batch_size, seq_length), dtype=torch.int).to('cuda') 18 | else: 19 | input_ids = torch.randint(1, 40000, (batch_size, 2, seq_length), dtype=torch.long).to('cuda') 20 | input_ids[:, :: 128] = model.config.bos_token_id 21 | attention_mask = torch.ones((batch_size, 2, seq_length), dtype=torch.int).to('cuda') 22 | if mode == 'train': 23 | if task_type == 'lm': 24 | labels = input_ids.clone() 25 | elif task_type == 'sent_cls': 26 | labels = torch.ones((batch_size, 32), dtype=torch.int).long().to('cuda') 27 | else: 28 | labels = torch.ones((batch_size, ), dtype=torch.int).long().to('cuda') 29 | optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5) 30 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, steps) 31 | max_time = [] 32 | max_mem = [] 33 | for _ in range(steps): 34 | torch.cuda.reset_peak_memory_stats() 35 | start = time.time() 36 | if mode == 'train': 37 | outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) 38 | loss = outputs.loss 39 | loss.backward(loss) 40 | optimizer.step() 41 | lr_scheduler.step() 42 | optimizer.zero_grad() 43 | else: 44 | with torch.no_grad(): 45 | outputs = model(input_ids=input_ids, attention_mask=attention_mask) 46 | end = time.time() 47 | total_time = (end - start) 48 | max_time.append(total_time) 49 | max_mem.append(torch.cuda.max_memory_allocated() / 1e9) 50 | 51 | return np.mean(max_mem), np.mean(max_time) 52 | 53 | 54 | def estimate_model_size(): 55 | for mode in ['train', 'test']: 56 | print(F'MODE: {mode.upper()}') 57 | for task in ['doc_cls', 'sent_cls']: 58 | MAX_SENTENCE_LENGTH = 128 59 | roberta_config = AutoConfig.from_pretrained('roberta-base') 60 | print('-' * 150) 61 | print(F'TASK: {task.upper()}\t' 62 | F'NUM LAYERS: {roberta_config.num_hidden_layers}\t' 63 | F'NUM HIDDEN: {roberta_config.hidden_size}\t' 64 | F'ATTENTION HEADS: {roberta_config.num_attention_heads}') 65 | print('-' * 150) 66 | MAX_SENTENCES = 32 67 | print('-' * 150) 68 | print(F'MAX SEQ LENGTH: {int(MAX_SENTENCES * MAX_SENTENCE_LENGTH)}') 69 | print('-' * 150) 70 | # load dummy longformer model 71 | lf_config = AutoConfig.from_pretrained('allenai/longformer-base-4096') 72 | lf_config.num_labels = 2 73 | lf_config.max_sentence_length = 128 74 | lf_config.max_sentences = 32 75 | lf_config.cls_token_id = lf_config.bos_token_id 76 | lf_config.sep_token_id = lf_config.eos_token_id 77 | if task == 'doc_cls': 78 | htf_model = AutoModelForSequenceClassification.from_config(lf_config) 79 | else: 80 | htf_model = LongformerModelForSentenceClassification.from_config(lf_config) 81 | model_total_params = sum(p.numel() for p in htf_model.longformer.parameters() if p.requires_grad) 82 | model_total_params = model_total_params / 1e6 83 | memory_use, time_use = test_memory_usage(htf_model, seq_length=4096, mode=mode, 84 | task_type=task) 85 | print(f'Original Longformer (12-layer) model has {model_total_params:.1f}M number of parameters ' 86 | f'and {memory_use:.2f}GB peak memory use and {time_use:.3f} batch/second!') 87 | print('-' * 150) 88 | 89 | # load dummy bigbird model 90 | lf_config = AutoConfig.from_pretrained('google/bigbird-roberta-base') 91 | lf_config.num_labels = 2 92 | lf_config.max_sentence_length = 128 93 | lf_config.max_sentences = 32 94 | lf_config.cls_token_id = lf_config.bos_token_id 95 | lf_config.sep_token_id = lf_config.eos_token_id 96 | if task == 'doc_cls': 97 | htf_model = AutoModelForSequenceClassification.from_config(lf_config) 98 | else: 99 | htf_model = BigBirdModelForSentenceClassification.from_config(lf_config) 100 | model_total_params = sum(p.numel() for p in htf_model.bert.parameters() if p.requires_grad) 101 | model_total_params = model_total_params / 1e6 102 | memory_use, time_use = test_memory_usage(htf_model, seq_length=4096, mode=mode, 103 | task_type=task) 104 | print(f'Original BigBird (12-layer) model has {model_total_params:.1f}M number of parameters ' 105 | f'and {memory_use:.2f}GB peak memory use and {time_use:.3f} batch/second!') 106 | print('-' * 150) 107 | 108 | 109 | if __name__ == '__main__': 110 | estimate_model_size() 111 | 112 | # ------------------------------------------------------------------------------------------------------------------------------------------------------ 113 | # MODE: TRAIN 114 | # ------------------------------------------------------------------------------------------------------------------------------------------------------ 115 | # TASK: DOC_CLS NUM LAYERS: 12 NUM HIDDEN: 768 ATTENTION HEADS: 12 116 | # ------------------------------------------------------------------------------------------------------------------------------------------------------ 117 | # MAX SEQ LENGTH: 4096 118 | # ------------------------------------------------------------------------------------------------------------------------------------------------------ 119 | # Original Longformer (12-layer) model has 148.1M number of parameters and 17.76GB peak memory use and 0.852 batch/second! 120 | # ------------------------------------------------------------------------------------------------------------------------------------------------------ 121 | # Original BigBird (12-layer) model has 127.5M number of parameters and 18.84GB peak memory use and 0.795 batch/second! 122 | # ------------------------------------------------------------------------------------------------------------------------------------------------------ 123 | # TASK: SENT_CLS NUM LAYERS: 12 NUM HIDDEN: 768 ATTENTION HEADS: 12 124 | # ------------------------------------------------------------------------------------------------------------------------------------------------------ 125 | # MAX SEQ LENGTH: 4096 126 | # ------------------------------------------------------------------------------------------------------------------------------------------------------ 127 | # Original Longformer (12-layer) model has 148.1M number of parameters and 18.37GB peak memory use and 0.895 batch/second! 128 | # ------------------------------------------------------------------------------------------------------------------------------------------------------ 129 | # Original BigBird (12-layer) model has 126.9M number of parameters and 18.84GB peak memory use and 0.795 batch/second! 130 | # ------------------------------------------------------------------------------------------------------------------------------------------------------ 131 | # MODE: TEST 132 | # ------------------------------------------------------------------------------------------------------------------------------------------------------ 133 | # TASK: DOC_CLS NUM LAYERS: 12 NUM HIDDEN: 768 ATTENTION HEADS: 12 134 | # ------------------------------------------------------------------------------------------------------------------------------------------------------ 135 | # MAX SEQ LENGTH: 4096 136 | # ------------------------------------------------------------------------------------------------------------------------------------------------------ 137 | # Original Longformer (12-layer) model has 148.1M number of parameters and 1.70GB peak memory use and 0.223 batch/second! 138 | # ------------------------------------------------------------------------------------------------------------------------------------------------------ 139 | # Original BigBird (12-layer) model has 127.5M number of parameters and 1.76GB peak memory use and 0.207 batch/second! 140 | # ------------------------------------------------------------------------------------------------------------------------------------------------------ 141 | # TASK: SENT_CLS NUM LAYERS: 12 NUM HIDDEN: 768 ATTENTION HEADS: 12 142 | # ------------------------------------------------------------------------------------------------------------------------------------------------------ 143 | # MAX SEQ LENGTH: 4096 144 | # ------------------------------------------------------------------------------------------------------------------------------------------------------ 145 | # Original Longformer (12-layer) model has 148.1M number of parameters and 1.71GB peak memory use and 0.236 batch/second! 146 | # ------------------------------------------------------------------------------------------------------------------------------------------------------ 147 | # Original BigBird (12-layer) model has 126.9M number of parameters and 1.76GB peak memory use and 0.207 batch/second! 148 | # ------------------------------------------------------------------------------------------------------------------------------------------------------ -------------------------------------------------------------------------------- /models/util/efficiency_metrics_roberta.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import warnings 3 | 4 | import torch 5 | import time 6 | from transformers import AutoConfig 7 | import numpy as np 8 | 9 | from data import DATA_DIR 10 | from models.hat import HATForMaskedLM, HATConfig, HATForSequenceClassification, \ 11 | HATForMultipleChoice, HATModelForSequentialSentenceClassification 12 | from models.longformer import LongformerForMaskedLM, LongformerModelForSequenceClassification, LongformerForMultipleChoice , \ 13 | LongformerModelForSentenceClassification 14 | warnings.filterwarnings("ignore") 15 | 16 | LAYOUTS = { 17 | 'f12': 'S|S|S|S|S|S|S|S|S|S|S|SD|D|D|D', 18 | 'p1': 'S|S|SD|S|S|SD|S|S|SD|S|S|SD', 19 | } 20 | 21 | TASK_MODEL = {'lm': {'longformer': LongformerForMaskedLM, 'hilm': HATForMaskedLM}, 22 | 'doc_cls': {'longformer': LongformerModelForSequenceClassification, 'hilm': HATForSequenceClassification}, 23 | 'mc_qa': {'longformer': LongformerForMultipleChoice, 'hilm': HATForMultipleChoice}, 24 | 'sent_cls': {'longformer': LongformerModelForSentenceClassification, 'hilm': HATModelForSequentialSentenceClassification}, 25 | } 26 | 27 | 28 | def test_memory_usage(model, steps=40, batch_size=2, seq_length=4096, mode='test', task_type='lm'): 29 | model.to('cuda') 30 | if task_type != 'mc_qa': 31 | input_ids = torch.randint(1, 40000, (batch_size, seq_length), dtype=torch.long).to('cuda') 32 | input_ids[:, :: 128] = model.config.bos_token_id 33 | attention_mask = torch.ones((batch_size, seq_length), dtype=torch.int).to('cuda') 34 | else: 35 | input_ids = torch.randint(1, 40000, (batch_size, 2, seq_length), dtype=torch.long).to('cuda') 36 | input_ids[:, :: 128] = model.config.bos_token_id 37 | attention_mask = torch.ones((batch_size, 2, seq_length), dtype=torch.int).to('cuda') 38 | if mode == 'train': 39 | if task_type == 'lm': 40 | labels = input_ids.clone() 41 | elif task_type == 'sent_cls': 42 | labels = torch.ones((batch_size, 32), dtype=torch.int).long().to('cuda') 43 | else: 44 | labels = torch.ones((batch_size, ), dtype=torch.int).long().to('cuda') 45 | optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5) 46 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, steps) 47 | max_time = [] 48 | max_mem = [] 49 | for _ in range(steps): 50 | torch.cuda.reset_peak_memory_stats() 51 | start = time.time() 52 | if mode == 'train': 53 | outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) 54 | loss = outputs.loss 55 | loss.backward(loss) 56 | optimizer.step() 57 | lr_scheduler.step() 58 | optimizer.zero_grad() 59 | else: 60 | with torch.no_grad(): 61 | outputs = model(input_ids=input_ids, attention_mask=attention_mask) 62 | end = time.time() 63 | total_time = (end - start) 64 | max_time.append(total_time) 65 | max_mem.append(torch.cuda.max_memory_allocated() / 1e9) 66 | 67 | return np.mean(max_mem), np.mean(max_time) 68 | 69 | 70 | def efficiency_metrics(): 71 | for mode in ['train', 'test']: 72 | print(F'MODE: {mode.upper()}') 73 | for task in TASK_MODEL: 74 | MAX_SENTENCE_LENGTH = 128 75 | roberta_config = AutoConfig.from_pretrained('roberta-base') 76 | print('-' * 150) 77 | print(F'TASK: {task.upper()}\t' 78 | F'NUM LAYERS: {roberta_config.num_hidden_layers}\t' 79 | F'NUM HIDDEN: {roberta_config.hidden_size}\t' 80 | F'ATTENTION HEADS: {roberta_config.num_attention_heads}') 81 | print('-' * 150) 82 | MAX_SENTENCES = 32 83 | print('-' * 150) 84 | print(F'MAX SEQ LENGTH: {int(MAX_SENTENCES * MAX_SENTENCE_LENGTH)}') 85 | print('-' * 150) 86 | lf_config = AutoConfig.from_pretrained('allenai/longformer-base-4096') 87 | lf_config.num_hidden_layers = 12 88 | # Transformer parameters 89 | lf_config.hidden_size = roberta_config.hidden_size 90 | lf_config.intermediate_size = roberta_config.intermediate_size 91 | lf_config.num_attention_heads = roberta_config.num_attention_heads 92 | # Vocabulary parameters 93 | lf_config.vocab_size = roberta_config.vocab_size 94 | lf_config.type_vocab_size = 2 95 | lf_config.model_max_length = int(MAX_SENTENCE_LENGTH * MAX_SENTENCES) 96 | lf_config.max_sentence_length = int(MAX_SENTENCE_LENGTH) 97 | lf_config.max_sentences = int(MAX_SENTENCES) 98 | lf_config.max_position_embeddings = int(MAX_SENTENCE_LENGTH * MAX_SENTENCES) + 2 99 | lf_config.attention_window = [128] * roberta_config.num_hidden_layers 100 | lf_config.cls_token_id = 100 101 | lf_config.num_labels = 2 102 | # load dummy longformer model 103 | htf_model = TASK_MODEL[task]['longformer'].from_config(lf_config) 104 | model_total_params = sum(p.numel() for p in htf_model.longformer.parameters() if p.requires_grad) 105 | model_total_params = model_total_params / 1e6 106 | memory_use, time_use = test_memory_usage(htf_model, seq_length=lf_config.model_max_length, mode=mode, task_type=task) 107 | lf_mem_use = copy.deepcopy(memory_use) 108 | lf_time_use = copy.deepcopy(time_use) 109 | print(f'Longformer (12-layer) model has {model_total_params:.1f}M number of parameters ' 110 | f'and {memory_use:.2f}GB peak memory use and {time_use:.3f} batch/second!') 111 | print('-' * 150) 112 | for layout in LAYOUTS: 113 | ENCODER_LAYOUT = {} 114 | for idx, block_pattern in enumerate(LAYOUTS[layout].split('|')): 115 | ENCODER_LAYOUT[str(idx)] = {"sentence_encoder": True if 'S' in block_pattern else False, 116 | "document_encoder": True if 'D' in block_pattern else False} 117 | 118 | # load dummy config and change specifications 119 | htf_config = HATConfig.from_pretrained(f'{DATA_DIR}/hi-transformer') 120 | # Text length parameters 121 | htf_config.max_sentence_length = MAX_SENTENCE_LENGTH 122 | htf_config.MAX_SENTENCES = MAX_SENTENCES 123 | htf_config.max_position_embeddings = MAX_SENTENCE_LENGTH 124 | htf_config.model_max_length = int(MAX_SENTENCE_LENGTH * MAX_SENTENCES) 125 | htf_config.num_hidden_layers = len(ENCODER_LAYOUT.keys()) 126 | # Transformer parameters 127 | htf_config.hidden_size = roberta_config.hidden_size 128 | htf_config.intermediate_size = roberta_config.intermediate_size 129 | htf_config.num_attention_heads = roberta_config.num_attention_heads 130 | htf_config.encoder_layout = ENCODER_LAYOUT 131 | # Vocabulary parameters 132 | htf_config.vocab_size = roberta_config.vocab_size 133 | htf_config.type_vocab_size = 2 134 | lf_config.num_labels = 2 135 | # load dummy hi-transformer model 136 | htf_model = TASK_MODEL[task]['hilm'].from_config(htf_config) 137 | model_total_params = sum(p.numel() for p in htf_model.hat.parameters() if p.requires_grad) 138 | model_total_params = model_total_params / 1e6 139 | memory_use, time_use = test_memory_usage(htf_model, seq_length=int(MAX_SENTENCE_LENGTH * MAX_SENTENCES), mode=mode, task_type=task) 140 | mem_gains = (lf_mem_use / memory_use) - 1 141 | time_gains = (lf_time_use / time_use) - 1 142 | print(f'Hi-transformer model with layout {layout} has {model_total_params:.1f}M number of parameters ' 143 | f'{memory_use:.2f}GB peak memory use (-{mem_gains*100:.2f}%) and {time_use:.3f} batch/second (-{time_gains*100:.2f}%)!') 144 | 145 | 146 | if __name__ == '__main__': 147 | efficiency_metrics() 148 | 149 | 150 | 151 | # ------------------------------------------------------------------------------------------------------------------------------------------------------ 152 | # MODE: TRAIN 153 | # ------------------------------------------------------------------------------------------------------------------------------------------------------ 154 | # TASK: SENT_CLS NUM LAYERS: 12 NUM HIDDEN: 768 ATTENTION HEADS: 12 155 | # ------------------------------------------------------------------------------------------------------------------------------------------------------ 156 | # MAX SEQ LENGTH: 4096 157 | # ------------------------------------------------------------------------------------------------------------------------------------------------------ 158 | # Longformer (12-layer) model has 148.1M number of parameters and 10.77GB peak memory use and 0.459 batch/second! 159 | # ------------------------------------------------------------------------------------------------------------------------------------------------------ 160 | # Hi-transformer model with layout f12 has 152.2M number of parameters 8.96GB peak memory use (-20.24%) and 0.343 batch/second (-33.80%)! 161 | # Hi-transformer model with layout p1 has 152.2M number of parameters 8.97GB peak memory use (-20.11%) and 0.344 batch/second (-33.42%)! 162 | # ------------------------------------------------------------------------------------------------------------------------------------------------------ 163 | # MODE: TEST 164 | # ------------------------------------------------------------------------------------------------------------------------------------------------------ 165 | # TASK: SENT_CLS NUM LAYERS: 12 NUM HIDDEN: 768 ATTENTION HEADS: 12 166 | # ------------------------------------------------------------------------------------------------------------------------------------------------------ 167 | # MAX SEQ LENGTH: 4096 168 | # ------------------------------------------------------------------------------------------------------------------------------------------------------ 169 | # Longformer (12-layer) model has 148.1M number of parameters and 0.98GB peak memory use and 0.131 batch/second! 170 | # ------------------------------------------------------------------------------------------------------------------------------------------------------ 171 | # Hi-transformer model with layout f12 has 152.2M number of parameters 0.90GB peak memory use (-8.59%) and 0.115 batch/second (-13.96%)! 172 | # Hi-transformer model with layout p1 has 152.2M number of parameters 0.90GB peak memory use (-8.59%) and 0.114 batch/second (-14.45%)! 173 | # ------------------------------------------------------------------------------------------------------------------------------------------------------ 174 | -------------------------------------------------------------------------------- /models/big_bird/tokenization_big_bird.py: -------------------------------------------------------------------------------- 1 | """Tokenization classes for BigBird.""" 2 | import torch 3 | from transformers import AutoTokenizer 4 | from transformers.models.big_bird.configuration_big_bird import BigBirdConfig 5 | from transformers.utils import logging 6 | from nltk import sent_tokenize 7 | logger = logging.get_logger(__name__) 8 | 9 | 10 | class BigbirdTokenizer: 11 | def __init__(self, tokenizer=None): 12 | self._tokenizer = tokenizer 13 | self.config = BigBirdConfig.from_pretrained(self._tokenizer.name_or_path) 14 | # hardcoded values 15 | self.config.max_sentence_size = 128 16 | self.config.max_sentence_length = 128 17 | self.config.max_sentences = 32 18 | self.config.model_max_length = 4096 19 | self._tokenizer.model_max_length = self.model_max_length 20 | self.type2id = {'input_ids': (self._tokenizer.sep_token_id, self._tokenizer.pad_token_id), 21 | 'token_type_ids': (0, 0), 22 | 'attention_mask': (1, 0), 23 | 'special_tokens_mask': (1, -100)} 24 | 25 | @property 26 | def model_max_length(self): 27 | return self.config.model_max_length 28 | 29 | @property 30 | def mask_token(self): 31 | return self._tokenizer.mask_token 32 | 33 | @property 34 | def mask_token_id(self): 35 | return self._tokenizer.mask_token_id 36 | 37 | @property 38 | def pad_token_id(self): 39 | return self._tokenizer.pad_token_id 40 | 41 | @property 42 | def cls_token_id(self): 43 | return self._tokenizer.cls_token_id 44 | 45 | @property 46 | def sep_token_id(self): 47 | return self._tokenizer.sep_token_id 48 | 49 | @property 50 | def vocab(self): 51 | return self._tokenizer.vocab 52 | 53 | def __len__(self): 54 | """ 55 | Size of the full vocabulary with the added tokens. 56 | """ 57 | return len(self._tokenizer) 58 | 59 | def pad(self, *args, **kwargs): 60 | return self._tokenizer.pad(*args, **kwargs) 61 | 62 | def convert_tokens_to_ids(self, *args, **kwargs): 63 | return self._tokenizer.convert_tokens_to_ids(*args, **kwargs) 64 | 65 | def batch_decode(self, *args, **kwargs): 66 | return self._tokenizer.batch_decode(*args, **kwargs) 67 | 68 | def decode(self, *args, **kwargs): 69 | return self._tokenizer.decode(*args, **kwargs) 70 | 71 | def tokenize(self, text, **kwargs): 72 | return self._tokenizer.tokenize(text, **kwargs) 73 | 74 | def encode(self, text, **kwargs): 75 | input_ids = self._tokenizer.encode_plus(text, add_special_tokens=False, **kwargs) 76 | input_ids = self.chunks(input_ids[: self.model_max_length - self.config.max_sentences], 77 | chunk_size=self.config.max_sentence_length, special_id=self.type2id['input_ids']) 78 | 79 | for idx, _ in enumerate(input_ids): 80 | input_ids[idx][0] = self._tokenizer.cls_token_id 81 | 82 | return input_ids 83 | 84 | def get_special_tokens_mask(self, *args, **kwargs): 85 | return self._tokenizer.get_special_tokens_mask(*args, **kwargs) 86 | 87 | @classmethod 88 | def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): 89 | return cls(tokenizer=AutoTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)) 90 | 91 | def save_pretrained(self, *args, **kwargs): 92 | return self._tokenizer.save_pretrained( *args, **kwargs) 93 | 94 | def __call__(self, texts, **kwargs): 95 | greedy_chunking = kwargs.pop('greedy_chunking', None) 96 | if isinstance(texts[0], list): 97 | batch = self.auto_chunking(texts, **kwargs) 98 | else: 99 | if greedy_chunking: 100 | # fixed uniform chunking 101 | batch = self.uniform_chunking(texts, **kwargs) 102 | else: 103 | # dynamic sentence splitting and grouping 104 | batch = self.sentence_splitting(texts, **kwargs) 105 | 106 | for idx, _ in enumerate(batch['input_ids']): 107 | batch['input_ids'][idx][0] = self._tokenizer.cls_token_id 108 | 109 | if kwargs['padding']: 110 | batch = self.pad(batch, 111 | padding=kwargs['padding'], 112 | max_length=kwargs['max_length'], 113 | pad_to_multiple_of=kwargs['max_length']) 114 | 115 | return batch 116 | 117 | def auto_chunking(self, texts, **kwargs): 118 | batch = {} 119 | for text_idx, text in enumerate(texts): 120 | example_batch = self._tokenizer(text, add_special_tokens=False, **kwargs) 121 | for input_key in example_batch: 122 | key_inputs_list = [] 123 | for idx, example in enumerate(example_batch[input_key][:self.config.max_sentences]): 124 | key_inputs_list.append(self.pad_sentence(example, special_id=self.type2id[input_key])) 125 | if isinstance(key_inputs_list[0], list): 126 | key_inputs_list = [token for sentence in key_inputs_list for token in sentence] 127 | else: 128 | key_inputs_list = torch.stack(key_inputs_list) 129 | if input_key in batch: 130 | batch[input_key].append(key_inputs_list) 131 | else: 132 | batch[input_key] = [key_inputs_list] 133 | 134 | return batch 135 | 136 | def uniform_chunking(self, texts, **kwargs): 137 | original_batch = self._tokenizer(texts, add_special_tokens=False, **kwargs) 138 | batch = {input_type: [] for input_type in original_batch} 139 | for input_type in original_batch: 140 | fixed_batch = [] 141 | for example in original_batch[input_type]: 142 | fixed_batch.append(self.chunks(example[: self.model_max_length - self.config.max_sentences], 143 | chunk_size=self.config.max_sentence_length, 144 | special_id=self.type2id[input_type])) 145 | batch[input_type] = fixed_batch if isinstance(fixed_batch[0], list) else torch.stack(fixed_batch) 146 | return batch 147 | 148 | def chunks(self, flat_inputs, chunk_size=128, special_id=0): 149 | if isinstance(flat_inputs, list): 150 | return self.list_chunks(flat_inputs, chunk_size, special_id) 151 | else: 152 | return self.tensor_chunks(flat_inputs, chunk_size, special_id) 153 | 154 | def list_chunks(self, flat_inputs, chunk_size=128, special_id=(0, 0)): 155 | """Yield successive n-sized chunks from lst.""" 156 | structured_inputs = [[special_id[0] if sum(flat_inputs[i:i + chunk_size-1]) else special_id[1]] 157 | + flat_inputs[i:i + chunk_size-1] for i in range(0, len(flat_inputs), chunk_size-1)] 158 | return [token_input for sentence_inputs in structured_inputs for token_input in sentence_inputs] 159 | 160 | def tensor_chunks(self, flat_inputs, chunk_size=128, special_id=(0, 0)): 161 | """Yield successive n-sized chunks from lst.""" 162 | structured_inputs = torch.stack([torch.cat((torch.tensor([special_id[0] if flat_inputs[i:i + chunk_size-1].sum() else special_id[1]], dtype=torch.int), 163 | flat_inputs[i:i + chunk_size-1])) for i in range(0, len(flat_inputs), chunk_size-1)]) 164 | return structured_inputs.reshape(-1) 165 | 166 | def sentence_splitting(self, texts, **kwargs): 167 | fixed_batch = [] 168 | doc_out = {} 169 | for text in texts: 170 | # sentence splitting 171 | sentences = sent_tokenize(text) 172 | # tokenization of sentences 173 | sentences = self._tokenizer(sentences, add_special_tokens=False, padding=False, truncation=False) 174 | # sentence grouping - merging short sentences to minimize padding 175 | doc_out = self.sentence_grouping(sentences) 176 | fixed_batch.append(doc_out) 177 | # batchify examples 178 | batch = {input_type: [] for input_type in doc_out} 179 | for input_type in batch: 180 | batch[input_type] = [example[input_type] for example in fixed_batch] 181 | if not isinstance(batch[input_type][0], list): 182 | batch[input_type] = torch.stack(batch[input_type]) 183 | 184 | return batch 185 | 186 | def sentence_grouping(self, sentences): 187 | doc_out = {input_type: [] for input_type in sentences} 188 | for input_type in sentences: 189 | tmp_doc = [] 190 | tmp_sentence = [] 191 | for example in sentences[input_type]: 192 | if len(tmp_doc) >= self.config.max_sentences: 193 | break 194 | if len(tmp_sentence) + len(example) <= self.config.max_sentence_length - 1: 195 | tmp_sentence.extend(example) 196 | else: 197 | tmp_doc.append(self.pad_sentence(tmp_sentence if len(tmp_sentence) else example, 198 | chunk_size=self.config.max_sentence_length, 199 | special_id=self.type2id[input_type])) 200 | tmp_sentence = example if len(tmp_sentence) else example[self.config.max_sentence_length:] 201 | if len(tmp_sentence) and len(tmp_doc) < self.config.max_sentences: 202 | tmp_doc.append(self.pad_sentence(tmp_sentence, 203 | chunk_size=self.config.max_sentence_length, 204 | special_id=self.type2id[input_type])) 205 | doc_out[input_type] = [token for sentence in tmp_doc for token in sentence] 206 | return doc_out 207 | 208 | def pad_sentence(self, flat_input, chunk_size=128, special_id=(0, 0)): 209 | if isinstance(flat_input, list): 210 | return [special_id[0]] + flat_input[:chunk_size-1] + [self.pad_token_id] * max(0, chunk_size - len(flat_input) - 1) 211 | else: 212 | return torch.cat((torch.tensor([special_id[0] if flat_input[:chunk_size-1].sum() 213 | else special_id[1]], dtype=torch.int), 214 | flat_input[:chunk_size-1], 215 | torch.tensor([self.pad_token_id] * max(0, chunk_size - len(flat_input) - 1), dtype=torch.int) 216 | )) 217 | 218 | 219 | if __name__ == "__main__": 220 | tokenizer = BigbirdTokenizer.from_pretrained('roberta-base') 221 | inputs = tokenizer([' '.join(['dog'] * 8192), 222 | ' '.join(['cat'] * 7000), 223 | ' '.join(['mouse'] * 5000)], 224 | padding=True, max_length=8192, truncation=True 225 | ) 226 | print() 227 | -------------------------------------------------------------------------------- /models/longformer/tokenization_longformer.py: -------------------------------------------------------------------------------- 1 | """Tokenization classes for Longformer.""" 2 | import torch 3 | from transformers import AutoTokenizer 4 | from transformers.models.longformer.configuration_longformer import LongformerConfig 5 | from transformers.utils import logging 6 | try: 7 | from nltk import sent_tokenize 8 | except: 9 | raise Exception('NLTK is not installed! Install it with `pip install nltk`...') 10 | logger = logging.get_logger(__name__) 11 | 12 | 13 | class LongformerTokenizer: 14 | def __init__(self, tokenizer=None): 15 | self._tokenizer = tokenizer 16 | self.config = LongformerConfig.from_pretrained(self._tokenizer.name_or_path) 17 | # hardcoded values 18 | self.config.max_sentence_size = 128 19 | self.config.max_sentence_length = 128 20 | self.config.max_sentences = 32 21 | self.config.model_max_length = 4096 22 | self._tokenizer.model_max_length = self.model_max_length 23 | self.type2id = {'input_ids': (self._tokenizer.sep_token_id, self._tokenizer.pad_token_id), 24 | 'token_type_ids': (0, 0), 25 | 'attention_mask': (1, 0), 26 | 'special_tokens_mask': (1, -100)} 27 | 28 | @property 29 | def model_max_length(self): 30 | return self.config.model_max_length 31 | 32 | @property 33 | def mask_token(self): 34 | return self._tokenizer.mask_token 35 | 36 | @property 37 | def mask_token_id(self): 38 | return self._tokenizer.mask_token_id 39 | 40 | @property 41 | def pad_token_id(self): 42 | return self._tokenizer.pad_token_id 43 | 44 | @property 45 | def cls_token_id(self): 46 | return self._tokenizer.cls_token_id 47 | 48 | @property 49 | def sep_token_id(self): 50 | return self._tokenizer.sep_token_id 51 | 52 | @property 53 | def vocab(self): 54 | return self._tokenizer.vocab 55 | 56 | def __len__(self): 57 | """ 58 | Size of the full vocabulary with the added tokens. 59 | """ 60 | return len(self._tokenizer) 61 | 62 | def pad(self, *args, **kwargs): 63 | return self._tokenizer.pad(*args, **kwargs) 64 | 65 | def convert_tokens_to_ids(self, *args, **kwargs): 66 | return self._tokenizer.convert_tokens_to_ids(*args, **kwargs) 67 | 68 | def batch_decode(self, *args, **kwargs): 69 | return self._tokenizer.batch_decode(*args, **kwargs) 70 | 71 | def decode(self, *args, **kwargs): 72 | return self._tokenizer.decode(*args, **kwargs) 73 | 74 | def tokenize(self, text, **kwargs): 75 | return self._tokenizer.tokenize(text, **kwargs) 76 | 77 | def encode(self, text, **kwargs): 78 | input_ids = self._tokenizer.encode_plus(text, add_special_tokens=False, **kwargs) 79 | input_ids = self.chunks(input_ids[: self.model_max_length - self.config.max_sentences], 80 | chunk_size=self.config.max_sentence_length, special_id=self.type2id['input_ids']) 81 | 82 | for idx, _ in enumerate(input_ids): 83 | input_ids[idx][0] = self._tokenizer.cls_token_id 84 | 85 | return input_ids 86 | 87 | def get_special_tokens_mask(self, *args, **kwargs): 88 | return self._tokenizer.get_special_tokens_mask(*args, **kwargs) 89 | 90 | @classmethod 91 | def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): 92 | return cls(tokenizer=AutoTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)) 93 | 94 | def save_pretrained(self, *args, **kwargs): 95 | return self._tokenizer.save_pretrained( *args, **kwargs) 96 | 97 | def __call__(self, texts, **kwargs): 98 | greedy_chunking = kwargs.pop('greedy_chunking', None) 99 | if isinstance(texts[0], list): 100 | batch = self.auto_chunking(texts, **kwargs) 101 | else: 102 | if greedy_chunking: 103 | # fixed uniform chunking 104 | batch = self.uniform_chunking(texts, **kwargs) 105 | else: 106 | # dynamic sentence splitting and grouping 107 | batch = self.sentence_splitting(texts, **kwargs) 108 | 109 | for idx, _ in enumerate(batch['input_ids']): 110 | batch['input_ids'][idx][0] = self._tokenizer.cls_token_id 111 | 112 | if kwargs['padding']: 113 | batch = self.pad(batch, 114 | padding=kwargs['padding'], 115 | max_length=kwargs['max_length'], 116 | pad_to_multiple_of=kwargs['max_length']) 117 | 118 | return batch 119 | 120 | def auto_chunking(self, texts, **kwargs): 121 | batch = {} 122 | for text_idx, text in enumerate(texts): 123 | example_batch = self._tokenizer(text, add_special_tokens=False, **kwargs) 124 | for input_key in example_batch: 125 | key_inputs_list = [] 126 | for idx, example in enumerate(example_batch[input_key][:self.config.max_sentences]): 127 | key_inputs_list.append(self.pad_sentence(example, special_id=self.type2id[input_key])) 128 | if isinstance(key_inputs_list[0], list): 129 | key_inputs_list = [token for sentence in key_inputs_list for token in sentence] 130 | else: 131 | key_inputs_list = torch.stack(key_inputs_list) 132 | if input_key in batch: 133 | batch[input_key].append(key_inputs_list) 134 | else: 135 | batch[input_key] = [key_inputs_list] 136 | 137 | return batch 138 | 139 | def uniform_chunking(self, texts, **kwargs): 140 | original_batch = self._tokenizer(texts, add_special_tokens=False, **kwargs) 141 | batch = {input_type: [] for input_type in original_batch} 142 | for input_type in original_batch: 143 | fixed_batch = [] 144 | for example in original_batch[input_type]: 145 | fixed_batch.append(self.chunks(example[: self.model_max_length - self.config.max_sentences], 146 | chunk_size=self.config.max_sentence_length, 147 | special_id=self.type2id[input_type])) 148 | batch[input_type] = fixed_batch if isinstance(fixed_batch[0], list) else torch.stack(fixed_batch) 149 | return batch 150 | 151 | def chunks(self, flat_inputs, chunk_size=128, special_id=0): 152 | if isinstance(flat_inputs, list): 153 | return self.list_chunks(flat_inputs, chunk_size, special_id) 154 | else: 155 | return self.tensor_chunks(flat_inputs, chunk_size, special_id) 156 | 157 | def list_chunks(self, flat_inputs, chunk_size=128, special_id=(0, 0)): 158 | """Yield successive n-sized chunks from lst.""" 159 | structured_inputs = [[special_id[0] if sum(flat_inputs[i:i + chunk_size-1]) else special_id[1]] 160 | + flat_inputs[i:i + chunk_size-1] for i in range(0, len(flat_inputs), chunk_size-1)] 161 | return [token_input for sentence_inputs in structured_inputs for token_input in sentence_inputs] 162 | 163 | def tensor_chunks(self, flat_inputs, chunk_size=128, special_id=(0, 0)): 164 | """Yield successive n-sized chunks from lst.""" 165 | structured_inputs = torch.stack([torch.cat((torch.tensor([special_id[0] if flat_inputs[i:i + chunk_size-1].sum() else special_id[1]], dtype=torch.int), 166 | flat_inputs[i:i + chunk_size-1])) for i in range(0, len(flat_inputs), chunk_size-1)]) 167 | return structured_inputs.reshape(-1) 168 | 169 | def sentence_splitting(self, texts, **kwargs): 170 | fixed_batch = [] 171 | doc_out = {} 172 | for text in texts: 173 | # sentence splitting 174 | sentences = sent_tokenize(text) 175 | # tokenization of sentences 176 | sentences = self._tokenizer(sentences, add_special_tokens=False, padding=False, truncation=False) 177 | # sentence grouping - merging short sentences to minimize padding 178 | doc_out = self.sentence_grouping(sentences) 179 | fixed_batch.append(doc_out) 180 | # batchify examples 181 | batch = {input_type: [] for input_type in doc_out} 182 | for input_type in batch: 183 | batch[input_type] = [example[input_type] for example in fixed_batch] 184 | if not isinstance(batch[input_type][0], list): 185 | batch[input_type] = torch.stack(batch[input_type]) 186 | 187 | return batch 188 | 189 | def sentence_grouping(self, sentences): 190 | doc_out = {input_type: [] for input_type in sentences} 191 | for input_type in sentences: 192 | tmp_doc = [] 193 | tmp_sentence = [] 194 | for example in sentences[input_type]: 195 | if len(tmp_doc) >= self.config.max_sentences: 196 | break 197 | if len(tmp_sentence) + len(example) <= self.config.max_sentence_length - 1: 198 | tmp_sentence.extend(example) 199 | else: 200 | tmp_doc.append(self.pad_sentence(tmp_sentence if len(tmp_sentence) else example, 201 | chunk_size=self.config.max_sentence_length, 202 | special_id=self.type2id[input_type])) 203 | tmp_sentence = example if len(tmp_sentence) else example[self.config.max_sentence_length:] 204 | if len(tmp_sentence) and len(tmp_doc) < self.config.max_sentences: 205 | tmp_doc.append(self.pad_sentence(tmp_sentence, 206 | chunk_size=self.config.max_sentence_length, 207 | special_id=self.type2id[input_type])) 208 | doc_out[input_type] = [token for sentence in tmp_doc for token in sentence] 209 | return doc_out 210 | 211 | def pad_sentence(self, flat_input, chunk_size=128, special_id=(0, 0)): 212 | if isinstance(flat_input, list): 213 | return [special_id[0]] + flat_input[:chunk_size-1] + [self.pad_token_id] * max(0, chunk_size - len(flat_input) - 1) 214 | else: 215 | return torch.cat((torch.tensor([special_id[0] if flat_input[:chunk_size-1].sum() 216 | else special_id[1]], dtype=torch.int), 217 | flat_input[:chunk_size-1], 218 | torch.tensor([self.pad_token_id] * max(0, chunk_size - len(flat_input) - 1), dtype=torch.int) 219 | )) 220 | 221 | 222 | if __name__ == "__main__": 223 | tokenizer = LongformerTokenizer.from_pretrained('roberta-base') 224 | inputs = tokenizer([' '.join(['dog'] * 8192), 225 | ' '.join(['cat'] * 7000), 226 | ' '.join(['mouse'] * 5000)], 227 | padding=True, max_length=8192, truncation=True 228 | ) 229 | print() 230 | -------------------------------------------------------------------------------- /models/hat/tokenization_hat.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | """Tokenization classes for HAT.""" 14 | import torch 15 | from transformers import RobertaTokenizer, BertTokenizer 16 | from .configuration_hat import HATConfig 17 | from transformers.utils import logging 18 | try: 19 | from nltk import sent_tokenize 20 | except: 21 | raise Exception('NLTK is not installed! Install it with `pip install nltk`...') 22 | logger = logging.get_logger(__name__) 23 | 24 | 25 | class HATTokenizer: 26 | def __init__(self, tokenizer=None): 27 | self._tokenizer = tokenizer 28 | self.config = HATConfig.from_pretrained(self._tokenizer.name_or_path) 29 | self._tokenizer.model_max_length = self.model_max_length 30 | self.type2id = {'input_ids': (self._tokenizer.cls_token_id, self._tokenizer.pad_token_id), 31 | 'token_type_ids': (0, 0), 32 | 'attention_mask': (1, 0), 33 | 'special_tokens_mask': (1, -100)} 34 | 35 | @property 36 | def model_max_length(self): 37 | return self.config.model_max_length 38 | 39 | @property 40 | def mask_token(self): 41 | return self._tokenizer.mask_token 42 | 43 | @property 44 | def mask_token_id(self): 45 | return self._tokenizer.mask_token_id 46 | 47 | @property 48 | def pad_token_id(self): 49 | return self._tokenizer.pad_token_id 50 | 51 | @property 52 | def cls_token_id(self): 53 | return self._tokenizer.cls_token_id 54 | 55 | @property 56 | def sep_token_id(self): 57 | return self._tokenizer.sep_token_id 58 | 59 | @property 60 | def vocab(self): 61 | return self._tokenizer.vocab 62 | 63 | def __len__(self): 64 | """ 65 | Size of the full vocabulary with the added tokens. 66 | """ 67 | return len(self._tokenizer) 68 | 69 | def pad(self, *args, **kwargs): 70 | return self._tokenizer.pad(*args, **kwargs) 71 | 72 | def convert_tokens_to_ids(self, *args, **kwargs): 73 | return self._tokenizer.convert_tokens_to_ids(*args, **kwargs) 74 | 75 | def batch_decode(self, *args, **kwargs): 76 | return self._tokenizer.batch_decode(*args, **kwargs) 77 | 78 | def decode(self, *args, **kwargs): 79 | return self._tokenizer.decode(*args, **kwargs) 80 | 81 | def tokenize(self, text, **kwargs): 82 | return self._tokenizer.tokenize(text, **kwargs) 83 | 84 | def encode(self, text, **kwargs): 85 | input_ids = self._tokenizer.encode_plus(text, add_special_tokens=False, **kwargs) 86 | input_ids = self.chunks(input_ids[: self.model_max_length - self.config.max_sentences], 87 | chunk_size=self.config.max_sentence_length, special_id=self.type2id['input_ids']) 88 | return input_ids 89 | 90 | def get_special_tokens_mask(self, *args, **kwargs): 91 | return self._tokenizer.get_special_tokens_mask(*args, **kwargs) 92 | 93 | @classmethod 94 | def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): 95 | try: 96 | tokenizer = RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs) 97 | except: 98 | tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs) 99 | return cls(tokenizer=tokenizer) 100 | 101 | def save_pretrained(self, *args, **kwargs): 102 | return self._tokenizer.save_pretrained( *args, **kwargs) 103 | 104 | def __call__(self, text, **kwargs): 105 | greedy_chunking = kwargs.pop('greedy_chunking', None) 106 | text_pair = kwargs.pop('text_pair', None) 107 | if isinstance(text[0], list): 108 | batch = self.auto_chunking(text, **kwargs) 109 | elif greedy_chunking: 110 | # fixed uniform chunking 111 | batch = self.uniform_chunking(text, **kwargs) 112 | else: 113 | # dynamic sentence splitting and grouping 114 | batch = self.sentence_splitting(text, **kwargs) 115 | 116 | if text_pair: 117 | batch_b = self._tokenizer(text_pair, add_special_tokens=False, 118 | padding=False, truncation=False) 119 | for idx, sample in enumerate(batch['input_ids']): 120 | n_sentences = sum(sample[::self.config.max_sentence_size]) 121 | for input_key in batch: 122 | batch[input_key][idx][self.config.max_sentence_size * n_sentences: 123 | self.config.max_sentence_size * (n_sentences + 1)] = \ 124 | self.pad_sentence(batch_b[input_key][idx], 125 | special_id=(self.sep_token_id, self.pad_token_id) 126 | if input_key == 'input_ids' else self.type2id[input_key]) 127 | 128 | return batch 129 | 130 | def uniform_chunking(self, texts, **kwargs): 131 | original_batch = self._tokenizer(texts, add_special_tokens=False, **kwargs) 132 | batch = {input_type: [] for input_type in original_batch} 133 | for input_type in original_batch: 134 | fixed_batch = [] 135 | for example in original_batch[input_type]: 136 | fixed_batch.append(self.chunks(example[: self.model_max_length - self.config.max_sentences], 137 | chunk_size=self.config.max_sentence_length, 138 | special_id=self.type2id[input_type])) 139 | batch[input_type] = fixed_batch if isinstance(fixed_batch[0], list) else torch.stack(fixed_batch) 140 | 141 | if kwargs['padding']: 142 | batch = self.pad(batch, 143 | padding=kwargs['padding'], 144 | max_length=kwargs['max_length'], 145 | pad_to_multiple_of=kwargs['max_length']) 146 | 147 | return batch 148 | 149 | def auto_chunking(self, texts, **kwargs): 150 | batch = {} 151 | for text_idx, text in enumerate(texts): 152 | example_batch = self._tokenizer(text, add_special_tokens=False, **kwargs) 153 | for input_key in example_batch: 154 | key_inputs_list = [] 155 | for idx, example in enumerate(example_batch[input_key][:self.config.max_sentences]): 156 | key_inputs_list.append(self.pad_sentence(example, special_id=self.type2id[input_key])) 157 | if isinstance(key_inputs_list[0], list): 158 | key_inputs_list = [token for sentence in key_inputs_list for token in sentence] 159 | else: 160 | key_inputs_list = torch.stack([token for sentence in key_inputs_list for token in sentence]) 161 | if input_key in batch: 162 | batch[input_key].append(key_inputs_list) 163 | else: 164 | batch[input_key] = [key_inputs_list] 165 | 166 | if kwargs['padding']: 167 | batch = self.pad(batch, 168 | padding=kwargs['padding'], 169 | max_length=kwargs['max_length'], 170 | pad_to_multiple_of=kwargs['max_length']) 171 | 172 | return batch 173 | 174 | def chunks(self, flat_inputs, chunk_size=128, special_id=0): 175 | if isinstance(flat_inputs, list): 176 | return self.list_chunks(flat_inputs, chunk_size, special_id) 177 | else: 178 | return self.tensor_chunks(flat_inputs, chunk_size, special_id) 179 | 180 | def list_chunks(self, flat_inputs, chunk_size=128, special_id=(0, 0)): 181 | """Yield successive n-sized chunks from lst.""" 182 | structured_inputs = [[special_id[0] if sum(flat_inputs[i:i + chunk_size-1]) else special_id[1]] 183 | + flat_inputs[i:i + chunk_size-1] for i in range(0, len(flat_inputs), chunk_size-1)] 184 | return [token_input for sentence_inputs in structured_inputs for token_input in sentence_inputs] 185 | 186 | def tensor_chunks(self, flat_inputs, chunk_size=128, special_id=(0, 0)): 187 | """Yield successive n-sized chunks from lst.""" 188 | structured_inputs = torch.stack([torch.cat((torch.tensor([special_id[0] if flat_inputs[i:i + chunk_size-1].sum() else special_id[1]], dtype=torch.int), 189 | flat_inputs[i:i + chunk_size-1])) for i in range(0, len(flat_inputs), chunk_size-1)]) 190 | return structured_inputs.reshape(-1) 191 | 192 | def sentence_splitting(self, texts, **kwargs): 193 | fixed_batch = [] 194 | doc_out = {} 195 | for text in texts: 196 | # sentence splitting 197 | sentences = sent_tokenize(text) 198 | # tokenization of sentences 199 | sentences = self._tokenizer(sentences, add_special_tokens=False, padding=False, truncation=False) 200 | # sentence grouping - merging short sentences to minimize padding 201 | doc_out = self.sentence_grouping(sentences) 202 | fixed_batch.append(doc_out) 203 | # batchify examples 204 | batch = {input_type: [] for input_type in doc_out} 205 | for input_type in batch: 206 | batch[input_type] = [example[input_type] for example in fixed_batch] 207 | if not isinstance(batch[input_type][0], list): 208 | batch[input_type] = torch.stack(batch[input_type]) 209 | 210 | if kwargs['padding']: 211 | batch = self.pad(batch, 212 | padding=kwargs['padding'], 213 | max_length=kwargs['max_length'], 214 | pad_to_multiple_of=kwargs['max_length']) 215 | 216 | return batch 217 | 218 | def sentence_grouping(self, sentences): 219 | doc_out = {input_type: [] for input_type in sentences} 220 | for input_type in sentences: 221 | tmp_doc = [] 222 | tmp_sentence = [] 223 | for example in sentences[input_type]: 224 | if len(tmp_doc) >= self.config.max_sentences: 225 | break 226 | if len(tmp_sentence) + len(example) <= self.config.max_sentence_length - 1: 227 | tmp_sentence.extend(example) 228 | else: 229 | tmp_doc.append(self.pad_sentence(tmp_sentence if len(tmp_sentence) else example, 230 | chunk_size=self.config.max_sentence_length, 231 | special_id=self.type2id[input_type])) 232 | tmp_sentence = example if len(tmp_sentence) else example[self.config.max_sentence_length:] 233 | if len(tmp_sentence) and len(tmp_doc) < self.config.max_sentences: 234 | tmp_doc.append(self.pad_sentence(tmp_sentence, 235 | chunk_size=self.config.max_sentence_length, 236 | special_id=self.type2id[input_type])) 237 | doc_out[input_type] = [token for sentence in tmp_doc for token in sentence] 238 | return doc_out 239 | 240 | def pad_sentence(self, flat_input, chunk_size=128, special_id=(0, 0)): 241 | if isinstance(flat_input, list): 242 | return [special_id[0]] + flat_input[:chunk_size-1] + [self.pad_token_id] * max(0, chunk_size - len(flat_input) - 1) 243 | else: 244 | return torch.cat((torch.tensor([special_id[0] if flat_input[:chunk_size-1].sum() 245 | else special_id[1]], dtype=torch.int), 246 | flat_input[:chunk_size-1], 247 | torch.tensor([self.pad_token_id] * max(0, chunk_size - len(flat_input) - 1), dtype=torch.int) 248 | )) 249 | 250 | @classmethod 251 | def register_for_auto_class(cls, auto_class="AutoModel"): 252 | """ 253 | Register this class with a given auto class. This should only be used for custom models as the ones in the 254 | library are already mapped with an auto class. 255 | 256 | This API is experimental and may have some slight breaking changes in the next releases. 257 | 258 | Args: 259 | auto_class (`str` or `type`, *optional*, defaults to `"TFAutoModel"`): 260 | The auto class to register this new model with. 261 | """ 262 | if not isinstance(auto_class, str): 263 | auto_class = auto_class.__name__ 264 | 265 | import transformers.models.auto as auto_module 266 | 267 | if not hasattr(auto_module, auto_class): 268 | raise ValueError(f"{auto_class} is not a valid auto class.") 269 | 270 | cls._auto_class = auto_class 271 | 272 | -------------------------------------------------------------------------------- /evaluation/run_sequential_sentence_classification.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | """ Finetuning models on the ECtHR Argument Mining dataset (e.g. Bert, RoBERTa, LEGAL-BERT).""" 4 | 5 | import logging 6 | import os 7 | import random 8 | import sys 9 | from dataclasses import dataclass, field 10 | from typing import Optional 11 | 12 | import datasets 13 | from datasets import load_dataset 14 | from sklearn.metrics import f1_score, classification_report 15 | import glob 16 | import shutil 17 | 18 | import transformers 19 | from transformers import ( 20 | Trainer, 21 | AutoConfig, 22 | EvalPrediction, 23 | HfArgumentParser, 24 | TrainingArguments, 25 | set_seed, 26 | EarlyStoppingCallback, 27 | ) 28 | from transformers.trainer_utils import get_last_checkpoint 29 | from transformers.utils import check_min_version 30 | from transformers.utils.versions import require_version 31 | from data_collator import DataCollatorForMultiLabelClassification 32 | from models.hat import HATModelForSequentialSentenceClassification, HATTokenizer, HATConfig 33 | from models.longformer import LongformerModelForSentenceClassification, LongformerTokenizer 34 | from models.big_bird import BigBirdModelForSentenceClassification, BigbirdTokenizer 35 | 36 | 37 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. 38 | check_min_version("4.9.0") 39 | 40 | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt") 41 | 42 | logger = logging.getLogger(__name__) 43 | 44 | 45 | @dataclass 46 | class DataTrainingArguments: 47 | """ 48 | Arguments pertaining to what data we are going to input our model for training and eval. 49 | 50 | Using `HfArgumentParser` we can turn this class 51 | into argparse arguments to be able to specify them on 52 | the command line. 53 | """ 54 | 55 | max_seq_length: Optional[int] = field( 56 | default=4096, 57 | metadata={ 58 | "help": "The maximum total input sequence length after tokenization. Sequences longer " 59 | "than this will be truncated, sequences shorter will be padded." 60 | }, 61 | ) 62 | max_sentences: int = field( 63 | default=32, 64 | metadata={ 65 | "help": "The maximum number of sentences after tokenization. Sequences longer " 66 | "than this will be truncated." 67 | }, 68 | ) 69 | overwrite_cache: bool = field( 70 | default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."} 71 | ) 72 | pad_to_max_length: bool = field( 73 | default=True, 74 | metadata={ 75 | "help": "Whether to pad all samples to `max_seq_length`. " 76 | "If False, will pad the samples dynamically when batching to the maximum length in the batch." 77 | }, 78 | ) 79 | max_train_samples: Optional[int] = field( 80 | default=None, 81 | metadata={ 82 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this " 83 | "value if set." 84 | }, 85 | ) 86 | max_eval_samples: Optional[int] = field( 87 | default=None, 88 | metadata={ 89 | "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 90 | "value if set." 91 | }, 92 | ) 93 | max_predict_samples: Optional[int] = field( 94 | default=None, 95 | metadata={ 96 | "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this " 97 | "value if set." 98 | }, 99 | ) 100 | dataset_name: Optional[str] = field( 101 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} 102 | ) 103 | dataset_config_name: Optional[str] = field( 104 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 105 | ) 106 | task: Optional[str] = field( 107 | default='ecthr_arguments', 108 | metadata={ 109 | "help": "Define downstream task" 110 | }, 111 | ) 112 | server_ip: Optional[str] = field(default=None, metadata={"help": "For distant debugging."}) 113 | server_port: Optional[str] = field(default=None, metadata={"help": "For distant debugging."}) 114 | 115 | 116 | @dataclass 117 | class ModelArguments: 118 | """ 119 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 120 | """ 121 | 122 | model_name_or_path: str = field( 123 | default=None, metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 124 | ) 125 | config_name: Optional[str] = field( 126 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 127 | ) 128 | tokenizer_name: Optional[str] = field( 129 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 130 | ) 131 | cache_dir: Optional[str] = field( 132 | default=None, 133 | metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, 134 | ) 135 | use_fast_tokenizer: bool = field( 136 | default=True, 137 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 138 | ) 139 | model_revision: str = field( 140 | default="main", 141 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 142 | ) 143 | use_auth_token: bool = field( 144 | default=True, 145 | metadata={ 146 | "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " 147 | "with private models)." 148 | }, 149 | ) 150 | 151 | 152 | def main(): 153 | # See all possible arguments in src/transformers/training_args.py 154 | # or by passing the --help flag to this script. 155 | # We now keep distinct sets of args, for a cleaner separation of concerns. 156 | 157 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) 158 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 159 | 160 | # Setup distant debugging if needed 161 | if data_args.server_ip and data_args.server_port: 162 | # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script 163 | import ptvsd 164 | 165 | print("Waiting for debugger attach") 166 | ptvsd.enable_attach(address=(data_args.server_ip, data_args.server_port), redirect_output=True) 167 | ptvsd.wait_for_attach() 168 | 169 | # Setup logging 170 | logging.basicConfig( 171 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 172 | datefmt="%m/%d/%Y %H:%M:%S", 173 | handlers=[logging.StreamHandler(sys.stdout)], 174 | ) 175 | 176 | log_level = training_args.get_process_log_level() 177 | logger.setLevel(log_level) 178 | datasets.utils.logging.set_verbosity(log_level) 179 | transformers.utils.logging.set_verbosity(log_level) 180 | transformers.utils.logging.enable_default_handler() 181 | transformers.utils.logging.enable_explicit_format() 182 | 183 | # Log on each process the small summary: 184 | logger.warning( 185 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 186 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 187 | ) 188 | logger.info(f"Training/evaluation parameters {training_args}") 189 | 190 | # Detecting last checkpoint. 191 | last_checkpoint = None 192 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 193 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 194 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 195 | raise ValueError( 196 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 197 | "Use --overwrite_output_dir to overcome." 198 | ) 199 | elif last_checkpoint is not None: 200 | logger.info( 201 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 202 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 203 | ) 204 | 205 | # Set seed before initializing model. 206 | set_seed(training_args.seed) 207 | 208 | # In distributed training, the load_dataset function guarantees that only one local process can concurrently 209 | # download the dataset. 210 | # Downloading and loading eurlex dataset from the hub. 211 | if training_args.do_train: 212 | train_dataset = load_dataset( 213 | data_args.dataset_name, 214 | data_args.dataset_config_name, 215 | split="train", 216 | data_dir=data_args.dataset_name, 217 | cache_dir=model_args.cache_dir, 218 | ) 219 | 220 | if training_args.do_eval: 221 | eval_dataset = load_dataset( 222 | data_args.dataset_name, 223 | data_args.dataset_config_name, 224 | split="validation", 225 | data_dir=data_args.dataset_name, 226 | cache_dir=model_args.cache_dir, 227 | ) 228 | 229 | if training_args.do_predict: 230 | predict_dataset = load_dataset( 231 | data_args.dataset_name, 232 | data_args.dataset_config_name, 233 | split="test", 234 | data_dir=data_args.dataset_name, 235 | cache_dir=model_args.cache_dir, 236 | ) 237 | # Labels 238 | label_list = list(range(train_dataset.features['labels'].feature.feature.num_classes)) 239 | label_names = train_dataset.features['labels'].feature.feature.names 240 | num_labels = len(label_list) 241 | 242 | # Load pretrained model and tokenizer 243 | # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently 244 | # download model & vocab. 245 | 246 | if 'hat' in model_args.model_name_or_path: 247 | config = HATConfig.from_pretrained( 248 | model_args.model_name_or_path, 249 | num_labels=num_labels, 250 | finetuning_task="ecthr-args", 251 | cache_dir=model_args.cache_dir, 252 | revision=model_args.model_revision, 253 | ) 254 | tokenizer = HATTokenizer.from_pretrained( 255 | model_args.model_name_or_path, 256 | cache_dir=model_args.cache_dir, 257 | use_fast=model_args.use_fast_tokenizer, 258 | revision=model_args.model_revision, 259 | ) 260 | model = HATModelForSequentialSentenceClassification.from_pretrained( 261 | model_args.model_name_or_path, 262 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 263 | config=config, 264 | cache_dir=model_args.cache_dir, 265 | revision=model_args.model_revision, 266 | ) 267 | elif 'longformer' in model_args.model_name_or_path: 268 | config = AutoConfig.from_pretrained( 269 | model_args.model_name_or_path, 270 | num_labels=num_labels, 271 | finetuning_task="ecthr-args", 272 | cache_dir=model_args.cache_dir, 273 | revision=model_args.model_revision, 274 | ) 275 | config.max_sentence_size = 128 276 | config.max_sentence_length = 128 277 | config.max_sentences = data_args.max_sentences 278 | config.model_max_length = 4096 279 | config.cls_token_id = config.bos_token_id 280 | config.sep_token_id = config.eos_token_id 281 | tokenizer = LongformerTokenizer.from_pretrained( 282 | model_args.model_name_or_path, 283 | cache_dir=model_args.cache_dir, 284 | use_fast=model_args.use_fast_tokenizer, 285 | revision=model_args.model_revision, 286 | ) 287 | model = LongformerModelForSentenceClassification.from_pretrained( 288 | model_args.model_name_or_path, 289 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 290 | config=config, 291 | cache_dir=model_args.cache_dir, 292 | revision=model_args.model_revision, 293 | ) 294 | elif 'bigbird' in model_args.model_name_or_path: 295 | config = AutoConfig.from_pretrained( 296 | model_args.model_name_or_path, 297 | num_labels=num_labels, 298 | finetuning_task="ecthr-args", 299 | cache_dir=model_args.cache_dir, 300 | revision=model_args.model_revision, 301 | ) 302 | config.max_sentence_size = 128 303 | config.max_sentence_length = 128 304 | config.max_sentences = data_args.max_sentences 305 | config.model_max_length = 4096 306 | config.cls_token_id = config.bos_token_id 307 | config.sep_token_id = config.eos_token_id 308 | tokenizer = BigbirdTokenizer.from_pretrained( 309 | model_args.model_name_or_path, 310 | cache_dir=model_args.cache_dir, 311 | use_fast=model_args.use_fast_tokenizer, 312 | revision=model_args.model_revision, 313 | ) 314 | model = BigBirdModelForSentenceClassification.from_pretrained( 315 | model_args.model_name_or_path, 316 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 317 | config=config, 318 | cache_dir=model_args.cache_dir, 319 | revision=model_args.model_revision, 320 | ) 321 | 322 | # Preprocessing the datasets 323 | # Padding strategy 324 | if data_args.pad_to_max_length: 325 | padding = "max_length" 326 | else: 327 | # We will pad later, dynamically at batch creation, to the max sequence length in each batch 328 | padding = False 329 | 330 | # for document, labels in zip(train_dataset['text'], train_dataset['labels']): 331 | # for paragraph, par_labels in zip(document, labels): 332 | # par_labels = [label_names[label] for label in par_labels] 333 | # if len(par_labels) > 1: 334 | # print() 335 | 336 | def preprocess_function(examples): 337 | # Tokenize the texts 338 | batch = tokenizer( 339 | examples['text'], 340 | padding=padding, 341 | max_length=data_args.max_seq_length, 342 | truncation=True, 343 | ) 344 | 345 | label_ids = [] 346 | for idx, labels in enumerate(examples["labels"]): 347 | par_label_ids = [] 348 | for par_labels in labels[:tokenizer.config.max_sentences]: 349 | par_label_ids.append([1.0 if label in par_labels else 0.0 for label in label_list]) 350 | par_label_ids.extend([[-1.0] * len(label_list)] * (tokenizer.config.max_sentences - len(par_label_ids))) 351 | label_ids.append(par_label_ids) 352 | 353 | batch["label_ids"] = label_ids 354 | 355 | return batch 356 | 357 | if training_args.do_train: 358 | if data_args.max_train_samples is not None: 359 | train_dataset = train_dataset.select(range(data_args.max_train_samples)) 360 | with training_args.main_process_first(desc="train dataset map pre-processing"): 361 | train_dataset = train_dataset.map( 362 | preprocess_function, 363 | batched=True, 364 | load_from_cache_file=not data_args.overwrite_cache, 365 | desc="Running tokenizer on train dataset", 366 | ) 367 | # Log a few random samples from the training set: 368 | for index in random.sample(range(len(train_dataset)), 3): 369 | logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") 370 | 371 | if training_args.do_eval: 372 | if data_args.max_eval_samples is not None: 373 | eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) 374 | with training_args.main_process_first(desc="validation dataset map pre-processing"): 375 | eval_dataset = eval_dataset.map( 376 | preprocess_function, 377 | batched=True, 378 | load_from_cache_file=not data_args.overwrite_cache, 379 | desc="Running tokenizer on validation dataset", 380 | ) 381 | 382 | if training_args.do_predict: 383 | if data_args.max_predict_samples is not None: 384 | predict_dataset = predict_dataset.select(range(data_args.max_predict_samples)) 385 | with training_args.main_process_first(desc="prediction dataset map pre-processing"): 386 | predict_dataset = predict_dataset.map( 387 | preprocess_function, 388 | batched=True, 389 | load_from_cache_file=not data_args.overwrite_cache, 390 | desc="Running tokenizer on prediction dataset", 391 | ) 392 | 393 | # You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a 394 | # predictions and label_ids field) and has to return a dictionary string to float. 395 | def compute_metrics(p: EvalPrediction): 396 | # Fix scores 397 | from multi_label_utils import fix_multi_label_scores 398 | y_true, y_pred = fix_multi_label_scores(p.predictions, p.label_ids, 399 | unpad_sequences=True, flatten_sequences=True) 400 | 401 | # Compute scores 402 | macro_f1 = f1_score(y_true=y_true, y_pred=y_pred, average='macro', zero_division=0) 403 | micro_f1 = f1_score(y_true=y_true, y_pred=y_pred, average='micro', zero_division=0) 404 | logger.info(classification_report(y_true=y_true, y_pred=y_pred, 405 | target_names=label_names + ['None'], zero_division=0)+'\n') 406 | return {'macro-f1': macro_f1, 'micro-f1': micro_f1} 407 | 408 | # Initialize our Trainer 409 | trainer = Trainer( 410 | model=model, 411 | args=training_args, 412 | train_dataset=train_dataset if training_args.do_train else None, 413 | eval_dataset=eval_dataset if training_args.do_eval else None, 414 | compute_metrics=compute_metrics, 415 | tokenizer=tokenizer, 416 | data_collator=DataCollatorForMultiLabelClassification(tokenizer), 417 | callbacks=[EarlyStoppingCallback(early_stopping_patience=3)] 418 | ) 419 | 420 | # Training 421 | if training_args.do_train: 422 | checkpoint = None 423 | if training_args.resume_from_checkpoint is not None: 424 | checkpoint = training_args.resume_from_checkpoint 425 | elif last_checkpoint is not None: 426 | checkpoint = last_checkpoint 427 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 428 | metrics = train_result.metrics 429 | max_train_samples = ( 430 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 431 | ) 432 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 433 | 434 | trainer.save_model() # Saves the tokenizer too for easy upload 435 | 436 | trainer.log_metrics("train", metrics) 437 | trainer.save_metrics("train", metrics) 438 | trainer.save_state() 439 | 440 | # Evaluation 441 | if training_args.do_eval: 442 | logger.info("*** Evaluate ***") 443 | metrics = trainer.evaluate(eval_dataset=eval_dataset) 444 | 445 | max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) 446 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 447 | 448 | trainer.log_metrics("eval", metrics) 449 | trainer.save_metrics("eval", metrics) 450 | 451 | # Prediction 452 | if training_args.do_predict: 453 | logger.info("*** Predict ***") 454 | predictions, labels, metrics = trainer.predict(predict_dataset, metric_key_prefix="predict") 455 | 456 | max_predict_samples = ( 457 | data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset) 458 | ) 459 | metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset)) 460 | 461 | trainer.log_metrics("predict", metrics) 462 | trainer.save_metrics("predict", metrics) 463 | 464 | output_predict_file = os.path.join(training_args.output_dir, "test_predictions.csv") 465 | report_predict_file = os.path.join(training_args.output_dir, "classification_report.txt") 466 | if trainer.is_world_process_zero(): 467 | with open(output_predict_file, "w") as writer: 468 | try: 469 | for index, pred_list in enumerate(predictions[0]): 470 | pred_line = '\t'.join([f'{pred:.5f}' for pred in pred_list]) 471 | writer.write(f"{index}\t{pred_line}\n") 472 | except: 473 | try: 474 | for index, pred_list in enumerate(predictions): 475 | pred_line = '\t'.join([f'{pred:.5f}' for pred in pred_list]) 476 | writer.write(f"{index}\t{pred_line}\n") 477 | except: 478 | pass 479 | 480 | # Clean up checkpoints 481 | checkpoints = [filepath for filepath in glob.glob(f'{training_args.output_dir}/*/') if '/checkpoint' in filepath] 482 | for checkpoint in checkpoints: 483 | shutil.rmtree(checkpoint) 484 | 485 | 486 | if __name__ == "__main__": 487 | main() 488 | --------------------------------------------------------------------------------