├── models
├── __init__.py
├── util
│ ├── __init__.py
│ ├── efficiency_metrics_bert_models.py
│ ├── benchmark_original_models.py
│ └── efficiency_metrics_roberta.py
├── big_bird
│ ├── __init__.py
│ ├── modeling_big_bird.py
│ └── tokenization_big_bird.py
├── longformer
│ ├── __init__.py
│ ├── convert_roberta_to_lf.py
│ ├── convert_bert_to_lf.py
│ └── tokenization_longformer.py
└── hat
│ ├── __init__.py
│ ├── convert_roberta_to_htf.py
│ ├── convert_bert_to_htf.py
│ ├── configuration_hat.py
│ └── tokenization_hat.py
├── data
├── PLMs
│ └── placeholder.txt
├── __init__.py
├── .DS_Store
├── roberta
│ ├── .DS_Store
│ ├── special_tokens_map.json
│ ├── tokenizer_config.json
│ └── config.json
├── figures
│ ├── hat_encoder.png
│ └── hat_layouts.png
├── hi-transformer
│ └── .DS_Store
├── quality-dataset
│ ├── read_quality.py
│ └── quality-dataset.py
├── hat
│ ├── special_tokens_map.json
│ ├── tokenizer_config.json
│ └── config.json
├── check-data.py
├── check_training_arguments.py
├── c4-dataset
│ ├── transform-data.py
│ └── c4-dataset.py
├── mimic-dataset
│ └── mimic-dataset.py
├── ecthr-arguments-dataset
│ ├── create_dataset.py
│ └── ecthr-arguments-dataset.py
└── contractnli-dataset
│ └── contractnli-dataset.py
├── requirements.txt
├── running_scripts
├── fine_tuning
│ ├── train_sentence_order_prediction.sh
│ ├── train_masked_sentence_mcqa.sh
│ ├── train_document_classification.sh
│ ├── train_ecthr_classification.sh
│ ├── train_mimic_classification.sh
│ ├── train_contract_nli.sh
│ ├── train_ecthr_arguments_classification.sh
│ └── train_quality_mcqa.sh
└── pre_training
│ ├── train_longformer.sh
│ ├── train_longformer_base.sh
│ ├── train_hat.sh
│ └── train_hat_base.sh
├── LICENSE
├── evaluation
├── multi_label_utils.py
├── data_collator.py
└── run_sequential_sentence_classification.py
├── language_modelling
├── prepare_dataset.py
├── xla_spawn.py
├── train_tokenizer.py
└── text_featurization.py
├── .gitignore
└── README.md
/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/util/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/data/PLMs/placeholder.txt:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | DATA_DIR = os.path.dirname(os.path.realpath(__file__))
--------------------------------------------------------------------------------
/data/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/coastalcph/hierarchical-transformers/HEAD/data/.DS_Store
--------------------------------------------------------------------------------
/data/roberta/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/coastalcph/hierarchical-transformers/HEAD/data/roberta/.DS_Store
--------------------------------------------------------------------------------
/data/figures/hat_encoder.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/coastalcph/hierarchical-transformers/HEAD/data/figures/hat_encoder.png
--------------------------------------------------------------------------------
/data/figures/hat_layouts.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/coastalcph/hierarchical-transformers/HEAD/data/figures/hat_layouts.png
--------------------------------------------------------------------------------
/data/hi-transformer/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/coastalcph/hierarchical-transformers/HEAD/data/hi-transformer/.DS_Store
--------------------------------------------------------------------------------
/data/quality-dataset/read_quality.py:
--------------------------------------------------------------------------------
1 | from datasets import load_dataset
2 |
3 | datasets = load_dataset('../contractnli-dataset')
4 |
5 | print()
--------------------------------------------------------------------------------
/models/big_bird/__init__.py:
--------------------------------------------------------------------------------
1 | from .modeling_big_bird import BigBirdModelForSentenceClassification
2 | from .tokenization_big_bird import BigbirdTokenizer
3 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=1.11.0
2 | transformers>=4.18.0
3 | datasets>=2.0.0
4 | tokenizers>=0.11.0
5 | scikit-learn>=1.0.0
6 | accelerate>=0.19.0
7 | tqdm>=4.62.0
8 | nltk>=3.7.0
--------------------------------------------------------------------------------
/data/hat/special_tokens_map.json:
--------------------------------------------------------------------------------
1 | {"bos_token": "", "eos_token": "", "unk_token": "", "sep_token": "", "pad_token": "", "cls_token": "", "mask_token": ""}
--------------------------------------------------------------------------------
/data/roberta/special_tokens_map.json:
--------------------------------------------------------------------------------
1 | {"bos_token": "", "eos_token": "", "unk_token": "", "sep_token": "", "pad_token": "", "cls_token": "", "mask_token": ""}
--------------------------------------------------------------------------------
/data/roberta/tokenizer_config.json:
--------------------------------------------------------------------------------
1 | {"model_max_length": 512, "bos_token": "", "eos_token": "", "unk_token": "", "sep_token": "", "pad_token": "", "cls_token": "", "mask_token": "", "tokenizer_class": "PreTrainedTokenizerFast"}
--------------------------------------------------------------------------------
/models/longformer/__init__.py:
--------------------------------------------------------------------------------
1 | from .modeling_longformer import LongformerModelForSentenceClassification, LongformerModelForPreTraining, \
2 | LongformerModelForSequenceClassification, LongformerForMaskedLM, LongformerForMultipleChoice
3 | from .tokenization_longformer import LongformerTokenizer
4 |
--------------------------------------------------------------------------------
/data/check-data.py:
--------------------------------------------------------------------------------
1 | from datasets import load_dataset
2 |
3 | dataset = load_dataset('data/wikipedia-dataset', '20200501.en', data_dir='data/wikipedia-dataset')
4 |
5 | print(f'Train subset: {len(dataset["train"])}')
6 | print(f'Validation subset: {len(dataset["validation"])}')
7 | print(f'Test subset: {len(dataset["test"])}')
8 |
--------------------------------------------------------------------------------
/data/hat/tokenizer_config.json:
--------------------------------------------------------------------------------
1 | {"model_max_length": 4096, "bos_token": "", "eos_token": "", "unk_token": "", "sep_token": "", "pad_token": "", "cls_token": "", "mask_token": "", "tokenizer_class": "PreTrainedTokenizerFast", "auto_map": {"AutoTokenizer": ["tokenization_hat.HATTokenizer", "tokenization_hat.HATTokenizer"]}}
--------------------------------------------------------------------------------
/models/hat/__init__.py:
--------------------------------------------------------------------------------
1 | from .modelling_hat import HATForMaskedLM, HATForSequenceClassification, \
2 | HATModelForBoWPreTraining, HATModelForVICRegPreTraining, HATModelForSimCLRPreTraining, \
3 | HATModelForSequentialSentenceClassification, HATForMultipleChoice
4 | from .tokenization_hat import HATTokenizer
5 | from .configuration_hat import HATConfig
6 |
--------------------------------------------------------------------------------
/data/check_training_arguments.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import argparse
3 |
4 |
5 | def check_args():
6 | parser = argparse.ArgumentParser()
7 |
8 | # Required arguments
9 | parser.add_argument('--path', default=None)
10 | config = parser.parse_args()
11 |
12 | training_args = torch.load(config.path)
13 | print(training_args)
14 |
15 |
16 | if __name__ == '__main__':
17 | check_args()
18 |
--------------------------------------------------------------------------------
/data/roberta/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "RobertaForMaskedLM"
4 | ],
5 | "attention_probs_dropout_prob": 0.1,
6 | "bos_token_id": 1,
7 | "eos_token_id": 2,
8 | "hidden_act": "gelu",
9 | "hidden_dropout_prob": 0.1,
10 | "hidden_size": 768,
11 | "initializer_range": 0.02,
12 | "intermediate_size": 3072,
13 | "layer_norm_eps": 1e-05,
14 | "max_position_embeddings": 512,
15 | "model_type": "roberta",
16 | "num_attention_heads": 12,
17 | "num_hidden_layers": 12,
18 | "pad_token_id": 0,
19 | "type_vocab_size": 1,
20 | "vocab_size": 50000
21 | }
--------------------------------------------------------------------------------
/data/c4-dataset/transform-data.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 |
4 | import tqdm
5 | from datasets import load_dataset
6 |
7 |
8 | def create_eval_streamable_wiki():
9 | for subset, size in zip(['train', 'validation'], [60000, 20000]):
10 | dataset = load_dataset('c4', 'en', split=subset, streaming=True)
11 | shuffled_dataset = dataset.shuffle(seed=42, buffer_size=10_000)
12 | valid_samples = []
13 | count = 0
14 | for sample in tqdm.tqdm(iter(shuffled_dataset)):
15 | if 4000 > len(sample['text'].split()) > 1000:
16 | valid_samples.append(sample['text'])
17 | count += 1
18 | if count >= 2000000:
19 | break
20 |
21 | print(f'TOTAL SAMPLES MEET CRITERIA: {len(valid_samples)}')
22 | samples = random.sample(valid_samples, k=size)
23 |
24 | with open(f'evaluation_{subset}.jsonl', 'w') as file:
25 | for sample in samples:
26 | file.write(json.dumps({'text': sample})+'\n')
27 |
28 |
29 | if __name__ == '__main__':
30 | create_eval_streamable_wiki()
31 |
--------------------------------------------------------------------------------
/running_scripts/fine_tuning/train_sentence_order_prediction.sh:
--------------------------------------------------------------------------------
1 | export WANDB_PROJECT="HATs-eval"
2 | export PYTHONPATH=.
3 |
4 | MODEL_NAME='hi-transformer-p1-roberta-mlm'
5 | MODEL_MAX_LENGTH=4096
6 | MAX_SENTENCES=32
7 |
8 | python evaluation/run_sentence_order.py \
9 | --model_name_or_path data/PLMs/${MODEL_NAME} \
10 | --dataset_name ./data/wikipedia-dataset \
11 | --dataset_config_name eval.en \
12 | --do_train \
13 | --do_eval \
14 | --do_predict \
15 | --output_dir data/PLMs/${MODEL_NAME}-sop \
16 | --overwrite_output_dir \
17 | --evaluation_strategy epoch \
18 | --save_strategy epoch \
19 | --num_train_epochs 20 \
20 | --load_best_model_at_end \
21 | --metric_for_best_model accuracy_score \
22 | --greater_is_better True \
23 | --save_total_limit 5 \
24 | --learning_rate 1e-5 \
25 | --per_device_train_batch_size 32 \
26 | --per_device_eval_batch_size 32 \
27 | --lr_scheduler_type linear \
28 | --warmup_ratio 0.05 \
29 | --max_seq_length ${MODEL_MAX_LENGTH} \
30 | --max_sentences ${MAX_SENTENCES} \
31 | --pad_to_max_length
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Ilias Chalkidis
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/running_scripts/fine_tuning/train_masked_sentence_mcqa.sh:
--------------------------------------------------------------------------------
1 | export WANDB_PROJECT="HATs-eval"
2 | export PYTHONPATH=.
3 |
4 | MODEL_NAME='hat-p1-roberta-mlm'
5 | MODEL_MAX_LENGTH=4096
6 | MAX_SENTENCES=32
7 |
8 | python evaluation/run_masked_sentence_prediction.py \
9 | --model_name_or_path data/PLMs/${MODEL_NAME} \
10 | --dataset_name ./data/wikipedia-dataset \
11 | --dataset_config_name eval.en \
12 | --do_train \
13 | --do_eval \
14 | --do_predict \
15 | --output_dir data/PLMs/${MODEL_NAME}-mcqa-sbert \
16 | --overwrite_output_dir \
17 | --evaluation_strategy epoch \
18 | --save_strategy epoch \
19 | --num_train_epochs 20 \
20 | --load_best_model_at_end \
21 | --metric_for_best_model accuracy_score \
22 | --greater_is_better True \
23 | --save_total_limit 5 \
24 | --learning_rate 1e-5 \
25 | --per_device_train_batch_size 32 \
26 | --per_device_eval_batch_size 32 \
27 | --lr_scheduler_type linear \
28 | --warmup_ratio 0.05 \
29 | --max_seq_length ${MODEL_MAX_LENGTH} \
30 | --max_sentences ${MAX_SENTENCES} \
31 | --pad_to_max_length \
32 | --sentence_bert_path all-MiniLM-L6-v2
--------------------------------------------------------------------------------
/running_scripts/fine_tuning/train_document_classification.sh:
--------------------------------------------------------------------------------
1 | export WANDB_PROJECT="HATs-eval"
2 | export PYTHONPATH=.
3 |
4 | MODEL_NAME='hat-s1-grouped-mlm'
5 | POOLING_METHOD='max'
6 | MODEL_MAX_LENGTH=1024
7 | MAX_SENTENCES=8
8 |
9 | python evaluation/run_document_classification.py \
10 | --model_name_or_path data/PLMs/${MODEL_NAME} \
11 | --pooling ${POOLING_METHOD} \
12 | --dataset_name multi_eurlex \
13 | --dataset_config_name en \
14 | --do_train \
15 | --do_eval \
16 | --do_predict \
17 | --output_dir data/PLMs/${MODEL_NAME}-${POOLING_METHOD}-dc \
18 | --overwrite_output_dir \
19 | --evaluation_strategy epoch \
20 | --save_strategy epoch \
21 | --num_train_epochs 20 \
22 | --load_best_model_at_end \
23 | --metric_for_best_model micro_f1 \
24 | --greater_is_better True \
25 | --save_total_limit 5 \
26 | --learning_rate 1e-5 \
27 | --per_device_train_batch_size 32 \
28 | --per_device_eval_batch_size 32 \
29 | --lr_scheduler_type linear \
30 | --warmup_ratio 0.05 \
31 | --max_seq_length ${MODEL_MAX_LENGTH} \
32 | --max_sentences ${MAX_SENTENCES} \
33 | --pad_to_max_length
--------------------------------------------------------------------------------
/running_scripts/fine_tuning/train_ecthr_classification.sh:
--------------------------------------------------------------------------------
1 | export WANDB_PROJECT="HATs-eval"
2 | export PYTHONPATH=.
3 |
4 | MODEL_NAME='allenai/longformer-base-4096'
5 | POOLING_METHOD='max'
6 | MODEL_MAX_LENGTH=4096
7 | MAX_SENTENCES=32
8 |
9 | python evaluation/run_document_classification.py \
10 | --model_name_or_path ${MODEL_NAME} \
11 | --pooling ${POOLING_METHOD} \
12 | --dataset_name lex_glue \
13 | --dataset_config_name ecthr_b \
14 | --do_train \
15 | --do_eval \
16 | --do_predict \
17 | --output_dir data/PLMs/${MODEL_NAME}-${POOLING_METHOD}-ecthr \
18 | --overwrite_output_dir \
19 | --evaluation_strategy epoch \
20 | --save_strategy epoch \
21 | --num_train_epochs 20 \
22 | --load_best_model_at_end \
23 | --metric_for_best_model micro_f1 \
24 | --greater_is_better True \
25 | --save_total_limit 5 \
26 | --learning_rate 1e-5 \
27 | --per_device_train_batch_size 32 \
28 | --per_device_eval_batch_size 32 \
29 | --lr_scheduler_type linear \
30 | --warmup_ratio 0.05 \
31 | --max_seq_length ${MODEL_MAX_LENGTH} \
32 | --max_sentences ${MAX_SENTENCES} \
33 | --pad_to_max_length
--------------------------------------------------------------------------------
/running_scripts/fine_tuning/train_mimic_classification.sh:
--------------------------------------------------------------------------------
1 | export WANDB_PROJECT="HATs-eval"
2 | export PYTHONPATH=.
3 |
4 | MODEL_NAME='allenai/longformer-base-4096'
5 | POOLING_METHOD='max'
6 | MODEL_MAX_LENGTH=4096
7 | MAX_SENTENCES=32
8 |
9 | python evaluation/run_document_classification.py \
10 | --model_name_or_path ${MODEL_NAME} \
11 | --pooling ${POOLING_METHOD} \
12 | --dataset_name data/mimic-dataset \
13 | --dataset_config_name mimic \
14 | --do_train \
15 | --do_eval \
16 | --do_predict \
17 | --output_dir data/PLMs/${MODEL_NAME}-${POOLING_METHOD}-mimic \
18 | --overwrite_output_dir \
19 | --evaluation_strategy epoch \
20 | --save_strategy epoch \
21 | --num_train_epochs 20 \
22 | --load_best_model_at_end \
23 | --metric_for_best_model micro_f1 \
24 | --greater_is_better True \
25 | --save_total_limit 5 \
26 | --learning_rate 1e-5 \
27 | --per_device_train_batch_size 8 \
28 | --per_device_eval_batch_size 8 \
29 | --lr_scheduler_type linear \
30 | --warmup_ratio 0.05 \
31 | --max_seq_length ${MODEL_MAX_LENGTH} \
32 | --max_sentences ${MAX_SENTENCES} \
33 | --pad_to_max_length
--------------------------------------------------------------------------------
/running_scripts/fine_tuning/train_contract_nli.sh:
--------------------------------------------------------------------------------
1 | export WANDB_PROJECT="HATs-eval"
2 | export PYTHONPATH=.
3 |
4 | MODEL_NAME='allenai/longformer-base-4096'
5 | POOLING_METHOD='last'
6 | MODEL_MAX_LENGTH=4096
7 | MAX_SENTENCES=32
8 |
9 | python evaluation/run_document_nli.py \
10 | --model_name_or_path ${MODEL_NAME} \
11 | --pooling ${POOLING_METHOD} \
12 | --dataset_name data/contractnli-dataset \
13 | --dataset_config_name contractnli \
14 | --do_train \
15 | --do_eval \
16 | --do_predict \
17 | --output_dir data/PLMs/${MODEL_NAME}-${POOLING_METHOD}-contractnli \
18 | --overwrite_output_dir \
19 | --evaluation_strategy epoch \
20 | --save_strategy epoch \
21 | --num_train_epochs 20 \
22 | --load_best_model_at_end \
23 | --metric_for_best_model micro_f1 \
24 | --greater_is_better True \
25 | --save_total_limit 5 \
26 | --learning_rate 1e-5 \
27 | --per_device_train_batch_size 32 \
28 | --per_device_eval_batch_size 32 \
29 | --lr_scheduler_type linear \
30 | --warmup_ratio 0.05 \
31 | --max_seq_length ${MODEL_MAX_LENGTH} \
32 | --max_sentences ${MAX_SENTENCES} \
33 | --pad_to_max_length
--------------------------------------------------------------------------------
/running_scripts/fine_tuning/train_ecthr_arguments_classification.sh:
--------------------------------------------------------------------------------
1 | export WANDB_PROJECT="HATs-eval"
2 | export PYTHONPATH=.
3 | export CUDA_VISIBLE_DEVICES=1,2,3,4
4 |
5 | MODEL_NAME='google/bigbird-roberta-base'
6 | MODEL_MAX_LENGTH=4096
7 | MAX_SENTENCES=32
8 |
9 | python evaluation/run_sequential_sentence_classification.py \
10 | --dataset_name ../data/ecthr-arguments-dataset \
11 | --dataset_config_name ecthr-arguments-dataset \
12 | --model_name_or_path ${MODEL_NAME} \
13 | --do_train \
14 | --do_eval \
15 | --do_predict \
16 | --output_dir data/PLMs/${MODEL_NAME}-ecthr-args \
17 | --overwrite_output_dir \
18 | --evaluation_strategy epoch \
19 | --save_strategy epoch \
20 | --num_train_epochs 20 \
21 | --load_best_model_at_end \
22 | --metric_for_best_model micro-f1 \
23 | --greater_is_better True \
24 | --save_total_limit 5 \
25 | --learning_rate 3e-5 \
26 | --per_device_train_batch_size 2 \
27 | --per_device_eval_batch_size 2 \
28 | --lr_scheduler_type linear \
29 | --warmup_ratio 0.05 \
30 | --max_seq_length ${MODEL_MAX_LENGTH} \
31 | --max_sentences ${MAX_SENTENCES} \
32 | --pad_to_max_length
33 |
--------------------------------------------------------------------------------
/running_scripts/pre_training/train_longformer.sh:
--------------------------------------------------------------------------------
1 | export WANDB_PROJECT="HATs-pretrain"
2 | export PYTHONPATH=.
3 |
4 | MODEL_MAX_LENGTH=1024
5 | MAX_SENTENCES=8
6 |
7 | python models/longformer/convert_bert_to_lf.py --max_sentences ${MAX_SENTENCES}
8 |
9 | python language_modelling/run_mlm_stream.py \
10 | --model_name_or_path data/PLMs/longformer \
11 | --dataset_name ./data/wikipedia-dataset \
12 | --dataset_config_name 20200501.en \
13 | --do_train \
14 | --do_eval \
15 | --output_dir data/PLMs/longformer-global-mlm \
16 | --overwrite_output_dir \
17 | --logging_steps 500 \
18 | --evaluation_strategy steps \
19 | --eval_steps 10000 \
20 | --save_strategy steps \
21 | --save_steps 10000 \
22 | --save_total_limit 5 \
23 | --max_steps 50000 \
24 | --learning_rate 1e-4 \
25 | --per_device_train_batch_size 32 \
26 | --per_device_eval_batch_size 32 \
27 | --gradient_accumulation_steps 4 \
28 | --eval_accumulation_steps 4 \
29 | --lr_scheduler_type linear \
30 | --warmup_ratio 0.10 \
31 | --weight_decay 0.01 \
32 | --mlm_probability 0.15 \
33 | --max_seq_length ${MODEL_MAX_LENGTH} \
34 | --line_by_line \
35 | --pad_to_max_length
--------------------------------------------------------------------------------
/running_scripts/fine_tuning/train_quality_mcqa.sh:
--------------------------------------------------------------------------------
1 | export WANDB_PROJECT="HATs-eval"
2 | export PYTHONPATH=.
3 |
4 | MODEL_NAME='allenai/longformer-base-4096'
5 | POOLING_METHOD='last'
6 | MODEL_MAX_LENGTH=4096
7 | MAX_SENTENCES=32
8 |
9 | python evaluation/run_quality_mcqa.py \
10 | --model_name_or_path ${MODEL_NAME} \
11 | --pooling ${POOLING_METHOD} \
12 | --dataset_name data/quality-dataset \
13 | --dataset_config_name quality \
14 | --do_train \
15 | --do_eval \
16 | --do_predict \
17 | --output_dir data/PLMs/${MODEL_NAME}-${POOLING_METHOD}-quality \
18 | --overwrite_output_dir \
19 | --evaluation_strategy epoch \
20 | --save_strategy epoch \
21 | --num_train_epochs 20 \
22 | --load_best_model_at_end \
23 | --metric_for_best_model accuracy_score \
24 | --greater_is_better True \
25 | --save_total_limit 5 \
26 | --learning_rate 1e-5 \
27 | --per_device_train_batch_size 2 \
28 | --per_device_eval_batch_size 2 \
29 | --gradient_accumulation_steps 4 \
30 | --eval_accumulation_steps 4 \
31 | --lr_scheduler_type linear \
32 | --warmup_ratio 0.05 \
33 | --max_seq_length ${MODEL_MAX_LENGTH} \
34 | --max_sentences ${MAX_SENTENCES} \
35 | --pad_to_max_length
--------------------------------------------------------------------------------
/running_scripts/pre_training/train_longformer_base.sh:
--------------------------------------------------------------------------------
1 | export WANDB_PROJECT="HATs-pretrain"
2 | export PYTHONPATH=.
3 |
4 | MODEL_MAX_LENGTH=4096
5 | MAX_SENTENCES=32
6 |
7 | python models/longformer/convert_roberta_to_lf.py --max_sentences ${MAX_SENTENCES} --num_hidden_layers 12
8 |
9 | python language_modelling/run_mlm_stream.py \
10 | --model_name_or_path data/PLMs/longformer-roberta \
11 | --dataset_name c4 \
12 | --dataset_config_name en \
13 | --do_train \
14 | --do_eval \
15 | --output_dir data/PLMs/longformer-roberta-mlm \
16 | --overwrite_output_dir \
17 | --logging_steps 500 \
18 | --evaluation_strategy steps \
19 | --eval_steps 10000 \
20 | --save_strategy steps \
21 | --save_steps 10000 \
22 | --save_total_limit 5 \
23 | --max_steps 50000 \
24 | --learning_rate 1e-4 \
25 | --per_device_train_batch_size 8 \
26 | --per_device_eval_batch_size 8 \
27 | --gradient_accumulation_steps 16 \
28 | --eval_accumulation_steps 16 \
29 | --lr_scheduler_type linear \
30 | --warmup_ratio 0.10 \
31 | --weight_decay 0.01 \
32 | --mlm_probability 0.15 \
33 | --max_seq_length ${MODEL_MAX_LENGTH} \
34 | --max_sentences ${MAX_SENTENCES} \
35 | --min_sequence_length 1024 \
36 | --pad_to_max_length \
37 | --max_eval_samples 100000
--------------------------------------------------------------------------------
/evaluation/multi_label_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.special import expit
3 |
4 |
5 | def fix_multi_label_scores(original_predictions, original_label_ids, unpad_sequences=False, flatten_sequences=False):
6 |
7 | if flatten_sequences:
8 | predictions = original_predictions.reshape((-1, original_predictions.shape[-1]))
9 | label_ids = original_label_ids.reshape((-1, original_predictions.shape[-1]))
10 | else:
11 | predictions = original_predictions
12 | label_ids = original_label_ids
13 |
14 | if unpad_sequences:
15 | predictions = np.asarray([pred for pred, label in zip(predictions, label_ids) if label[0] != -1])
16 | label_ids = np.asarray([label for label in label_ids if label[0] != -1])
17 |
18 | # Fix gold labels
19 | y_true = np.zeros((len(label_ids), len(label_ids[0]) + 1), dtype=np.int32)
20 | y_true[:, :-1] = label_ids
21 | y_true[:, -1] = (np.sum(label_ids, axis=1) == 0).astype('int32')
22 | # Fix predictions
23 | logits = predictions[0] if isinstance(predictions, tuple) else predictions
24 | preds = (expit(logits) > 0.5).astype('int32')
25 | y_pred = np.zeros((len(label_ids), len(label_ids[0]) + 1), dtype=np.int32)
26 | y_pred[:, :-1] = preds
27 | y_pred[:, -1] = (np.sum(preds, axis=1) == 0).astype('int32')
28 |
29 | return y_true, y_pred
30 |
--------------------------------------------------------------------------------
/running_scripts/pre_training/train_hat.sh:
--------------------------------------------------------------------------------
1 | export WANDB_PROJECT="HATs-pretrain"
2 | export XRT_TPU_CONFIG="localservice;0;localhost:51011"
3 | export PYTHONPATH=.
4 |
5 | LAYOUT='s1'
6 | MODEL_WARMUP_STRATEGY='grouped'
7 | MODEL_MAX_LENGTH=1024
8 | MAX_SENTENCES=8
9 |
10 | python3 models/hat/convert_bert_to_htf.py --layout ${LAYOUT} --max_sentences ${MAX_SENTENCES}
11 |
12 | python3 language_modelling/xla_spawn.py --num_cores=8 language_modelling/run_mlm_stream.py \
13 | --model_name_or_path data/PLMs/hi-transformer-${LAYOUT}-${MODEL_WARMUP_STRATEGY} \
14 | --dataset_name c4 \
15 | --dataset_config_name en \
16 | --do_train \
17 | --do_eval \
18 | --output_dir data/PLMs/hi-transformer-${LAYOUT}-${MODEL_WARMUP_STRATEGY}-mlm \
19 | --overwrite_output_dir \
20 | --logging_steps 500 \
21 | --evaluation_strategy steps \
22 | --eval_steps 10000 \
23 | --save_strategy steps \
24 | --save_steps 10000 \
25 | --save_total_limit 5 \
26 | --max_steps 50000 \
27 | --learning_rate 1e-4 \
28 | --per_device_train_batch_size 4 \
29 | --per_device_eval_batch_size 4 \
30 | --gradient_accumulation_steps 4 \
31 | --eval_accumulation_steps 4 \
32 | --lr_scheduler_type linear \
33 | --warmup_ratio 0.10 \
34 | --weight_decay 0.01 \
35 | --mlm_probability 0.15 \
36 | --max_seq_length ${MODEL_MAX_LENGTH} \
37 | --line_by_line \
38 | --pad_to_max_length
--------------------------------------------------------------------------------
/running_scripts/pre_training/train_hat_base.sh:
--------------------------------------------------------------------------------
1 | export WANDB_PROJECT="HATs-pretrain"
2 | export XRT_TPU_CONFIG="localservice;0;localhost:51011"
3 | export PYTHONPATH=.
4 |
5 | LAYOUT='p1'
6 | MODEL_MAX_LENGTH=4096
7 | MAX_SENTENCES=32
8 |
9 | python3 models/hat/convert_roberta_to_htf.py --layout ${LAYOUT} --max_sentences ${MAX_SENTENCES}
10 |
11 | python3 language_modelling/xla_spawn.py --num_cores=8 language_modelling/run_mlm_stream.py \
12 | --model_name_or_path data/PLMs/hat-${LAYOUT}-roberta \
13 | --dataset_name c4 \
14 | --dataset_config_name en \
15 | --do_train \
16 | --do_eval \
17 | --output_dir data/PLMs/hi-transformer-${LAYOUT}-roberta-mlm \
18 | --overwrite_output_dir \
19 | --logging_steps 500 \
20 | --evaluation_strategy steps \
21 | --eval_steps 10000 \
22 | --save_strategy steps \
23 | --save_steps 10000 \
24 | --save_total_limit 5 \
25 | --max_steps 50000 \
26 | --learning_rate 1e-4 \
27 | --per_device_train_batch_size 4 \
28 | --per_device_eval_batch_size 4 \
29 | --gradient_accumulation_steps 4 \
30 | --eval_accumulation_steps 4 \
31 | --lr_scheduler_type linear \
32 | --warmup_ratio 0.10 \
33 | --weight_decay 0.01 \
34 | --mlm_probability 0.15 \
35 | --max_seq_length ${MODEL_MAX_LENGTH} \
36 | --max_sentences ${MAX_SENTENCES} \
37 | --min_sequence_length 1024 \
38 | --pad_to_max_length \
39 | --max_eval_samples 100000
--------------------------------------------------------------------------------
/language_modelling/prepare_dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tqdm
3 | import matplotlib.pyplot as plt
4 |
5 | from datasets import load_dataset
6 | import argparse
7 |
8 |
9 | def main():
10 | """ set default hyperparams in default_hyperparams.py """
11 | parser = argparse.ArgumentParser()
12 |
13 | # Required arguments
14 | parser.add_argument('--dataset_name', default='wikipedia')
15 | parser.add_argument('--dataset_config', default='20200501.en')
16 | config = parser.parse_args()
17 |
18 | # load datasets
19 | dataset = load_dataset(config.dataset_name, config.dataset_config, split='train')
20 |
21 | text_length = []
22 | for text in tqdm.tqdm(dataset['text']):
23 | text_length.append(len(text.split()))
24 |
25 | # reduce to truncated size, max 4096
26 | text_length = [x if x <= 4000 else 4096 for x in text_length]
27 |
28 | print(f'AVG: {np.mean(text_length):.1f} ± {np.std(text_length):.1f}, MAX: {np.max(text_length):.1f}')
29 |
30 | # print stats in percentiles
31 | for min_size in [512, 1024, 2048]:
32 | n_docs = len([1 for x in text_length if x >= min_size])
33 | perc = (n_docs * 100) / len(text_length)
34 | print(f'No of document over {min_size} words: {n_docs}/{len(text_length)} ({perc:.1f}%)')
35 |
36 | # plot document length histogram
37 | plt.hist(text_length, range=(500, 4096), bins=50)
38 | plt.savefig(f'{config.dataset_name}_hist.png')
39 |
40 |
41 | if __name__ == '__main__':
42 | main()
43 |
--------------------------------------------------------------------------------
/evaluation/data_collator.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from transformers.tokenization_utils_base import PreTrainedTokenizerBase
3 |
4 | @dataclass
5 | class DataCollatorForMultiLabelClassification:
6 | tokenizer: PreTrainedTokenizerBase
7 |
8 | def __call__(self, features):
9 | import torch
10 | first = features[0]
11 | batch = {}
12 |
13 | # Special handling for labels.
14 | # Ensure that tensor is created with the correct type
15 | # (it should be automatically the case, but let's make sure of it.)
16 | if "label_ids" in first and first["label_ids"] is not None:
17 | if isinstance(first["labels"], torch.Tensor):
18 | batch["labels"] = torch.stack([f["label_ids"] for f in features])
19 | else:
20 | dtype = torch.long if type(first["label_ids"][0]) is int else torch.float
21 | batch["labels"] = torch.tensor([f["label_ids"] for f in features], dtype=dtype)
22 |
23 | # Handling of all other possible keys.
24 | # Again, we will use the first element to figure out which key/values are not None for this model.
25 | for k, v in first.items():
26 | if k not in ("label", "label_ids", "labels") and v is not None and not isinstance(v, str):
27 | if isinstance(v, torch.Tensor):
28 | batch[k] = torch.stack([f[k] for f in features])
29 | else:
30 | batch[k] = torch.tensor([f[k] for f in features])
31 |
32 | return batch
33 |
--------------------------------------------------------------------------------
/data/hat/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "HATForMaskedLM"
4 | ],
5 | "attention_probs_dropout_prob": 0.1,
6 | "bos_token_id": 1,
7 | "do_sample": false,
8 | "eos_token_id": 2,
9 | "finetuning_task": null,
10 | "hidden_act": "gelu",
11 | "hidden_dropout_prob": 0.1,
12 | "hidden_size": 256,
13 | "id2label": {
14 | "0": "LABEL_0",
15 | "1": "LABEL_1"
16 | },
17 | "initializer_range": 0.02,
18 | "intermediate_size": 512,
19 | "is_decoder": false,
20 | "label2id": {
21 | "LABEL_0": 0,
22 | "LABEL_1": 1
23 | },
24 | "layer_norm_eps": 1e-12,
25 | "length_penalty": 1.0,
26 | "max_sentence_length": 128,
27 | "max_sentences": 64,
28 | "model_max_length": 8192,
29 | "max_position_embeddings": 128,
30 | "model_type": "hierarchical-transformer",
31 | "num_attention_heads": 8,
32 | "num_beams": 1,
33 | "num_hidden_layers": 6,
34 | "encoder_layout": {
35 | "0": {"sentence_encoder": true, "document_encoder": false},
36 | "1": {"sentence_encoder": true, "document_encoder": false},
37 | "2": {"sentence_encoder": true, "document_encoder": true},
38 | "3": {"sentence_encoder": true, "document_encoder": false},
39 | "4": {"sentence_encoder": true, "document_encoder": false},
40 | "5": {"sentence_encoder": true, "document_encoder": true}},
41 | "num_labels": 2,
42 | "num_return_sequences": 1,
43 | "output_attentions": false,
44 | "output_hidden_states": false,
45 | "output_past": true,
46 | "pad_token_id": 0,
47 | "pruned_heads": {},
48 | "repetition_penalty": 1.0,
49 | "temperature": 1.0,
50 | "top_k": 50,
51 | "top_p": 1.0,
52 | "torchscript": false,
53 | "type_vocab_size": 2,
54 | "use_bfloat16": false,
55 | "vocab_size": 50000,
56 | "parameters": 136350720
57 | }
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # PyCharm files
132 | .idea
133 |
134 | # Log files
135 | logs/
136 |
--------------------------------------------------------------------------------
/language_modelling/xla_spawn.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | A simple launcher script for TPU training
16 |
17 | Inspired by https://github.com/pytorch/pytorch/blob/master/torch/distributed/launch.py
18 |
19 | ::
20 | >>> python xla_spawn.py --num_cores=NUM_CORES_YOU_HAVE
21 | YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 and all other
22 | arguments of your training script)
23 |
24 | """
25 |
26 |
27 | import importlib
28 | import sys
29 | from argparse import REMAINDER, ArgumentParser
30 | from pathlib import Path
31 |
32 | import torch_xla.distributed.xla_multiprocessing as xmp
33 |
34 |
35 | def parse_args():
36 | """
37 | Helper function parsing the command line options
38 | @retval ArgumentParser
39 | """
40 | parser = ArgumentParser(
41 | description=(
42 | "PyTorch TPU distributed training launch "
43 | "helper utility that will spawn up "
44 | "multiple distributed processes"
45 | )
46 | )
47 |
48 | # Optional arguments for the launch helper
49 | parser.add_argument("--num_cores", type=int, default=1, help="Number of TPU cores to use (1 or 8).")
50 |
51 | # positional
52 | parser.add_argument(
53 | "training_script",
54 | type=str,
55 | help=(
56 | "The full path to the single TPU training "
57 | "program/script to be launched in parallel, "
58 | "followed by all the arguments for the "
59 | "training script"
60 | ),
61 | )
62 |
63 | # rest from the training program
64 | parser.add_argument("training_script_args", nargs=REMAINDER)
65 |
66 | return parser.parse_args()
67 |
68 |
69 | def main():
70 | args = parse_args()
71 |
72 | # Import training_script as a module.
73 | script_fpath = Path(args.training_script)
74 | sys.path.append(str(script_fpath.parent.resolve()))
75 | mod_name = script_fpath.stem
76 | mod = importlib.import_module(mod_name)
77 |
78 | # Patch sys.argv
79 | sys.argv = [args.training_script] + args.training_script_args + ["--tpu_num_cores", str(args.num_cores)]
80 |
81 | xmp.spawn(mod._mp_fn, args=(), nprocs=args.num_cores)
82 |
83 |
84 | if __name__ == "__main__":
85 | main()
--------------------------------------------------------------------------------
/language_modelling/train_tokenizer.py:
--------------------------------------------------------------------------------
1 | from tokenizers import models, normalizers, pre_tokenizers, decoders, processors, trainers
2 | from tokenizers import Tokenizer
3 | from datasets import load_dataset
4 | from transformers import PreTrainedTokenizerFast, AutoTokenizer
5 | import argparse
6 |
7 | CUSTOM_TOK_FOLDER = '../data/custom-tokenizer'
8 | hat_FOLDER = '../data/hi-transformer'
9 | ROBERTA_FOLDER = '../data/roberta'
10 |
11 |
12 | def batch_iterator(dataset):
13 | for example in dataset['text']:
14 | yield example
15 |
16 |
17 | def main():
18 | """ set default hyperparams in default_hyperparams.py """
19 | parser = argparse.ArgumentParser()
20 |
21 | # Required arguments
22 | parser.add_argument('--vocab_size', default=50000)
23 | config = parser.parse_args()
24 |
25 | # configure tokenizer
26 | backend_tokenizer = Tokenizer(models.BPE(unk_token=""))
27 | backend_tokenizer.normalizer = normalizers.Lowercase()
28 | backend_tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=True)
29 | backend_tokenizer.decoder = decoders.ByteLevel()
30 | backend_tokenizer.post_processor = processors.RobertaProcessing(sep=("", 2), cls=("", 1),
31 | add_prefix_space=True, trim_offsets=True)
32 |
33 | trainer = trainers.BpeTrainer(
34 | vocab_size=config.vocab_size,
35 | min_frequency=2,
36 | special_tokens=["", "", "", "", ""],
37 | show_progress=True
38 | )
39 |
40 | # load datasets
41 | dataset = load_dataset("multi_eurlex", "en", split='train')
42 |
43 | # train tokenizer
44 | backend_tokenizer.train_from_iterator(trainer=trainer, iterator=batch_iterator(dataset))
45 |
46 | # test tokenizer
47 | tokens = backend_tokenizer.encode('dog ' * 5, add_special_tokens=False)
48 | print('Original Tokenizer: ', tokens.tokens)
49 |
50 | # save tokenizer
51 | new_roberta_tokenizer = PreTrainedTokenizerFast(
52 | tokenizer_object=backend_tokenizer,
53 | model_max_length=512,
54 | # padding_side="Set me if you want",
55 | # truncation_side="Set me if you want",
56 | # model_input_names="Set me if you want",
57 | bos_token='',
58 | eos_token='',
59 | unk_token='',
60 | sep_token='',
61 | pad_token='',
62 | cls_token='',
63 | mask_token='',
64 | )
65 |
66 | new_hat_tokenizer = PreTrainedTokenizerFast(
67 | tokenizer_object=backend_tokenizer,
68 | model_max_length=8192,
69 | # padding_side="Set me if you want",
70 | # truncation_side="Set me if you want",
71 | # model_input_names="Set me if you want",
72 | bos_token='',
73 | eos_token='',
74 | unk_token='',
75 | sep_token='',
76 | pad_token='',
77 | cls_token='',
78 | mask_token='',
79 | )
80 |
81 | new_roberta_tokenizer.save_pretrained(CUSTOM_TOK_FOLDER)
82 | new_roberta_tokenizer.save_pretrained(ROBERTA_FOLDER)
83 | new_hat_tokenizer.save_pretrained(hat_FOLDER)
84 |
85 | # re-load tokenizer and test
86 | reloaded_tokenizer = AutoTokenizer.from_pretrained(CUSTOM_TOK_FOLDER)
87 | tokens = reloaded_tokenizer.tokenize('dog ' * 5, add_special_tokens=False)
88 | print('Reloaded Tokenizer: ', tokens)
89 |
90 |
91 | if __name__ == '__main__':
92 | main()
93 |
--------------------------------------------------------------------------------
/language_modelling/text_featurization.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import random
3 |
4 | import nltk
5 | import tqdm
6 | from sklearn.feature_extraction.text import TfidfVectorizer
7 | from sklearn.decomposition import PCA
8 | import warnings
9 | warnings.filterwarnings("ignore")
10 |
11 |
12 | def train_text_featurizer(documents, tokenizer_path='google/bert_uncased_L-6_H-256_A-4', hidden_units=768):
13 |
14 | def tokenize(document: str):
15 | return tokenizer.tokenize(document,
16 | padding=False,
17 | truncation=True,
18 | max_length=1024)
19 |
20 | # init tokenizer
21 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, model_max_length=1024)
22 |
23 | # init tfidf vectorizer
24 | vocab = [(key, value) for (key, value) in tokenizer.vocab.items()]
25 | vocab = sorted(vocab, key=lambda tup: tup[1])
26 | vocab = [key for (key, value) in vocab]
27 | tfidf_vectorizer = TfidfVectorizer(lowercase=False, preprocessor=None, tokenizer=tokenize,
28 | vocabulary=vocab)
29 | pca_solver = PCA(n_components=hidden_units)
30 |
31 | tfidf_scores = tfidf_vectorizer.fit_transform(documents)
32 | print('TFIDF-VECTORIZER DONE!')
33 |
34 | pca_solver.fit(tfidf_scores.toarray())
35 | print('PCA SOLVER DONE!')
36 |
37 | with open('./data/wikipedia-dataset/tifidf_vectorizer.pkl', 'wb') as fin:
38 | pickle.dump(tfidf_vectorizer, fin)
39 | print('TFIDF-VECTORIZER SAVED!')
40 |
41 | with open('./data/wikipedia-dataset/pca_solver.pkl', 'wb') as fin:
42 | pickle.dump(pca_solver, fin)
43 | print('PCA SOLVER SAVED!')
44 |
45 |
46 | def learn_idfs(documents, tokenizer_path='google/bert_uncased_L-6_H-256_A-4'):
47 |
48 | def tokenize(document: str):
49 | return tokenizer.tokenize(document)
50 |
51 | # init tokenizer
52 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, model_max_length=1024)
53 |
54 | # init tfidf vectorizer
55 | vocab = [(key, value) for (key, value) in tokenizer.vocab.items()]
56 | vocab = sorted(vocab, key=lambda tup: tup[1])
57 | vocab = [key for (key, value) in vocab]
58 | tfidf_vectorizer = TfidfVectorizer(lowercase=False, preprocessor=None, tokenizer=tokenize,
59 | vocabulary=vocab)
60 |
61 | tfidf_vectorizer.fit(documents)
62 |
63 | with open('./data/wikipedia-dataset/idf_scores.pkl', 'wb') as file:
64 | pickle.dump(tfidf_vectorizer.idf_, file)
65 |
66 |
67 | def embed_sentences(documents, model_path='all-MiniLM-L6-v2'):
68 | from sentence_transformers import SentenceTransformer
69 |
70 | # Define the model
71 | model = SentenceTransformer(model_path)
72 |
73 | # Start the multi-process pool on all available CUDA devices
74 | pool = model.start_multi_process_pool()
75 |
76 | # Sub-sample sentences
77 | grouped_sentences = []
78 | for document in documents:
79 | doc_sentences = nltk.sent_tokenize(' '.join(document.split()[:1024]))
80 | # Build grouped sentences up to 100 words
81 | temp_sentence = ''
82 | for doc_sentence in doc_sentences:
83 | if len(temp_sentence.split()) + len(doc_sentence.split()) <= 100:
84 | temp_sentence += ' ' + doc_sentence
85 | else:
86 | if len(temp_sentence):
87 | grouped_sentences.append(temp_sentence)
88 | temp_sentence = doc_sentence
89 | if len(temp_sentence):
90 | grouped_sentences.append(temp_sentence)
91 | del documents
92 |
93 | # Compute the embeddings using the multi-process pool
94 | sentence_embeddings = model.encode_multi_process(grouped_sentences, pool)
95 | print("Embeddings computed. Shape:", sentence_embeddings.shape)
96 |
97 | # Optional: Stop the proccesses in the pool
98 | model.stop_multi_process_pool(pool)
99 |
100 | with open('../data/wikipedia-dataset/sentence_embeddings.pkl', 'wb') as file:
101 | pickle.dump(sentence_embeddings, file)
102 |
103 | print('SENTENCE EMBEDDINGS DONE!')
104 |
105 |
106 | if __name__ == '__main__':
107 | from transformers import AutoTokenizer
108 | from datasets import load_dataset
109 |
110 | # load dataset
111 | dataset = load_dataset("lex_glue", "eurlex", split='train')
112 | dataset = dataset['text']
113 | subset = random.sample(range(len(dataset)), k=100)
114 | dataset_small = []
115 | for i in tqdm.tqdm(subset):
116 | dataset_small.append(dataset[i])
117 |
118 | # re-load tokenizer and test
119 | embed_sentences(documents=dataset_small)
120 |
121 |
--------------------------------------------------------------------------------
/data/mimic-dataset/mimic-dataset.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """MIMIC-III."""
16 |
17 |
18 | import json
19 |
20 | import datasets
21 |
22 |
23 | _CITATION = """@article{johnson-mit-2016-mimic-iii,
24 | author={Johnson, Alistair E W and Pollard, Tom J and Shen, Lu and Li-Wei, H Lehman and Feng, Mengling and Ghassemi,
25 | Mohammad and Moody, Benjamin and Szolovits, Peter and Celi, Leo Anthony and Mark, Roger G},
26 | year={2016},
27 | title={{MIMIC-III, a freely accessible critical care database}},
28 | journal={Sci. Data},
29 | volume={3},
30 | url={https://www.nature.com/articles/sdata201635.pdf}
31 | }"""
32 |
33 | _DESCRIPTION = """MIMIC"""
34 |
35 | _HOMEPAGE = "https://physionet.org/content/mimiciii/1.4/"
36 |
37 | _LICENSE = "CC BY-SA (Creative Commons / Attribution-ShareAlike)"
38 |
39 | _LABEL_NAMES = ['001-139', '140-239', '240-279', '280-289', '290-319', '320-389', '390-459', '460-519',
40 | '520-579', '580-629', '630-679', '680-709', '710-739', '740-759', '760-779', '780-799',
41 | '800-999', 'V01-V91', 'E000-E999']
42 |
43 | class MIMIC(datasets.GeneratorBasedBuilder):
44 | """MIMIC"""
45 |
46 | VERSION = datasets.Version("1.1.0")
47 |
48 | BUILDER_CONFIGS = [
49 | datasets.BuilderConfig(
50 | name="mimic", version=VERSION, description="MIMIC"
51 | ),
52 | ]
53 |
54 | DEFAULT_CONFIG_NAME = "mimic"
55 |
56 | def _info(self):
57 | features = datasets.Features(
58 | {
59 | "summary_id": datasets.Value("string"),
60 | "text": datasets.Value("string"),
61 | "labels": datasets.features.Sequence(datasets.ClassLabel(names=_LABEL_NAMES)),
62 | }
63 | )
64 | return datasets.DatasetInfo(
65 | # This is the description that will appear on the datasets page.
66 | description=_DESCRIPTION,
67 | # This defines the different columns of the dataset and their types
68 | features=features, # Here we define them above because they are different between the two configurations
69 | # If there's a common (input, target) tuple from the features,
70 | # specify them here. They'll be used if as_supervised=True in
71 | # builder.as_dataset.
72 | supervised_keys=None,
73 | # Homepage of the dataset for documentation
74 | homepage=_HOMEPAGE,
75 | # License for the dataset if available
76 | license=_LICENSE,
77 | # Citation for the dataset
78 | citation=_CITATION,
79 | )
80 |
81 | def _split_generators(self, dl_manager):
82 | """Returns SplitGenerators."""
83 | data_dir = dl_manager.download_and_extract('mimic.jsonl')
84 | return [
85 | datasets.SplitGenerator(
86 | name=datasets.Split.TRAIN,
87 | # These kwargs will be passed to _generate_examples
88 | gen_kwargs={
89 | "filepath": data_dir,
90 | "split": "train",
91 | },
92 | ),
93 | datasets.SplitGenerator(
94 | name=datasets.Split.TEST,
95 | # These kwargs will be passed to _generate_examples
96 | gen_kwargs={"filepath": data_dir,
97 | "split": "test"},
98 | ),
99 | datasets.SplitGenerator(
100 | name=datasets.Split.VALIDATION,
101 | # These kwargs will be passed to _generate_examples
102 | gen_kwargs={
103 | "filepath": data_dir,
104 | "split": "dev",
105 | },
106 | ),
107 | ]
108 |
109 | def _generate_examples(
110 | self, filepath, split # method parameters are unpacked from `gen_kwargs` as given in `_split_generators`
111 | ):
112 | """Yields examples as (key, example) tuples."""
113 |
114 | with open(filepath, encoding="utf-8") as f:
115 | for id_, row in enumerate(f):
116 | data = json.loads(row)
117 | if data['data_type'] == split:
118 | yield id_, {
119 | "summary_id": data["summary_id"],
120 | "text": data['text'],
121 | "labels": data["level_1"],
122 | }
--------------------------------------------------------------------------------
/data/c4-dataset/c4-dataset.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The TensorFlow Datasets Authors and the HuggingFace Datasets Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # Lint as: python3
17 | """English Wikipedia dataset containing cleaned articles."""
18 |
19 | import json
20 | import os
21 |
22 | import datasets
23 |
24 |
25 | logger = datasets.logging.get_logger(__name__)
26 |
27 |
28 | _CITATION = """\
29 | @ONLINE {wikidump,
30 | author = {Wikimedia Foundation},
31 | title = {Wikimedia Downloads},
32 | url = {https://dumps.wikimedia.org}
33 | }
34 | """
35 |
36 | _DESCRIPTION = """\
37 | Wikipedia dataset containing cleaned articles.
38 | The datasets are built from the Wikipedia dump
39 | (https://dumps.wikimedia.org/). Each example
40 | contains the content of one full Wikipedia article with cleaning to strip
41 | markdown and unwanted sections (references, etc.).
42 | """
43 |
44 | _LICENSE = (
45 | "This work is licensed under the Creative Commons Attribution-ShareAlike "
46 | "3.0 Unported License. To view a copy of this license, visit "
47 | "http://creativecommons.org/licenses/by-sa/3.0/ or send a letter to "
48 | "Creative Commons, PO Box 1866, Mountain View, CA 94042, USA."
49 | )
50 |
51 | _VERSION = datasets.Version("1.0.0", "")
52 |
53 |
54 | class C4Config(datasets.BuilderConfig):
55 | """BuilderConfig for Wikipedia."""
56 |
57 | def __init__(self, dump=None, version=_VERSION, **kwargs):
58 | """BuilderConfig for Wikipedia.
59 |
60 | Args:
61 | language: string, the language code for the Wikipedia dump to use.
62 | date: string, date of the Wikipedia dump in YYYYMMDD format. A list of
63 | available dates can be found at https://dumps.wikimedia.org/enwiki/.
64 | **kwargs: keyword arguments forwarded to super.
65 | """
66 | super().__init__(
67 | name=f"{dump}",
68 | description=f"C4 dataset for {dump} dump.",
69 | version=version,
70 | **kwargs,
71 | )
72 | self.dump = dump
73 |
74 |
75 | _DATE = "20220301"
76 |
77 |
78 | class C4(datasets.GeneratorBasedBuilder):
79 | """Wikipedia dataset."""
80 |
81 | # Use mirror (your.org) to avoid download caps.
82 | BUILDER_CONFIG_CLASS = C4Config
83 | BUILDER_CONFIGS = [
84 | C4Config(
85 | dump="eval.en",
86 | ) # pylint:disable=g-complex-comprehension
87 | ]
88 |
89 | def _info(self):
90 | return datasets.DatasetInfo(
91 | description=_DESCRIPTION,
92 | features=datasets.Features(
93 | {
94 | "text": datasets.Value("string"),
95 | }
96 | ),
97 | # No default supervised_keys.
98 | supervised_keys=None,
99 | homepage="https://dumps.wikimedia.org",
100 | citation=_CITATION,
101 | )
102 |
103 | def _split_generators(self, dl_manager):
104 | data_dir = dl_manager.download(os.path.join('./', f'c4.{self.config.dump}.tar.gz'))
105 | return [
106 | datasets.SplitGenerator(
107 | name=datasets.Split.TRAIN,
108 | # These kwargs will be passed to _generate_examples
109 | gen_kwargs={"filepath": 'train.jsonl',
110 | "split": "train",
111 | "files": dl_manager.iter_archive(data_dir)},
112 | ),
113 | datasets.SplitGenerator(
114 | name=datasets.Split.TEST,
115 | # These kwargs will be passed to _generate_examples
116 | gen_kwargs={"filepath": 'test.jsonl',
117 | "split": "test",
118 | "files": dl_manager.iter_archive(data_dir)},
119 | ),
120 | datasets.SplitGenerator(
121 | name=datasets.Split.VALIDATION,
122 | # These kwargs will be passed to _generate_examples
123 | gen_kwargs={"filepath": 'dev.jsonl',
124 | "split": "validation",
125 | "files": dl_manager.iter_archive(data_dir)},
126 | ),
127 | ]
128 |
129 | def _generate_examples(self, filepath, split, files):
130 | """This function returns the examples in the raw (text) form."""
131 | for path, f in files:
132 | if path == filepath:
133 | for id_, row in enumerate(f):
134 | data = json.loads(row)
135 | yield id_, {
136 | "text": data['text'],
137 | }
138 |
--------------------------------------------------------------------------------
/data/quality-dataset/quality-dataset.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """QuALITY: Question Answering with Long Input Texts, Yes!"""
16 |
17 |
18 | import json
19 | import os
20 |
21 | import datasets
22 |
23 |
24 | _CITATION = """@article{pang2021quality,
25 | title={{QuALITY}: Question Answering with Long Input Texts, Yes!},
26 | author={Pang, Richard Yuanzhe and Parrish, Alicia and Joshi, Nitish and Nangia, Nikita and Phang, Jason and Chen,
27 | Angelica and Padmakumar, Vishakh and Ma, Johnny and Thompson, Jana and He, He and Bowman, Samuel R.},
28 | journal={arXiv preprint arXiv:2112.08608},
29 | year={2021}
30 | }"""
31 |
32 | _DESCRIPTION = """QuALITY: Question Answering with Long Input Texts, Yes!"""
33 |
34 | _LICENSE = "CC BY-SA (Creative Commons / Attribution-ShareAlike)"
35 |
36 | _LABEL_NAMES = [f'choice_{i}' for i in range(5)]
37 |
38 |
39 | class QuALITY(datasets.GeneratorBasedBuilder):
40 | """QuALITY: Question Answering with Long Input Texts, Yes!"""
41 |
42 | VERSION = datasets.Version("1.1.0")
43 |
44 | BUILDER_CONFIGS = [
45 | datasets.BuilderConfig(
46 | name="quality", version=VERSION, description="QuALITY: Question Answering with Long Input Texts"
47 | ),
48 | ]
49 |
50 | DEFAULT_CONFIG_NAME = "quality"
51 |
52 | def _info(self):
53 | features = datasets.Features(
54 | {
55 | "article": datasets.Value("string"),
56 | "question": datasets.Value("string"),
57 | "options": datasets.Sequence(datasets.Value("string")),
58 | "label": datasets.ClassLabel(names=_LABEL_NAMES),
59 | }
60 | )
61 | return datasets.DatasetInfo(
62 | # This is the description that will appear on the datasets page.
63 | description=_DESCRIPTION,
64 | # This defines the different columns of the dataset and their types
65 | features=features, # Here we define them above because they are different between the two configurations
66 | # If there's a common (input, target) tuple from the features,
67 | # specify them here. They'll be used if as_supervised=True in
68 | # builder.as_dataset.
69 | supervised_keys=None,
70 | # Homepage of the dataset for documentation
71 | homepage='https://github.com/nyu-mll/quality',
72 | # License for the dataset if available
73 | license=_LICENSE,
74 | # Citation for the dataset
75 | citation=_CITATION,
76 | )
77 |
78 | def _split_generators(self, dl_manager):
79 | """Returns SplitGenerators."""
80 | data_dir = dl_manager.download_and_extract('QuALITY.v1.0.zip')
81 | return [
82 | datasets.SplitGenerator(
83 | name=datasets.Split.TRAIN,
84 | # These kwargs will be passed to _generate_examples
85 | gen_kwargs={
86 | "filepath": os.path.join(data_dir, 'QuALITY.v1.0.htmlstripped.train'),
87 | "split": "train",
88 | },
89 | ),
90 | datasets.SplitGenerator(
91 | name=datasets.Split.TEST,
92 | # These kwargs will be passed to _generate_examples
93 | gen_kwargs={"filepath": os.path.join(data_dir, 'QuALITY.v1.0.htmlstripped.test'),
94 | "split": "test"},
95 | ),
96 | datasets.SplitGenerator(
97 | name=datasets.Split.VALIDATION,
98 | # These kwargs will be passed to _generate_examples
99 | gen_kwargs={
100 | "filepath": os.path.join(data_dir, 'QuALITY.v1.0.htmlstripped.dev'),
101 | "split": "dev",
102 | },
103 | ),
104 | ]
105 |
106 | def _generate_examples(
107 | self, filepath, split # method parameters are unpacked from `gen_kwargs` as given in `_split_generators`
108 | ):
109 | """Yields examples as (key, example) tuples."""
110 |
111 | with open(filepath, encoding="utf-8") as f:
112 | count = 0
113 | for id_, row in enumerate(f):
114 | data = json.loads(row)
115 | for question in data['questions']:
116 | count += 1
117 | yield count, {
118 | 'article': data['article'],
119 | 'question': question['question'],
120 | 'options': question['options'],
121 | 'label': f"choice_{question['gold_label']}"
122 | }
123 |
--------------------------------------------------------------------------------
/models/hat/convert_roberta_to_htf.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import warnings
4 | from data import DATA_DIR
5 | from transformers import AutoModelForMaskedLM, AutoTokenizer
6 | from models.hat import HATForMaskedLM, HATConfig, HATTokenizer
7 | warnings.filterwarnings("ignore")
8 |
9 | LAYOUTS = {
10 | 'p1': 'S|S|SD|S|S|SD|S|S|SD|S|S|SD',
11 | 'l1': 'S|S|S|S|S|SD|S|SD|S|SD|S|SD',
12 | 'f12': 'S|S|S|S|S|S|S|S|S|S|S|SD|D|D|D'
13 | }
14 |
15 |
16 | def convert_roberta_to_htf():
17 | ''' set default hyperparams in default_hyperparams.py '''
18 | parser = argparse.ArgumentParser()
19 |
20 | # Required arguments
21 | parser.add_argument('--layout', default='p1', choices=['p1', 'l1', 'f12'],
22 | help='S|D encoders layout')
23 | parser.add_argument('--max_sentences', default=32)
24 | config = parser.parse_args()
25 | MAX_SENTENCE_LENGTH = 128
26 | MAX_SENTENCES = int(config.max_sentences)
27 | ENCODER_LAYOUT = {}
28 | for idx, block_pattern in enumerate(LAYOUTS[config.layout].split('|')):
29 | ENCODER_LAYOUT[str(idx)] = {"sentence_encoder": True if 'S' in block_pattern else False,
30 | "document_encoder": True if 'D' in block_pattern else False}
31 |
32 | NUM_HIDDEN_LAYERS = len(ENCODER_LAYOUT.keys())
33 | ROBERTA_CHECKPOINT = 'roberta-base'
34 |
35 | # load pre-trained bert model and tokenizer
36 | roberta_model = AutoModelForMaskedLM.from_pretrained(ROBERTA_CHECKPOINT)
37 | tokenizer = AutoTokenizer.from_pretrained(ROBERTA_CHECKPOINT, model_max_length=MAX_SENTENCE_LENGTH * MAX_SENTENCES)
38 |
39 | # load dummy config and change specifications
40 | roberta_config = roberta_model.config
41 | htf_config = HATConfig.from_pretrained(f'{DATA_DIR}/hat')
42 | # Text length parameters
43 | htf_config.max_sentence_length = MAX_SENTENCE_LENGTH
44 | htf_config.max_sentences = MAX_SENTENCES
45 | htf_config.max_position_embeddings = MAX_SENTENCE_LENGTH + 2
46 | htf_config.model_max_length = int(MAX_SENTENCE_LENGTH * MAX_SENTENCES)
47 | htf_config.num_hidden_layers = NUM_HIDDEN_LAYERS
48 | # Transformer parameters
49 | htf_config.hidden_size = roberta_config.hidden_size
50 | htf_config.intermediate_size = roberta_config.intermediate_size
51 | htf_config.num_attention_heads = roberta_config.num_attention_heads
52 | htf_config.hidden_act = roberta_config.hidden_act
53 | htf_config.encoder_layout = ENCODER_LAYOUT
54 | # Vocabulary parameters
55 | htf_config.vocab_size = roberta_config.vocab_size
56 | htf_config.pad_token_id = roberta_config.pad_token_id
57 | htf_config.bos_token_id = roberta_config.bos_token_id
58 | htf_config.eos_token_id = roberta_config.eos_token_id
59 | htf_config.type_vocab_size = roberta_config.type_vocab_size
60 |
61 | # load dummy hi-transformer model
62 | htf_model = HATForMaskedLM.from_config(htf_config)
63 |
64 | # copy embeddings
65 | htf_model.hi_transformer.embeddings.position_embeddings.weight.data = roberta_model.roberta.embeddings.position_embeddings.weight[:MAX_SENTENCE_LENGTH+roberta_config.pad_token_id+1]
66 | htf_model.hi_transformer.embeddings.word_embeddings.load_state_dict(roberta_model.roberta.embeddings.word_embeddings.state_dict())
67 | htf_model.hi_transformer.embeddings.token_type_embeddings.load_state_dict(roberta_model.roberta.embeddings.token_type_embeddings.state_dict())
68 | htf_model.hi_transformer.embeddings.LayerNorm.load_state_dict(roberta_model.roberta.embeddings.LayerNorm.state_dict())
69 |
70 | # copy transformer layers
71 | for idx in range(min(NUM_HIDDEN_LAYERS, roberta_config.num_hidden_layers)):
72 | if htf_model.config.encoder_layout[str(idx)]['sentence_encoder']:
73 | htf_model.hi_transformer.encoder.layer[idx].sentence_encoder.load_state_dict(roberta_model.roberta.encoder.layer[idx].state_dict())
74 | if htf_model.config.encoder_layout[str(idx)]['document_encoder']:
75 | htf_model.hi_transformer.encoder.layer[idx].document_encoder.load_state_dict(roberta_model.roberta.encoder.layer[idx].state_dict())
76 | htf_model.hi_transformer.encoder.layer[idx].position_embeddings.weight.data = roberta_model.roberta.embeddings.position_embeddings.weight[1:MAX_SENTENCES+2]
77 |
78 | # copy lm_head
79 | htf_model.lm_head.load_state_dict(roberta_model.lm_head.state_dict())
80 |
81 | # save model
82 | htf_model.save_pretrained(f'{DATA_DIR}/PLMs/hat-{config.layout}-roberta')
83 |
84 | # save tokenizer
85 | tokenizer.save_pretrained(f'{DATA_DIR}/PLMs/hat-{config.layout}-roberta')
86 |
87 | # re-load model
88 | htf_model = HATForMaskedLM.from_pretrained(f'{DATA_DIR}/PLMs/hat-{config.layout}-roberta')
89 | htf_tokenizer = HATTokenizer.from_pretrained(f'{DATA_DIR}/PLMs/hat-{config.layout}-roberta')
90 | print(f'RoBERTa-based HAT model with layout {config.layout} is ready to run!')
91 |
92 | # input_ids = torch.randint(1, 30000, (2, 1024), dtype=torch.long)
93 | # input_ids[:, :: 128] = htf_tokenizer.cls_token_id
94 | # labels = input_ids.clone()
95 | # attention_mask = torch.ones((2, 1024), dtype=torch.int)
96 | # htf_model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
97 | # roberta_model(input_ids=input_ids[:, :128], attention_mask=attention_mask[:, :128], labels=labels[:, :128])
98 |
99 |
100 | if __name__ == '__main__':
101 | convert_roberta_to_htf()
102 |
--------------------------------------------------------------------------------
/data/ecthr-arguments-dataset/create_dataset.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import glob
3 | import json
4 | import re
5 | from collections import Counter
6 | import pandas as pd
7 | from ast import literal_eval
8 | USED_LABELS = ['Subsumtion',
9 | 'Vorherige Rechtsprechung des EGMR',
10 | 'Verhältnismäßigkeitsprüfung – Angemessenheit', #Proportionality
11 | 'Entscheidung des EGMR',
12 | 'Verhältnismäßigkeitsprüfung – Rechtsgrundlage', #Proportionality
13 | 'Verhältnismäßigkeitsprüfung – Legitimer Zweck', #Proportionality
14 | 'Konsens der prozessualen Parteien']
15 |
16 | ECTHR_ARG_TYPES = ['Application',
17 | 'Precedent',
18 | 'Proportionality', #Proportionality
19 | 'Decision',
20 | 'Legal Basis', #Proportionality
21 | 'Legitimate Purpose', #Proportionality
22 | 'Non Contestation']
23 |
24 | all_labels = []
25 | article_paragraphs = []
26 | sample_ids = []
27 | article_paragraph_labels = []
28 |
29 |
30 | def fix_labels(labels):
31 | labels = list(set([re.sub('[BI]-', '', label) for label in labels]))
32 | if len(labels) >= 2 and 'O' in labels:
33 | labels.remove('O')
34 | labels = [ECTHR_ARG_TYPES[USED_LABELS.index(label)] for label in labels if label in USED_LABELS]
35 | if len(labels) == 0:
36 | labels = ['O']
37 | return labels
38 |
39 | short_chunks = []
40 |
41 | for subset in ['train', 'val', 'test']:
42 | for filename in glob.glob(f'../mining-legal-arguments/data/{subset}/argType/*.csv'):
43 | with open(filename) as file:
44 | df = pd.read_csv(file, sep='\t', encoding='utf-8')
45 | df['labels'] = df['labels'].map(lambda x: literal_eval(x))
46 | df['tokens'] = df['tokens'].map(lambda x: literal_eval(x))
47 | paragraphs = []
48 | paragraph_labels = []
49 | temp_paragraph = ''
50 | temp_labels = []
51 | for tokens, token_labels in zip(df['tokens'], df['labels']):
52 | paragraph = re.sub(r'([\(\-\[]) ', r'\1', re.sub(r' ([\.\)\,\:\;\'\-\]]|\'s)', r'\1', ' '.join(tokens)))
53 | labels = fix_labels(token_labels)
54 | if len(labels) > 1:
55 | print()
56 | if re.match('(FOR THESE REASONS, THE COURT|for these reasons, the court unanimously)', paragraph):
57 | break
58 | if not re.match('\d{2,}\.', paragraph) and not re.match('[IVX]+\. [A-Z]{2,}', paragraph):
59 | if len(paragraph.split()) <= 5 and re.match('(([A-Z]|\d)\.|\([a-zα-ω]+\))', paragraph):
60 | short_chunks.append(paragraph)
61 | continue
62 | temp_paragraph = temp_paragraph + '\n' + paragraph
63 | temp_labels.extend(labels)
64 | continue
65 | elif len(paragraphs) and len(temp_paragraph):
66 | paragraphs[-1] = paragraphs[-1] + '' + temp_paragraph
67 | paragraph_labels[-1] = list(set(copy.deepcopy(paragraph_labels[-1]) + copy.deepcopy(temp_labels)))
68 | if len(paragraph_labels[-1]) > 1 and 'O' in paragraph_labels[-1]:
69 | paragraph_labels[-1].remove('O')
70 | temp_paragraph = ''
71 | temp_labels = []
72 | if len(paragraph.split()) <= 10 and not re.match('\d{2,}\.', paragraph) and not re.match('[IVX]+\. [A-Z]{2,}', paragraph):
73 | continue
74 | if re.match('[IVX]+\. [A-Z]{2,}', paragraph) and len(paragraphs):
75 | article_paragraphs.append(copy.deepcopy(paragraphs))
76 | article_paragraph_labels.append(copy.deepcopy(paragraph_labels))
77 | sample_ids.append(filename.split('argType/')[1].replace('.csv', ''))
78 | paragraphs = []
79 | paragraph_labels = []
80 | paragraphs.append(paragraph)
81 | paragraph_labels.append(labels)
82 | article_paragraphs.append(copy.deepcopy(paragraphs))
83 | article_paragraph_labels.append(copy.deepcopy(paragraph_labels))
84 | sample_ids.append(filename.split('argType/')[1].replace('.csv', ''))
85 |
86 | article_paragraphs_clean = []
87 | article_paragraph_labels_clean = []
88 | for paragraphs, paragraph_labels in zip(article_paragraphs, article_paragraph_labels):
89 | if len(paragraphs) == len(paragraph_labels):
90 | article_paragraphs_clean.append(copy.deepcopy(paragraphs[:32]))
91 | article_paragraph_labels_clean.append(copy.deepcopy(paragraph_labels[:32]))
92 | all_labels.extend([label for label_group in copy.deepcopy(paragraph_labels[:32]) for label in label_group])
93 |
94 | label_counts = Counter(all_labels)
95 | n_paragraphs = []
96 | for paragraphs in article_paragraphs_clean:
97 | if len(paragraphs) <= 32:
98 | n_paragraphs.append(32)
99 | elif len(paragraphs) <= 64:
100 | n_paragraphs.append(64)
101 | elif len(paragraphs) <= 128:
102 | n_paragraphs.append(128)
103 | else:
104 | n_paragraphs.append('Long')
105 |
106 | par_counts = Counter(n_paragraphs)
107 | print(par_counts.most_common())
108 |
109 | count = 0
110 | with open('ecthr_arguments.jsonl', 'w') as file:
111 | for paragraphs, labels, sample_id in zip(article_paragraphs_clean, article_paragraph_labels_clean, sample_ids):
112 | count += 1
113 | if count <= 900:
114 | data_type = 'train'
115 | elif count <= 1000:
116 | data_type = 'dev'
117 | elif count <= 1100:
118 | data_type = 'test'
119 | else:
120 | break
121 | if labels is None:
122 | print()
123 | else:
124 | for paragraph_labels in labels:
125 | if paragraph_labels is None:
126 | print()
127 | file.write(json.dumps({'case_id': sample_id, 'paragraphs': paragraphs, 'labels': labels, 'data_type': data_type}) + '\n')
128 |
129 |
130 | label_counts = {'train': [], 'dev': [], 'test': []}
131 | with open('ecthr_arguments.jsonl', ) as file:
132 | for line in file:
133 | data = json.loads(line)
134 | label_counts[data['data_type']].extend([label for par_labels in data['labels'] for label in par_labels])
135 |
136 | for key in label_counts:
137 | print(Counter(label_counts[key]).most_common())
--------------------------------------------------------------------------------
/models/longformer/convert_roberta_to_lf.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import torch
4 | import copy
5 | import warnings
6 | from data import DATA_DIR
7 | from transformers import AutoModelForMaskedLM, AutoTokenizer, AutoConfig
8 | warnings.filterwarnings("ignore")
9 |
10 |
11 | def convert_roberta_to_htf():
12 | ''' set default hyperparams in default_hyperparams.py '''
13 | parser = argparse.ArgumentParser()
14 |
15 | # Required arguments
16 | parser.add_argument('--max_sentences', default=32)
17 | parser.add_argument('--num_hidden_layers', default=12)
18 | config = parser.parse_args()
19 | MAX_SENTENCE_LENGTH = 128
20 | MAX_SENTENCES = int(config.max_sentences)
21 | NUM_HIDDEN_LAYERS = int(config.num_hidden_layers)
22 | BERT_CHECKPOINT = 'roberta-base'
23 |
24 | # load pre-trained bert model and tokenizer
25 | roberta_model = AutoModelForMaskedLM.from_pretrained(BERT_CHECKPOINT)
26 | tokenizer = AutoTokenizer.from_pretrained(BERT_CHECKPOINT, model_max_length=MAX_SENTENCE_LENGTH * MAX_SENTENCES)
27 |
28 | # load dummy config and change specifications
29 | roberta_config = roberta_model.config
30 | lf_config = AutoConfig.from_pretrained('allenai/longformer-base-4096')
31 | # Text length parameters
32 | lf_config.max_position_embeddings = int(MAX_SENTENCE_LENGTH * MAX_SENTENCES) + roberta_config.pad_token_id + 2
33 | lf_config.model_max_length = int(MAX_SENTENCE_LENGTH * MAX_SENTENCES)
34 | lf_config.num_hidden_layers = NUM_HIDDEN_LAYERS
35 | # Transformer parameters
36 | lf_config.hidden_size = roberta_config.hidden_size
37 | lf_config.intermediate_size = roberta_config.intermediate_size
38 | lf_config.num_attention_heads = roberta_config.num_attention_heads
39 | lf_config.hidden_act = roberta_config.hidden_act
40 | lf_config.attention_window = [MAX_SENTENCE_LENGTH] * NUM_HIDDEN_LAYERS
41 | # Vocabulary parameters
42 | lf_config.vocab_size = roberta_config.vocab_size
43 | lf_config.pad_token_id = roberta_config.pad_token_id
44 | lf_config.bos_token_id = roberta_config.bos_token_id
45 | lf_config.eos_token_id = roberta_config.eos_token_id
46 | lf_config.cls_token_id = tokenizer.cls_token_id
47 | lf_config.sep_token_id = tokenizer.sep_token_id
48 | lf_config.type_vocab_size = roberta_config.type_vocab_size
49 |
50 | # load dummy hi-transformer model
51 | lf_model = AutoModelForMaskedLM.from_config(lf_config)
52 |
53 | # copy embeddings
54 | k = 2
55 | step = roberta_config.max_position_embeddings - 2
56 | while k < lf_config.max_position_embeddings - 1:
57 | if k + step >= lf_config.max_position_embeddings:
58 | lf_model.longformer.embeddings.position_embeddings.weight.data[k:] = roberta_model.roberta.embeddings.position_embeddings.weight[2:(roberta_config.max_position_embeddings + 2 - k)]
59 | else:
60 | lf_model.longformer.embeddings.position_embeddings.weight.data[k:(k + step)] = roberta_model.roberta.embeddings.position_embeddings.weight[2:]
61 | k += step
62 | lf_model.longformer.embeddings.word_embeddings.load_state_dict(roberta_model.roberta.embeddings.word_embeddings.state_dict())
63 | lf_model.longformer.embeddings.token_type_embeddings.load_state_dict(roberta_model.roberta.embeddings.token_type_embeddings.state_dict())
64 | lf_model.longformer.embeddings.LayerNorm.load_state_dict(roberta_model.roberta.embeddings.LayerNorm.state_dict())
65 |
66 | # copy transformer layers
67 | roberta_model.roberta.encoder.layer = roberta_model.roberta.encoder.layer[:NUM_HIDDEN_LAYERS]
68 | for i in range(len(roberta_model.roberta.encoder.layer)):
69 | # generic
70 | lf_model.longformer.encoder.layer[i].intermediate.dense = copy.deepcopy(
71 | roberta_model.roberta.encoder.layer[i].intermediate.dense)
72 | lf_model.longformer.encoder.layer[i].output.dense = copy.deepcopy(
73 | roberta_model.roberta.encoder.layer[i].output.dense)
74 | lf_model.longformer.encoder.layer[i].output.LayerNorm = copy.deepcopy(
75 | roberta_model.roberta.encoder.layer[i].output.LayerNorm)
76 | # attention output
77 | lf_model.longformer.encoder.layer[i].attention.output.dense = copy.deepcopy(
78 | roberta_model.roberta.encoder.layer[i].attention.output.dense)
79 | lf_model.longformer.encoder.layer[i].attention.output.LayerNorm = copy.deepcopy(
80 | roberta_model.roberta.encoder.layer[i].attention.output.LayerNorm)
81 | # local q,k,v
82 | lf_model.longformer.encoder.layer[i].attention.self.query = copy.deepcopy(
83 | roberta_model.roberta.encoder.layer[i].attention.self.query)
84 | lf_model.longformer.encoder.layer[i].attention.self.key = copy.deepcopy(
85 | roberta_model.roberta.encoder.layer[i].attention.self.key)
86 | lf_model.longformer.encoder.layer[i].attention.self.value = copy.deepcopy(
87 | roberta_model.roberta.encoder.layer[i].attention.self.value)
88 | # global q,k,v
89 | lf_model.longformer.encoder.layer[i].attention.self.query_global = copy.deepcopy(
90 | roberta_model.roberta.encoder.layer[i].attention.self.query)
91 | lf_model.longformer.encoder.layer[i].attention.self.key_global = copy.deepcopy(
92 | roberta_model.roberta.encoder.layer[i].attention.self.key)
93 | lf_model.longformer.encoder.layer[i].attention.self.value_global = copy.deepcopy(
94 | roberta_model.roberta.encoder.layer[i].attention.self.value)
95 |
96 | # copy lm_head
97 | lf_model.lm_head.load_state_dict(roberta_model.lm_head.state_dict())
98 |
99 | # check position ids
100 | # batch = tokenizer(['this is a dog', 'this is a cat'], return_tensors='pt')
101 | # lf_model(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
102 |
103 | # save model
104 | lf_model.save_pretrained(f'{DATA_DIR}/PLMs/longformer-roberta-{NUM_HIDDEN_LAYERS}')
105 |
106 | # save tokenizer
107 | tokenizer.save_pretrained(f'{DATA_DIR}/PLMs/longformer-roberta-{NUM_HIDDEN_LAYERS}')
108 |
109 | # re-load model
110 | lf_model = AutoModelForMaskedLM.from_pretrained(f'{DATA_DIR}/PLMs/longformer-roberta-{NUM_HIDDEN_LAYERS}')
111 | lf_tokenizer = AutoTokenizer.from_pretrained(f'{DATA_DIR}/PLMs/longformer-roberta-{NUM_HIDDEN_LAYERS}')
112 | # batch = tokenizer(['this is a dog', 'this is a cat'], return_tensors='pt')
113 | # lf_model(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
114 | print(f'RoBERTa-based Longformer model is ready to run!')
115 |
116 |
117 | if __name__ == '__main__':
118 | convert_roberta_to_htf()
119 |
--------------------------------------------------------------------------------
/models/longformer/convert_bert_to_lf.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import torch
4 | import copy
5 | import warnings
6 | from data import DATA_DIR
7 | from transformers import AutoModelForMaskedLM, AutoTokenizer, AutoConfig
8 | warnings.filterwarnings("ignore")
9 |
10 |
11 | def convert_bert_to_htf():
12 | ''' set default hyperparams in default_hyperparams.py '''
13 | parser = argparse.ArgumentParser()
14 |
15 | # Required arguments
16 | parser.add_argument('--max_sentences', default=8)
17 | config = parser.parse_args()
18 | MAX_SENTENCE_LENGTH = 128
19 | MAX_SENTENCES = int(config.max_sentences)
20 | NUM_HIDDEN_LAYERS = 6
21 | BERT_LAYERS = NUM_HIDDEN_LAYERS
22 | BERT_CHECKPOINT = f'google/bert_uncased_L-{str(BERT_LAYERS)}_H-256_A-4'
23 |
24 | # load pre-trained bert model and tokenizer
25 | bert_model = AutoModelForMaskedLM.from_pretrained(BERT_CHECKPOINT)
26 | tokenizer = AutoTokenizer.from_pretrained(BERT_CHECKPOINT, model_max_length=MAX_SENTENCE_LENGTH * MAX_SENTENCES)
27 |
28 | # load dummy config and change specifications
29 | bert_config = bert_model.config
30 | lf_config = AutoConfig.from_pretrained('allenai/longformer-base-4096')
31 | # Text length parameters
32 | lf_config.max_position_embeddings = int(MAX_SENTENCE_LENGTH * 8) + bert_config.pad_token_id + 2
33 | lf_config.model_max_length = int(MAX_SENTENCE_LENGTH * MAX_SENTENCES)
34 | lf_config.num_hidden_layers = NUM_HIDDEN_LAYERS
35 | # Transformer parameters
36 | lf_config.hidden_size = bert_config.hidden_size
37 | lf_config.intermediate_size = bert_config.intermediate_size
38 | lf_config.num_attention_heads = bert_config.num_attention_heads
39 | lf_config.hidden_act = bert_config.hidden_act
40 | lf_config.attention_window = [MAX_SENTENCE_LENGTH] * NUM_HIDDEN_LAYERS
41 | # Vocabulary parameters
42 | lf_config.vocab_size = bert_config.vocab_size
43 | lf_config.pad_token_id = bert_config.pad_token_id
44 | lf_config.bos_token_id = bert_config.bos_token_id
45 | lf_config.eos_token_id = bert_config.eos_token_id
46 | lf_config.cls_token_id = tokenizer.cls_token_id
47 | lf_config.sep_token_id = tokenizer.sep_token_id
48 | lf_config.type_vocab_size = bert_config.type_vocab_size
49 |
50 | # load dummy hi-transformer model
51 | lf_model = AutoModelForMaskedLM.from_config(lf_config)
52 |
53 | # copy embeddings
54 | lf_model.longformer.embeddings.position_embeddings.weight.data[0] = torch.zeros((bert_config.hidden_size,))
55 | k = 1
56 | step = bert_config.max_position_embeddings - 1
57 | while k < lf_config.max_position_embeddings - 1:
58 | if k + step >= lf_config.max_position_embeddings:
59 | lf_model.longformer.embeddings.position_embeddings.weight.data[k:] = bert_model.bert.embeddings.position_embeddings.weight[1:(bert_config.max_position_embeddings + 1 - k)]
60 | else:
61 | lf_model.longformer.embeddings.position_embeddings.weight.data[k:(k + step)] = bert_model.bert.embeddings.position_embeddings.weight[1:]
62 | k += step
63 | lf_model.longformer.embeddings.word_embeddings.load_state_dict(bert_model.bert.embeddings.word_embeddings.state_dict())
64 | lf_model.longformer.embeddings.token_type_embeddings.load_state_dict(bert_model.bert.embeddings.token_type_embeddings.state_dict())
65 | lf_model.longformer.embeddings.LayerNorm.load_state_dict(bert_model.bert.embeddings.LayerNorm.state_dict())
66 |
67 | # copy transformer layers
68 | for i in range(len(bert_model.bert.encoder.layer)):
69 | # generic
70 | lf_model.longformer.encoder.layer[i].intermediate.dense = copy.deepcopy(
71 | bert_model.bert.encoder.layer[i].intermediate.dense)
72 | lf_model.longformer.encoder.layer[i].output.dense = copy.deepcopy(
73 | bert_model.bert.encoder.layer[i].output.dense)
74 | lf_model.longformer.encoder.layer[i].output.LayerNorm = copy.deepcopy(
75 | bert_model.bert.encoder.layer[i].output.LayerNorm)
76 | # attention output
77 | lf_model.longformer.encoder.layer[i].attention.output.dense = copy.deepcopy(
78 | bert_model.bert.encoder.layer[i].attention.output.dense)
79 | lf_model.longformer.encoder.layer[i].attention.output.LayerNorm = copy.deepcopy(
80 | bert_model.bert.encoder.layer[i].attention.output.LayerNorm)
81 | # local q,k,v
82 | lf_model.longformer.encoder.layer[i].attention.self.query = copy.deepcopy(
83 | bert_model.bert.encoder.layer[i].attention.self.query)
84 | lf_model.longformer.encoder.layer[i].attention.self.key = copy.deepcopy(
85 | bert_model.bert.encoder.layer[i].attention.self.key)
86 | lf_model.longformer.encoder.layer[i].attention.self.value = copy.deepcopy(
87 | bert_model.bert.encoder.layer[i].attention.self.value)
88 | # global q,k,v
89 | lf_model.longformer.encoder.layer[i].attention.self.query_global = copy.deepcopy(
90 | bert_model.bert.encoder.layer[i].attention.self.query)
91 | lf_model.longformer.encoder.layer[i].attention.self.key_global = copy.deepcopy(
92 | bert_model.bert.encoder.layer[i].attention.self.key)
93 | lf_model.longformer.encoder.layer[i].attention.self.value_global = copy.deepcopy(
94 | bert_model.bert.encoder.layer[i].attention.self.value)
95 |
96 | # copy lm_head
97 | lf_model.lm_head.dense.load_state_dict(bert_model.cls.predictions.transform.dense.state_dict())
98 | lf_model.lm_head.layer_norm.load_state_dict(bert_model.cls.predictions.transform.LayerNorm.state_dict())
99 | lf_model.lm_head.decoder.load_state_dict(bert_model.cls.predictions.decoder.state_dict())
100 | lf_model.lm_head.bias = copy.deepcopy(bert_model.cls.predictions.bias)
101 |
102 | # check position ids
103 | # batch = tokenizer(['this is a dog', 'this is a cat'], return_tensors='pt')
104 | # lf_model(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
105 |
106 | # save model
107 | lf_model.save_pretrained(f'{DATA_DIR}/PLMs/longformer')
108 |
109 | # save tokenizer
110 | tokenizer.save_pretrained(f'{DATA_DIR}/PLMs/longformer')
111 |
112 | # re-load model
113 | lf_model = AutoModelForMaskedLM.from_pretrained(f'{DATA_DIR}/PLMs/longformer')
114 | lf_tokenizer = AutoTokenizer.from_pretrained(f'{DATA_DIR}/PLMs/longformer')
115 | # batch = tokenizer(['this is a dog', 'this is a cat'], return_tensors='pt')
116 | # lf_model(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
117 | print(f'Longformer model is ready to run!')
118 |
119 |
120 | if __name__ == '__main__':
121 | convert_bert_to_htf()
122 |
--------------------------------------------------------------------------------
/models/hat/convert_bert_to_htf.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import torch
4 | import copy
5 | import warnings
6 | from data import DATA_DIR
7 | from transformers import AutoModelForMaskedLM, AutoTokenizer
8 | from models.hat import HATForMaskedLM, HATConfig, HATTokenizer
9 | warnings.filterwarnings("ignore")
10 |
11 | LAYOUTS = {
12 | 's1': 'SD|SD|SD|SD|SD|SD',
13 | 's2': 'S|SD|D|S|SD|D|S|SD|D',
14 | 'p1': 'S|SD|S|SD|S|SD|S|SD',
15 | 'p2': 'S|S|SD|S|S|SD|S|S|SD',
16 | 'e1': 'SD|SD|SD|S|S|S|S|S|S',
17 | 'e2': 'S|SD|D|S|SD|D|S|S|S|S',
18 | 'l1': 'S|S|S|S|S|S|SD|SD|SD',
19 | 'l2': 'S|S|S|S|S|SD|D|S|SD|D',
20 | 'b1': 'S|S|SD|D|S|SD|D|S|S|S',
21 | 'b2': 'S|S|SD|SD|SD|S|S|S|S',
22 | 'f12': 'S|S|S|S|S|S|S|S|S|S|S|S',
23 | 'f8': 'S|S|S|S|S|S|S|S',
24 | 'f6': 'S|S|S|S|S|S',
25 | }
26 |
27 |
28 | def convert_bert_to_htf():
29 | ''' set default hyperparams in default_hyperparams.py '''
30 | parser = argparse.ArgumentParser()
31 |
32 | # Required arguments
33 | parser.add_argument('--warmup_strategy', default='grouped', choices=['linear', 'grouped', 'random', 'embeds-only', 'none'],
34 | help='linear: S|D encoders are warm-started independently (one-by-one)'
35 | 'grouped: pairs of S|D are warm-started with weights from the very same level'
36 | 'random: D encoders are not warm-started'
37 | 'embeds-only: No warm-starting, except embeddings'
38 | 'none: No warm-starting')
39 | parser.add_argument('--layout', default='s1', choices=['s1', 's2', 'p1', 'p2', 'e1', 'e2',
40 | 'l1', 'l2', 'b1', 'b2', 'f12', 'f8', 'f6'],
41 | help='S|D encoders layout')
42 | parser.add_argument('--max_sentences', default=8)
43 | config = parser.parse_args()
44 | MAX_SENTENCE_LENGTH = 128
45 | MAX_SENTENCES = int(config.max_sentences)
46 | ENCODER_LAYOUT = {}
47 | for idx, block_pattern in enumerate(LAYOUTS[config.layout].split('|')):
48 | ENCODER_LAYOUT[str(idx)] = {"sentence_encoder": True if 'S' in block_pattern else False,
49 | "document_encoder": True if 'D' in block_pattern else False}
50 |
51 | NUM_HIDDEN_LAYERS = len(ENCODER_LAYOUT.keys())
52 | BERT_LAYERS = NUM_HIDDEN_LAYERS if config.warmup_strategy != 'linear' else NUM_HIDDEN_LAYERS*2
53 | BERT_LAYERS = BERT_LAYERS + 1 if BERT_LAYERS % 2 else BERT_LAYERS
54 | BERT_CHECKPOINT = f'google/bert_uncased_L-{str(BERT_LAYERS)}_H-256_A-4'
55 |
56 | # load pre-trained bert model and tokenizer
57 | bert_model = AutoModelForMaskedLM.from_pretrained(BERT_CHECKPOINT)
58 | tokenizer = AutoTokenizer.from_pretrained(BERT_CHECKPOINT, model_max_length=MAX_SENTENCE_LENGTH * MAX_SENTENCES)
59 |
60 | # load dummy config and change specifications
61 | bert_config = bert_model.config
62 | htf_config = HATConfig.from_pretrained(f'{DATA_DIR}/hat')
63 | # Text length parameters
64 | htf_config.max_sentence_length = MAX_SENTENCE_LENGTH
65 | htf_config.max_sentences = MAX_SENTENCES
66 | htf_config.max_position_embeddings = MAX_SENTENCE_LENGTH
67 | htf_config.model_max_length = int(MAX_SENTENCE_LENGTH * MAX_SENTENCES)
68 | htf_config.num_hidden_layers = NUM_HIDDEN_LAYERS
69 | # Transformer parameters
70 | htf_config.hidden_size = bert_config.hidden_size
71 | htf_config.intermediate_size = bert_config.intermediate_size
72 | htf_config.num_attention_heads = bert_config.num_attention_heads
73 | htf_config.hidden_act = bert_config.hidden_act
74 | htf_config.encoder_layout = ENCODER_LAYOUT
75 | # Vocabulary parameters
76 | htf_config.vocab_size = bert_config.vocab_size
77 | htf_config.pad_token_id = bert_config.pad_token_id
78 | htf_config.bos_token_id = bert_config.bos_token_id
79 | htf_config.eos_token_id = bert_config.eos_token_id
80 | htf_config.type_vocab_size = bert_config.type_vocab_size
81 |
82 | # load dummy hi-transformer model
83 | htf_model = HATForMaskedLM.from_config(htf_config)
84 |
85 | if config.warmup_strategy != 'none':
86 | # copy embeddings
87 | htf_model.hi_transformer.embeddings.position_embeddings.weight.data[0] = torch.zeros((bert_config.hidden_size,))
88 | htf_model.hi_transformer.embeddings.position_embeddings.weight.data[1:] = bert_model.bert.embeddings.position_embeddings.weight[1:MAX_SENTENCE_LENGTH+htf_config.pad_token_id+1]
89 | htf_model.hi_transformer.embeddings.word_embeddings.load_state_dict(bert_model.bert.embeddings.word_embeddings.state_dict())
90 | htf_model.hi_transformer.embeddings.token_type_embeddings.load_state_dict(bert_model.bert.embeddings.token_type_embeddings.state_dict())
91 | htf_model.hi_transformer.embeddings.LayerNorm.load_state_dict(bert_model.bert.embeddings.LayerNorm.state_dict())
92 |
93 | if config.warmup_strategy != 'embeds-only':
94 | # copy transformer layers
95 | if config.warmup_strategy != 'linear':
96 | for idx in range(NUM_HIDDEN_LAYERS):
97 | if htf_model.config.encoder_layout[str(idx)]['sentence_encoder']:
98 | htf_model.hi_transformer.encoder.layer[idx].sentence_encoder.load_state_dict(bert_model.bert.encoder.layer[idx].state_dict())
99 | if htf_model.config.encoder_layout[str(idx)]['document_encoder']:
100 | if config.warmup_strategy == 'grouped':
101 | htf_model.hi_transformer.encoder.layer[idx].document_encoder.load_state_dict(bert_model.bert.encoder.layer[idx].state_dict())
102 | htf_model.hi_transformer.encoder.layer[idx].position_embeddings.weight.data = bert_model.bert.embeddings.position_embeddings.weight[1:MAX_SENTENCES+2]
103 | else:
104 | for idx, l_idx in enumerate(range(0, NUM_HIDDEN_LAYERS*2, 2)):
105 | if htf_model.config.encoder_layout[str(idx)]['sentence_encoder']:
106 | htf_model.hi_transformer.encoder.layer[idx].sentence_encoder.load_state_dict(bert_model.bert.encoder.layer[l_idx].state_dict())
107 | if htf_model.config.encoder_layout[str(idx)]['document_encoder']:
108 | htf_model.hi_transformer.encoder.layer[idx].document_encoder.load_state_dict(bert_model.bert.encoder.layer[l_idx+1].state_dict())
109 | htf_model.hi_transformer.encoder.layer[idx].position_embeddings.weight.data = bert_model.bert.embeddings.position_embeddings.weight[1:MAX_SENTENCES+2]
110 |
111 | # copy lm_head
112 | htf_model.lm_head.dense.load_state_dict(bert_model.cls.predictions.transform.dense.state_dict())
113 | htf_model.lm_head.layer_norm.load_state_dict(bert_model.cls.predictions.transform.LayerNorm.state_dict())
114 | htf_model.lm_head.decoder.load_state_dict(bert_model.cls.predictions.decoder.state_dict())
115 | htf_model.lm_head.bias = copy.deepcopy(bert_model.cls.predictions.bias)
116 |
117 | # save model
118 | htf_model.save_pretrained(f'{DATA_DIR}/PLMs/hat-{config.layout}-{config.warmup_strategy}')
119 |
120 | # save tokenizer
121 | tokenizer.save_pretrained(f'{DATA_DIR}/PLMs/hat-{config.layout}-{config.warmup_strategy}')
122 |
123 | # re-load model
124 | htf_model = HATForMaskedLM.from_pretrained(f'{DATA_DIR}/PLMs/hat-{config.layout}-{config.warmup_strategy}')
125 | htf_tokenizer = HATTokenizer.from_pretrained(f'{DATA_DIR}/PLMs/hat-{config.layout}-{config.warmup_strategy}')
126 | print(f'HAT model with layout {config.layout} and warm-up strategy {config.warmup_strategy} is ready to run!')
127 |
128 |
129 | if __name__ == '__main__':
130 | convert_bert_to_htf()
131 |
--------------------------------------------------------------------------------
/data/ecthr-arguments-dataset/ecthr-arguments-dataset.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """ECtHRArguments"""
16 |
17 | import json
18 | import os
19 | import textwrap
20 |
21 | import datasets
22 |
23 | MAIN_CITATION = """\
24 | @article{Habernal.et.al.2022.arg,
25 | author = {Habernal, Ivan and Faber, Daniel and Recchia, Nicola and
26 | Bretthauer, Sebastian and Gurevych, Iryna and
27 | Spiecker genannt Döhmann, Indra and Burchard, Christoph},
28 | title = {{Mining Legal Arguments in Court Decisions}},
29 | journal = {arXiv preprint},
30 | year = {2022},
31 | doi = {10.48550/arXiv.2208.06178},
32 | }"""
33 |
34 | _DESCRIPTION = """\
35 | The dataset contains approx. 300 cases from the European Court of Human Rights (ECtHR). For each case, the dataset provides
36 | a list of argumentative paragraphs from the case analysis. Spans in each paragraph has been labeled with one or more out
37 | of 13 argument types. We re-formulate this task, as a sequential paragraph classification task, where each paragraph is
38 | labelled with one or more labels. The input of the model is the list of paragraphs of a case, and the output is the set
39 | of relevant argument types per paragraph.
40 | """
41 |
42 | ECTHR_ARG_TYPES = ['Application', 'Precedent', 'Proportionality', 'Decision',
43 | 'Legal Basis', 'Legitimate Purpose', 'Non Contestation']
44 |
45 |
46 | class ECtHRArgumentsConfig(datasets.BuilderConfig):
47 | """BuilderConfig for ECtHRArguments."""
48 |
49 | def __init__(
50 | self,
51 | text_column,
52 | label_column,
53 | url,
54 | data_url,
55 | data_file,
56 | citation,
57 | label_classes=None,
58 | multi_label=None,
59 | dev_column="dev",
60 | **kwargs,
61 | ):
62 | """BuilderConfig for ECtHRArguments.
63 |
64 | Args:
65 | text_column: ``string`, name of the column in the jsonl file corresponding
66 | to the text
67 | label_column: `string`, name of the column in the jsonl file corresponding
68 | to the label
69 | url: `string`, url for the original project
70 | data_url: `string`, url to download the zip file from
71 | data_file: `string`, filename for data set
72 | citation: `string`, citation for the data set
73 | url: `string`, url for information about the data set
74 | label_classes: `list[string]`, the list of classes if the label is
75 | categorical. If not provided, then the label will be of type
76 | `datasets.Value('float32')`.
77 | multi_label: `boolean`, True if the task is multi-label
78 | dev_column: `string`, name for the development subset
79 | **kwargs: keyword arguments forwarded to super.
80 | """
81 | super(ECtHRArgumentsConfig, self).__init__(version=datasets.Version("1.0.0", ""), **kwargs)
82 | self.text_column = text_column
83 | self.label_column = label_column
84 | self.label_classes = label_classes
85 | self.multi_label = multi_label
86 | self.dev_column = dev_column
87 | self.url = url
88 | self.data_url = data_url
89 | self.data_file = data_file
90 | self.citation = citation
91 |
92 |
93 | class LexGLUE(datasets.GeneratorBasedBuilder):
94 | """LexGLUE: A Benchmark Dataset for Legal Language Understanding in English. Version 1.0"""
95 |
96 | BUILDER_CONFIGS = [
97 | ECtHRArgumentsConfig(
98 | name="ecthr-arguments-dataset",
99 | description=textwrap.dedent(
100 | """\
101 | The UKLEX dataset consists of UK laws that have been labeled with concepts.
102 | Given a document, the task is to predict its labels (concepts).
103 | """
104 | ),
105 | text_column="paragraphs",
106 | label_column="labels",
107 | label_classes=ECTHR_ARG_TYPES,
108 | multi_label=True,
109 | dev_column="dev",
110 | data_url=f"ecthr_arguments.tar.gz",
111 | data_file="ecthr_arguments.jsonl",
112 | url="https://github.com/trusthlt/mining-legal-arguments",
113 | citation=textwrap.dedent(
114 | """@article{Habernal.et.al.2022.arg,
115 | author = {Habernal, Ivan and Faber, Daniel and Recchia, Nicola and
116 | Bretthauer, Sebastian and Gurevych, Iryna and
117 | Spiecker genannt Döhmann, Indra and Burchard, Christoph},
118 | title = {{Mining Legal Arguments in Court Decisions}},
119 | journal = {arXiv preprint},
120 | year = {2022},
121 | doi = {10.48550/arXiv.2208.06178},
122 | }"""
123 | ),
124 | )
125 | ]
126 |
127 | def _info(self):
128 | features = {"text": datasets.features.Sequence(datasets.Value("string")),
129 | "labels": datasets.features.Sequence(datasets.features.Sequence(datasets.ClassLabel(names=self.config.label_classes)))}
130 | return datasets.DatasetInfo(
131 | description=self.config.description,
132 | features=datasets.Features(features),
133 | homepage=self.config.url,
134 | citation=self.config.citation + "\n" + MAIN_CITATION,
135 | )
136 |
137 | def _split_generators(self, dl_manager):
138 | data_dir = dl_manager.download_and_extract(self.config.data_url)
139 | return [
140 | datasets.SplitGenerator(
141 | name=datasets.Split.TRAIN,
142 | # These kwargs will be passed to _generate_examples
143 | gen_kwargs={"filepath": os.path.join(data_dir, self.config.data_file), "split": "train"},
144 | ),
145 | datasets.SplitGenerator(
146 | name=datasets.Split.TEST,
147 | # These kwargs will be passed to _generate_examples
148 | gen_kwargs={"filepath": os.path.join(data_dir, self.config.data_file), "split": "test"},
149 | ),
150 | datasets.SplitGenerator(
151 | name=datasets.Split.VALIDATION,
152 | # These kwargs will be passed to _generate_examples
153 | gen_kwargs={
154 | "filepath": os.path.join(data_dir, self.config.data_file),
155 | "split": self.config.dev_column,
156 | },
157 | ),
158 | ]
159 |
160 | def _generate_examples(self, filepath, split):
161 | """This function returns the examples in the raw (text) form."""
162 | with open(filepath, "r", encoding="utf-8") as f:
163 | for id_, row in enumerate(f):
164 | data = json.loads(row)
165 | labels = [list(set(par_labels)).remove('O') if 'O' in par_labels else list(set(par_labels)) for
166 | par_labels in data[self.config.label_column]]
167 | if data["data_type"] == split:
168 | yield id_, {
169 | "text": data[self.config.text_column],
170 | "labels": labels,
171 | }
--------------------------------------------------------------------------------
/models/big_bird/modeling_big_bird.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The Allen Institute for AI team and The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """PyTorch BigBird model."""
16 |
17 | from dataclasses import dataclass
18 | from typing import Optional, Tuple, Union
19 |
20 | import torch
21 | import torch.utils.checkpoint
22 | from torch import nn
23 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss, CosineEmbeddingLoss
24 | from transformers.models.big_bird.modeling_big_bird import BigBirdPreTrainedModel, BigBirdModel
25 |
26 | from transformers.modeling_outputs import (
27 | ModelOutput
28 | )
29 |
30 | @dataclass
31 | class SentenceClassifierOutput(ModelOutput):
32 | """
33 | Base class for outputs of sentence classification models.
34 |
35 | Args:
36 | loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided) :
37 | Classification loss.
38 | logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`):
39 | Classification scores (before SoftMax).
40 | hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
41 | Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
42 | one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
43 |
44 | Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
45 | attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
46 | Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
47 | sequence_length)`.
48 | sentence_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
49 | Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
50 | sequence_length)`.
51 |
52 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
53 | heads.
54 | """
55 |
56 | loss: Optional[Tuple[torch.FloatTensor]] = None
57 | logits: torch.FloatTensor = None
58 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None
59 | attentions: Optional[Tuple[torch.FloatTensor]] = None
60 | sentence_attentions: Optional[Tuple[torch.FloatTensor]] = None
61 |
62 |
63 | class BigBirdSentencizer(nn.Module):
64 | def __init__(self, config):
65 | super().__init__()
66 | self.dense = nn.Linear(config.hidden_size, config.hidden_size)
67 | self.activation = nn.Tanh()
68 | self.max_sentence_length = config.max_sentence_length
69 |
70 | def forward(self, hidden_states):
71 | sentence_repr_hidden_states = hidden_states[:, ::self.max_sentence_length]
72 | sentence_outputs = self.dense(sentence_repr_hidden_states)
73 | sentence_outputs = self.activation(sentence_outputs)
74 | return sentence_outputs
75 |
76 |
77 | class BigBirdModelForSentenceClassification(BigBirdPreTrainedModel):
78 | _keys_to_ignore_on_load_missing = [r"position_ids"]
79 |
80 | def __init__(self, config):
81 | super().__init__(config)
82 | self.num_labels = config.num_labels
83 | self.config = config
84 |
85 | self.bert = BigBirdModel(config, add_pooling_layer=False)
86 | self.sentencizer = BigBirdSentencizer(config)
87 | classifier_dropout = (
88 | config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
89 | )
90 | self.dropout = nn.Dropout(classifier_dropout)
91 | self.classifier = nn.Linear(config.hidden_size, config.num_labels)
92 |
93 | # Initialize weights and apply final processing
94 | self.post_init()
95 |
96 | @classmethod
97 | def from_config(cls, config):
98 | return cls._from_config(config)
99 |
100 | def forward(
101 | self,
102 | input_ids=None,
103 | attention_mask=None,
104 | token_type_ids=None,
105 | position_ids=None,
106 | inputs_embeds=None,
107 | labels=None,
108 | output_attentions=None,
109 | output_hidden_states=None,
110 | return_dict=None,
111 | ):
112 | r"""
113 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
114 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
115 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
116 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
117 | """
118 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
119 |
120 | outputs = self.bert(
121 | input_ids,
122 | attention_mask=attention_mask,
123 | token_type_ids=token_type_ids,
124 | position_ids=position_ids,
125 | inputs_embeds=inputs_embeds,
126 | output_attentions=output_attentions,
127 | output_hidden_states=output_hidden_states,
128 | return_dict=return_dict,
129 | )
130 | sequence_output = outputs[0]
131 | sentence_outputs = self.sentencizer(sequence_output)
132 | sentence_outputs = self.dropout(sentence_outputs)
133 | logits = self.classifier(sentence_outputs)
134 |
135 | loss = None
136 | if labels is not None:
137 | if self.config.problem_type is None:
138 | if self.num_labels == 1:
139 | self.config.problem_type = "regression"
140 | elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
141 | self.config.problem_type = "single_label_classification"
142 | else:
143 | self.config.problem_type = "multi_label_classification"
144 |
145 | if self.config.problem_type == "regression":
146 | loss_fct = MSELoss()
147 | if self.num_labels == 1:
148 | loss = loss_fct(logits.view(-1, 1).squeeze(), labels.view(-1).squeeze())
149 | else:
150 | loss = loss_fct(logits.view(-1, 1), labels.view(-1))
151 | elif self.config.problem_type == "single_label_classification":
152 | loss_fct = CrossEntropyLoss()
153 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
154 | elif self.config.problem_type == "multi_label_classification":
155 | loss_fct = BCEWithLogitsLoss()
156 | mask = labels[:, :, 0] != -1
157 | loss = loss_fct(logits[mask], labels[mask])
158 |
159 | if not return_dict:
160 | output = (logits,) + outputs[2:]
161 | return ((loss,) + output) if loss is not None else output
162 |
163 | return SentenceClassifierOutput(
164 | loss=loss,
165 | logits=logits,
166 | hidden_states=outputs.hidden_states,
167 | attentions=outputs.attentions,
168 | sentence_attentions=None
169 | )
170 |
171 |
--------------------------------------------------------------------------------
/data/contractnli-dataset/contractnli-dataset.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """ContractNLI"""
16 |
17 | import json
18 | import os
19 | import textwrap
20 |
21 | import datasets
22 |
23 | MAIN_CITATION = """\
24 | @inproceedings{koreeda-manning-2021-contractnli-dataset,
25 | title = "{C}ontract{NLI}: A Dataset for Document-level Natural Language Inference for Contracts",
26 | author = "Koreeda, Yuta and
27 | Manning, Christopher",
28 | booktitle = "Findings of the Association for Computational Linguistics: EMNLP 2021",
29 | month = nov,
30 | year = "2021",
31 | address = "Punta Cana, Dominican Republic",
32 | publisher = "Association for Computational Linguistics",
33 | url = "https://aclanthology.org/2021.findings-emnlp.164",
34 | doi = "10.18653/v1/2021.findings-emnlp.164",
35 | pages = "1907--1919",
36 | }"""
37 |
38 | _DESCRIPTION = """\
39 | The ContractNLI dataset consists of Non-Disclosure Agreements (NDAs). All NDAs have been labeled based
40 | on several hypothesis templates as entailment, neutral or contradiction. In this version of the task
41 | (Task B), the input consists of the full document.
42 | """
43 |
44 | LABELS = ["contradiction", "entailment", "neutral"]
45 |
46 |
47 | class ContractNLIConfig(datasets.BuilderConfig):
48 | """BuilderConfig for ContractNLI."""
49 |
50 | def __init__(
51 | self,
52 | text_column,
53 | label_column,
54 | url,
55 | data_url,
56 | data_file,
57 | citation,
58 | label_classes=None,
59 | multi_label=None,
60 | dev_column="dev",
61 | **kwargs,
62 | ):
63 | """BuilderConfig for ContractNLI.
64 |
65 | Args:
66 | text_column: ``string`, name of the column in the jsonl file corresponding
67 | to the text
68 | label_column: `string`, name of the column in the jsonl file corresponding
69 | to the label
70 | url: `string`, url for the original project
71 | data_url: `string`, url to download the zip file from
72 | data_file: `string`, filename for data set
73 | citation: `string`, citation for the data set
74 | url: `string`, url for information about the data set
75 | label_classes: `list[string]`, the list of classes if the label is
76 | categorical. If not provided, then the label will be of type
77 | `datasets.Value('float32')`.
78 | multi_label: `boolean`, True if the task is multi-label
79 | dev_column: `string`, name for the development subset
80 | **kwargs: keyword arguments forwarded to super.
81 | """
82 | super(ContractNLIConfig, self).__init__(version=datasets.Version("1.0.0", ""), **kwargs)
83 | self.text_column = text_column
84 | self.label_column = label_column
85 | self.label_classes = label_classes
86 | self.multi_label = multi_label
87 | self.dev_column = dev_column
88 | self.url = url
89 | self.data_url = data_url
90 | self.data_file = data_file
91 | self.citation = citation
92 |
93 |
94 | class LexGLUE(datasets.GeneratorBasedBuilder):
95 | """LexGLUE: A Benchmark Dataset for Legal Language Understanding in English. Version 1.0"""
96 |
97 | BUILDER_CONFIGS = [
98 | ContractNLIConfig(
99 | name="contractnli",
100 | description=textwrap.dedent(
101 | """\
102 | The ContractNLI dataset consists of Non-Disclosure Agreements (NDAs). All NDAs have been labeled based
103 | on several hypothesis templates as entailment, neutral or contradiction. In this version of the task
104 | (Task B), the input consists of the full document.
105 | """
106 | ),
107 | text_column="premise",
108 | label_column="label",
109 | label_classes=LABELS,
110 | multi_label=False,
111 | dev_column="dev",
112 | data_url="contract_nli.zip",
113 | data_file="contract_nli_long.jsonl",
114 | url="https://stanfordnlp.github.io/contract-nli/",
115 | citation=textwrap.dedent(
116 | """\
117 | @inproceedings{koreeda-manning-2021-contractnli-dataset,
118 | title = "{C}ontract{NLI}: A Dataset for Document-level Natural Language Inference for Contracts",
119 | author = "Koreeda, Yuta and
120 | Manning, Christopher",
121 | booktitle = "Findings of the Association for Computational Linguistics: EMNLP 2021",
122 | month = nov,
123 | year = "2021",
124 | address = "Punta Cana, Dominican Republic",
125 | publisher = "Association for Computational Linguistics",
126 | url = "https://aclanthology.org/2021.findings-emnlp.164",
127 | doi = "10.18653/v1/2021.findings-emnlp.164",
128 | pages = "1907--1919",
129 | }
130 | }"""
131 | ),
132 | )
133 | ]
134 |
135 | def _info(self):
136 | features = {"premise": datasets.Value("string"), "hypothesis": datasets.Value("string"),
137 | 'label': datasets.ClassLabel(names=LABELS)}
138 | return datasets.DatasetInfo(
139 | description=self.config.description,
140 | features=datasets.Features(features),
141 | homepage=self.config.url,
142 | citation=self.config.citation + "\n" + MAIN_CITATION,
143 | )
144 |
145 | def _split_generators(self, dl_manager):
146 | data_dir = dl_manager.download_and_extract(self.config.data_url)
147 | return [
148 | datasets.SplitGenerator(
149 | name=datasets.Split.TRAIN,
150 | # These kwargs will be passed to _generate_examples
151 | gen_kwargs={"filepath": os.path.join(data_dir, self.config.data_file), "split": "train"},
152 | ),
153 | datasets.SplitGenerator(
154 | name=datasets.Split.TEST,
155 | # These kwargs will be passed to _generate_examples
156 | gen_kwargs={"filepath": os.path.join(data_dir, self.config.data_file), "split": "test"},
157 | ),
158 | datasets.SplitGenerator(
159 | name=datasets.Split.VALIDATION,
160 | # These kwargs will be passed to _generate_examples
161 | gen_kwargs={
162 | "filepath": os.path.join(data_dir, self.config.data_file),
163 | "split": self.config.dev_column,
164 | },
165 | ),
166 | ]
167 |
168 | def _generate_examples(self, filepath, split):
169 | """This function returns the examples in the raw (text) form."""
170 | with open(filepath, "r", encoding="utf-8") as f:
171 | sid = -1
172 | for id_, row in enumerate(f):
173 | data = json.loads(row)
174 | if data["subset"] == split:
175 | for sample in data['hypothesises/labels']:
176 | sid += 1
177 | yield sid, {
178 | "premise": data["premise"],
179 | "hypothesis": sample['hypothesis'],
180 | "label": sample['label'],
181 | }
--------------------------------------------------------------------------------
/models/hat/configuration_hat.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # http://www.apache.org/licenses/LICENSE-2.0
7 | #
8 | # Unless required by applicable law or agreed to in writing, software
9 | # distributed under the License is distributed on an "AS IS" BASIS,
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | # See the License for the specific language governing permissions and
12 | # limitations under the License.
13 | """ HAT configuration"""
14 | from collections import OrderedDict
15 | from typing import Mapping
16 |
17 | from transformers.onnx import OnnxConfig
18 | from transformers.utils import logging
19 | from transformers import PretrainedConfig
20 |
21 |
22 | logger = logging.get_logger(__name__)
23 |
24 | HAT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
25 | "kiddothe2b/hierarchical-transformer-base-4096": "https://huggingface.co/kiddothe2b/hierarchical-transformer-base-4096/resolve/main/config.json",
26 | "kiddothe2b/adhoc-hierarchical-transformer-base-4096": "https://huggingface.co/kiddothe2b/adhoc-hierarchical-transformer-base-4096/resolve/main/config.json",
27 | }
28 |
29 |
30 | class HATConfig(PretrainedConfig):
31 | r"""
32 | This is the configuration class to store the configuration of a :class:`~transformers.HAT`.
33 | It is used to instantiate a HAT model according to the specified arguments,
34 | defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration
35 | to that of the HAT `kiddothe2b/hierarchical-transformer-base-4096
36 | `__ architecture.
37 |
38 | Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
39 | outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
40 |
41 |
42 | Args:
43 | vocab_size (:obj:`int`, `optional`, defaults to 30522):
44 | Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the
45 | :obj:`inputs_ids` passed when calling :class:`~transformers.BertModel` or
46 | :class:`~transformers.TFBertModel`.
47 | max_sentences (:obj:`int`, `optional`, defaults to 64):
48 | The maximum number of sentences that this model might ever be used with.
49 | max_sentence_size (:obj:`int`, `optional`, defaults to 128):
50 | The maximum sentence length that this model might ever be used with.
51 | model_max_length (:obj:`int`, `optional`, defaults to 8192):
52 | The maximum sequence length (max_sentences * max_sentence_size) that this model might ever be used with
53 | encoder_layout (:obj:`Dict`):
54 | The sentence/document encoder layout.
55 | hidden_size (:obj:`int`, `optional`, defaults to 768):
56 | Dimensionality of the encoder layers and the pooler layer.
57 | num_hidden_layers (:obj:`int`, `optional`, defaults to 12):
58 | Number of hidden layers in the Transformer encoder.
59 | num_attention_heads (:obj:`int`, `optional`, defaults to 12):
60 | Number of attention heads for each attention layer in the Transformer encoder.
61 | intermediate_size (:obj:`int`, `optional`, defaults to 3072):
62 | Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
63 | hidden_act (:obj:`str` or :obj:`Callable`, `optional`, defaults to :obj:`"gelu"`):
64 | The non-linear activation function (function or string) in the encoder and pooler. If string,
65 | :obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported.
66 | hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
67 | The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
68 | attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
69 | The dropout ratio for the attention probabilities.
70 | max_position_embeddings (:obj:`int`, `optional`, defaults to 512):
71 | The maximum sequence length that this model might ever be used with. Typically set this to something large
72 | just in case (e.g., 512 or 1024 or 2048).
73 | type_vocab_size (:obj:`int`, `optional`, defaults to 2):
74 | The vocabulary size of the :obj:`token_type_ids` passed when calling :class:`~transformers.BertModel` or
75 | :class:`~transformers.TFBertModel`.
76 | initializer_range (:obj:`float`, `optional`, defaults to 0.02):
77 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
78 | layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
79 | The epsilon used by the layer normalization layers.
80 | position_embedding_type (:obj:`str`, `optional`, defaults to :obj:`"absolute"`):
81 | Type of position embedding. Choose one of :obj:`"absolute"`, :obj:`"relative_key"`,
82 | :obj:`"relative_key_query"`. For positional embeddings use :obj:`"absolute"`. For more information on
83 | :obj:`"relative_key"`, please refer to `Self-Attention with Relative Position Representations (Shaw et al.)
84 | `__. For more information on :obj:`"relative_key_query"`, please refer to
85 | `Method 4` in `Improve Transformer Models with Better Relative Position Embeddings (Huang et al.)
86 | `__.
87 | use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
88 | Whether or not the model should return the last key/values attentions (not used by all models). Only
89 | relevant if ``config.is_decoder=True``.
90 | classifier_dropout (:obj:`float`, `optional`):
91 | The dropout ratio for the classification head.
92 | """
93 | model_type = "hierarchical-transformer"
94 |
95 | def __init__(
96 | self,
97 | vocab_size=30522,
98 | hidden_size=768,
99 | max_sentences=64,
100 | max_sentence_size=128,
101 | model_max_length=8192,
102 | num_hidden_layers=12,
103 | num_attention_heads=12,
104 | intermediate_size=3072,
105 | hidden_act="gelu",
106 | hidden_dropout_prob=0.1,
107 | attention_probs_dropout_prob=0.1,
108 | max_position_embeddings=512,
109 | type_vocab_size=2,
110 | initializer_range=0.02,
111 | layer_norm_eps=1e-12,
112 | pad_token_id=0,
113 | position_embedding_type="absolute",
114 | encoder_layout=None,
115 | use_cache=True,
116 | classifier_dropout=None,
117 | **kwargs
118 | ):
119 | super().__init__(pad_token_id=pad_token_id, **kwargs)
120 |
121 | self.vocab_size = vocab_size
122 | self.hidden_size = hidden_size
123 | self.max_sentences = max_sentences
124 | self.max_sentence_size = max_sentence_size
125 | self.model_max_length = model_max_length
126 | self.encoder_layout = encoder_layout
127 | self.num_hidden_layers = num_hidden_layers
128 | self.num_attention_heads = num_attention_heads
129 | self.hidden_act = hidden_act
130 | self.intermediate_size = intermediate_size
131 | self.hidden_dropout_prob = hidden_dropout_prob
132 | self.attention_probs_dropout_prob = attention_probs_dropout_prob
133 | self.max_position_embeddings = max_position_embeddings
134 | self.type_vocab_size = type_vocab_size
135 | self.initializer_range = initializer_range
136 | self.layer_norm_eps = layer_norm_eps
137 | self.position_embedding_type = position_embedding_type
138 | self.use_cache = use_cache
139 | self.classifier_dropout = classifier_dropout
140 |
141 |
142 | class HATOnnxConfig(OnnxConfig):
143 | @property
144 | def inputs(self) -> Mapping[str, Mapping[int, str]]:
145 | return OrderedDict(
146 | [
147 | ("input_ids", {0: "batch", 1: "sequence"}),
148 | ("attention_mask", {0: "batch", 1: "sequence"}),
149 | ]
150 | )
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Hierarchical Attention Transformers (HATs)
2 |
3 | Implementation of Hierarchical Attention Transformers (HATs) presented in _"An Exploration of Hierarchical Attention Transformers for Efficient Long Document Classification"_ of Chalkidis et al. (2022). HAT use a hierarchical attention scheme, which is a combination of segment-wise and cross-segment attention operations. You can think segments as paragraphs or sentences.
4 |
5 |
6 |
7 |
8 | ## Citation
9 |
10 | If you use HAT in your research, please cite:
11 |
12 | [An Exploration of Hierarchical Attention Transformers for Efficient Long Document Classification](https://arxiv.org/abs/2210.05529). Ilias Chalkidis, Xiang Dai, Manos Fergadiotis, Prodromos Malakasiotis, and Desmond Elliott. 2022. arXiv:2210.05529 (Preprint).
13 |
14 | ```
15 | @misc{chalkidis-etal-2022-hat,
16 | url = {https://arxiv.org/abs/2210.05529},
17 | author = {Chalkidis, Ilias and Dai, Xiang and Fergadiotis, Manos and Malakasiotis, Prodromos and Elliott, Desmond},
18 | title = {An Exploration of Hierarchical Attention Transformers for Efficient Long Document Classification},
19 | publisher = {arXiv},
20 | year = {2022},
21 | }
22 | ```
23 |
24 | ## Implementation Details
25 |
26 | The repository supports several variants of the HAT architecture. The implementation of HAT is build on top of HuggingFace Transformers in Torch. The implementations is available at `models/hat/modelling_hat.py`. The layout of stacking segment-wise (SW) and cross-segment (CS) encoders is specified in the configuration file with the `encoder_layout` parameter.
27 |
28 |
29 |
30 |
31 | * **_Ad-Hoc (AH):_** An ad-hoc (partially pre-trained) HAT comprises an initial stack of shared L-SWE segment encoders from a pre-trained transformer-based model, followed by L-CSE ad-hoc segment-wise encoders. In this case the model initially encodes and contextualize token representations per segment, and then builds higher-order segment-level representationse.g., a 6-layer model has 12 effective transformer blocks (Layout: S/S/S/S/S/S/S/S/D/D/D/D).
32 |
33 | ```json
34 | "encoder_layout": {
35 | "0": {"sentence_encoder": true, "document_encoder": false},
36 | "1": {"sentence_encoder": true, "document_encoder": false},
37 | "2": {"sentence_encoder": true, "document_encoder": false},
38 | "3": {"sentence_encoder": true, "document_encoder": false},
39 | "4": {"sentence_encoder": true, "document_encoder": false},
40 | "5": {"sentence_encoder": true, "document_encoder": false},
41 | "6": {"sentence_encoder": true, "document_encoder": false},
42 | "7": {"sentence_encoder": true, "document_encoder": false},
43 | "8": {"sentence_encoder": false, "document_encoder": true},
44 | "9": {"sentence_encoder": false, "document_encoder": true},
45 | "10": {"sentence_encoder": false, "document_encoder": true},
46 | "11": {"sentence_encoder": false, "document_encoder": true}
47 | }
48 | ```
49 |
50 | * **_Interleaved (I):_** An interleaved HAT comprises a stack of L paired segment-wise and cross-segment encoders.
51 | e.g., a 6-layer model has 12 effective transformer blocks (Layout: SD/SD/SD/SD/SD/SD).
52 |
53 | ```json
54 | "encoder_layout": {
55 | "0": {"sentence_encoder": true, "document_encoder": true},
56 | "1": {"sentence_encoder": true, "document_encoder": true},
57 | "2": {"sentence_encoder": true, "document_encoder": true},
58 | "3": {"sentence_encoder": true, "document_encoder": true},
59 | "4": {"sentence_encoder": true, "document_encoder": true},
60 | "5": {"sentence_encoder": true, "document_encoder": true}
61 | }
62 | ```
63 | * **_Early-Contextualization (EC):_** n early-contextualized HAT comprises an initial stack of L-P paired segment-wise and cross-segment encoders, followed by a stack of L-SWE segment-wise encoders. In this case, cross-segment attention (contextualization) is only performed at the initial layers of the model,e.g., a 6-layer model and 8 effective transformer blocks (Layout: SD/SD/S/S/S/S).
64 |
65 | ```json
66 | "encoder_layout": {
67 | "0": {"sentence_encoder": true, "document_encoder": true},
68 | "1": {"sentence_encoder": true, "document_encoder": true},
69 | "2": {"sentence_encoder": true, "document_encoder": false},
70 | "3": {"sentence_encoder": true, "document_encoder": false},
71 | "4": {"sentence_encoder": true, "document_encoder": false},
72 | "5": {"sentence_encoder": true, "document_encoder": false}
73 | }
74 | ```
75 |
76 |
77 | * **_Late-Contextualization (LC):_** A late-contextualized HAT comprises an initial stack of $L_{\mathrm{SWE}}$ segment-wise encoders, followed by a stack of $L_{\mathrm{P}}$ paired segment and segment-wise encoders. In this case, cross-segment attention (contextualization) is only performed in the latter layers of the model, e.g., a 6-layer model and 8 effective transformer blocks (Layout: S/S/S/S/SD/SD).
78 |
79 |
80 | ```json
81 | "encoder_layout": {
82 | "0": {"sentence_encoder": true, "document_encoder": false},
83 | "1": {"sentence_encoder": true, "document_encoder": false},
84 | "2": {"sentence_encoder": true, "document_encoder": false},
85 | "3": {"sentence_encoder": true, "document_encoder": false},
86 | "4": {"sentence_encoder": true, "document_encoder": true},
87 | "5": {"sentence_encoder": true, "document_encoder": true}
88 | }
89 | ```
90 |
91 | In thi study, we examine the efficacy of 8 alternative layouts:
92 |
93 | ```json
94 | {
95 | 'I1': 'SD|SD|SD|SD|SD|SD',
96 | 'I2': 'S|SD|D|S|SD|D|S|SD|D',
97 | 'I3': 'S|SD|S|SD|S|SD|S|SD',
98 | 'LC1': 'S|S|S|S|S|S|SD|SD|SD',
99 | 'LC2': 'S|S|S|S|S|SD|D|S|SD|D',
100 | 'EC1': 'S|S|SD|D|S|SD|D|S|S|S',
101 | 'EC2': 'S|S|SD|SD|SD|S|S|S|S',
102 | 'AH': 'S|S|S|S|S|S|S|S|S|S|S|S',
103 | }
104 |
105 | ```
106 |
107 | ## Available Models on HuggingFace Hub
108 |
109 | | Model Name | Layers | Hidden Units | Attention Heads | Vocab | Parameters |
110 | |---------------------------------------------------------------------------------------------------------------------------------------------|--------|--------------|-----------------|-------|------------|
111 | | [`kiddothe2b/hierarchical-transformer-base-4096`](https://huggingface.co/kiddothe2b/hierarchical-transformer-base-4096) | 16 | 768 | 12 | 50K | 152M |
112 | | [`kiddothe2b/longformer-base-4096`](https://huggingface.co/kiddothe2b/longformer-base-4096) | 12 | 768 | 12 | 50K | 152M |
113 | | [`kiddothe2b/adhoc-hierarchical-transformer-base-4096`](https://huggingface.co/kiddothe2b/adhoc-hierarchical-transformer-base-4096) | 16 | 768 | 12 | 50K | 140M |
114 | | [`kiddothe2b/adhoc-hierarchical-transformer-I1-mini-1024`](https://huggingface.co/kiddothe2b/adhoc-hierarchical-transformer-I1-mini-1024) | 12 | 256 | 4 | 32K | 18M |
115 | | [`kiddothe2b/adhoc-hierarchical-transformer-I3-mini-1024`](https://huggingface.co/kiddothe2b/adhoc-hierarchical-transformer-I2-mini-1024) | 12 | 256 | 4 | 32K | 18M |
116 | | [`kiddothe2b/adhoc-hierarchical-transformer-LC1-mini-1024`](https://huggingface.co/kiddothe2b/adhoc-hierarchical-transformer-LC1-mini-1024) | 12 | 256 | 4 | 32K | 18M |
117 | | [`kiddothe2b/adhoc-hierarchical-transformer-EC2-mini-1024`](https://huggingface.co/kiddothe2b/adhoc-hierarchical-transformer-EC2-mini-1024) | 12 | 256 | 4 | 32K | 18M |
118 | | [`kiddothe2b/longformer-mini-1024`](https://huggingface.co/kiddothe2b/longformer-mini-1024) | 6 | 256 | 4 | 32K | 14M |
119 |
120 |
121 | ## Requirements
122 |
123 | Make sure that all required packages are installed:
124 |
125 | ```
126 | torch>=1.11.0
127 | transformers>=4.18.0
128 | datasets>=2.0.0
129 | tokenizers>=0.11.0
130 | scikit-learn>=1.0.0
131 | tqdm>=4.62.0
132 | nltk>=3.7.0
133 | ```
134 |
135 | ## How to run experiments?
136 |
137 | You can use the shell scripts provided in the `running_scripts` directory to pre-train new models or fine-tune the ones released.
138 |
139 | Try on Google Colab: https://colab.research.google.com/drive/15feh49wqBshgkcvbO6QypvJoa3dG6P5S?usp=sharing
140 |
141 | ### I still have open questions...
142 |
143 | Please post your question on [Discussions](https://github.com/coastalcph/hi-transformers/discussions) section or communicate with the corresponding author via e-mail.
144 |
--------------------------------------------------------------------------------
/models/util/efficiency_metrics_bert_models.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import warnings
3 |
4 | import torch
5 | import time
6 | from transformers import AutoConfig
7 |
8 | import numpy as np
9 | from data import DATA_DIR
10 | from models.hat import HATForMaskedLM, HATConfig, HATForSequenceClassification, \
11 | HATForMultipleChoice
12 | from models.longformer import LongformerForMaskedLM, LongformerModelForSequenceClassification, LongformerForMultipleChoice
13 |
14 | warnings.filterwarnings("ignore")
15 |
16 | LAYOUTS = {
17 | # 'f6': 'S|S|S|S|S|S',
18 | # 'f8': 'S|S|S|S|S|S',
19 | 's1': 'SD|SD|SD|SD|SD|SD',
20 | 's2': 'S|SD|D|S|SD|D|S|SD|D',
21 | 'p1': 'S|SD|S|SD|S|SD|S|SD',
22 | 'p2': 'S|S|SD|S|S|SD|S|S|SD',
23 | 'e1': 'SD|SD|SD|S|S|S|S|S|S',
24 | 'e2': 'S|SD|D|S|SD|D|S|S|S|S',
25 | 'l1': 'S|S|S|S|S|S|SD|SD|SD',
26 | 'l2': 'S|S|S|S|S|SD|D|S|SD|D',
27 | # 'b1': 'S|S|SD|D|S|SD|D|S|S|S',
28 | # 'b2': 'S|S|SD|SD|SD|S|S|S|S',
29 | }
30 |
31 |
32 | TASK_MODEL = {'lm': {'longformer': LongformerForMaskedLM, 'hilm': HATForMaskedLM},
33 | 'doc_cls': {'longformer': LongformerModelForSequenceClassification, 'hilm': HATForSequenceClassification},
34 | 'mc_qa': {'longformer': LongformerForMultipleChoice, 'hilm': HATForMultipleChoice},
35 | }
36 |
37 |
38 | def test_memory_usage(model, steps=40, batch_size=2, seq_length=1024, mode='test', task_type='lm'):
39 | model.to('cuda')
40 | if task_type != 'mc_qa':
41 | input_ids = torch.randint(1, 30000, (batch_size, seq_length), dtype=torch.long).to('cuda')
42 | input_ids[:, :: 128] = model.config.bos_token_id
43 | attention_mask = torch.ones((batch_size, seq_length), dtype=torch.int).to('cuda')
44 | else:
45 | input_ids = torch.randint(1, 30000, (batch_size, 2, seq_length), dtype=torch.long).to('cuda')
46 | input_ids[:, :: 128] = model.config.bos_token_id
47 | attention_mask = torch.ones((batch_size, 2, seq_length), dtype=torch.int).to('cuda')
48 | if mode == 'train':
49 | if task_type == 'lm':
50 | labels = input_ids.clone()
51 | else:
52 | labels = torch.ones((batch_size, ), dtype=torch.int).long().to('cuda')
53 | optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
54 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, steps)
55 | max_time = []
56 | max_mem = []
57 | for _ in range(steps):
58 | torch.cuda.reset_peak_memory_stats()
59 | start = time.time()
60 | if mode == 'train':
61 | outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
62 | loss = outputs.loss
63 | loss.backward(loss)
64 | optimizer.step()
65 | lr_scheduler.step()
66 | optimizer.zero_grad()
67 | else:
68 | with torch.no_grad():
69 | outputs = model(input_ids=input_ids, attention_mask=attention_mask)
70 | end = time.time()
71 | total_time = (end - start)
72 | max_time.append(total_time)
73 | max_mem.append(torch.cuda.max_memory_allocated() / 1e9)
74 |
75 | return np.mean(max_mem), np.mean(max_time)
76 |
77 |
78 | def efficiency_metrics():
79 | MAX_SENTENCE_LENGTH = 128
80 | CONFIGS = [{'num_hidden_layers': 6,
81 | 'hidden_size': 256,
82 | 'intermediate_size': 1024,
83 | 'num_attention_heads': 4},
84 | # {'num_hidden_layers': 6,
85 | # 'hidden_size': 768,
86 | # 'intermediate_size': 3072,
87 | # 'num_attention_heads': 12}
88 | ]
89 |
90 | for mode in ['train', 'test']:
91 | print(F'MODE: {mode.upper()}')
92 | for task in ['lm', 'doc_cls', 'mc_qa']:
93 | for CONFIG in CONFIGS:
94 | print('-' * 150)
95 | print(F'TASK: {task.upper()}\t'
96 | F'NUM LAYERS: {CONFIG["num_hidden_layers"]}\t'
97 | F'NUM HIDDEN: {CONFIG["hidden_size"]}\t'
98 | F'ATTENTION HEADS: {CONFIG["num_attention_heads"]}')
99 | print('-' * 150)
100 |
101 | for max_sentences in [8]:
102 | print('-' * 150)
103 | print(F'MAX SEQ LENGTH: {int(max_sentences * MAX_SENTENCE_LENGTH)}')
104 | print('-' * 150)
105 |
106 | lf_config = AutoConfig.from_pretrained('allenai/longformer-base-4096')
107 | lf_config.num_hidden_layers = CONFIG['num_hidden_layers']
108 | # Transformer parameters
109 | lf_config.hidden_size = CONFIG['hidden_size']
110 | lf_config.intermediate_size = CONFIG['intermediate_size']
111 | lf_config.num_attention_heads = CONFIG['num_attention_heads']
112 | # Vocabulary parameters
113 | lf_config.vocab_size = 32000
114 | lf_config.type_vocab_size = 2
115 | lf_config.model_max_length = int(MAX_SENTENCE_LENGTH * max_sentences)
116 | lf_config.max_position_embeddings = int(MAX_SENTENCE_LENGTH * max_sentences) + 2
117 | lf_config.attention_window = [128] * CONFIG['num_hidden_layers']
118 | lf_config.max_sentence_length = MAX_SENTENCE_LENGTH
119 | lf_config.max_sentences = max_sentences
120 | lf_config.cls_token_id = 100
121 | # load dummy longformer model
122 | htf_model = TASK_MODEL[task]['longformer'].from_config(lf_config)
123 | model_total_params = sum(p.numel() for p in htf_model.longformer.parameters() if p.requires_grad)
124 | model_total_params = model_total_params / 1e6
125 | memory_use, time_use = test_memory_usage(htf_model, seq_length=lf_config.model_max_length, mode=mode, task_type=task)
126 | lf_mem_use = copy.deepcopy(memory_use)
127 | lf_time_use = copy.deepcopy(time_use)
128 | print(f'Longformer model has {model_total_params:.1f}M number of parameters '
129 | f'and {memory_use:.2f}GB peak memory use and {time_use:.3f} batch/second!')
130 | print('-' * 150)
131 |
132 | for layout in LAYOUTS:
133 | ENCODER_LAYOUT = {}
134 | for idx, block_pattern in enumerate(LAYOUTS[layout].split('|')):
135 | ENCODER_LAYOUT[str(idx)] = {"sentence_encoder": True if 'S' in block_pattern else False,
136 | "document_encoder": True if 'D' in block_pattern else False}
137 |
138 | # load dummy config and change specifications
139 | htf_config = HATConfig.from_pretrained(f'{DATA_DIR}/hat')
140 | # Text length parameters
141 | htf_config.max_sentence_length = MAX_SENTENCE_LENGTH
142 | htf_config.max_sentences = max_sentences
143 | htf_config.max_position_embeddings = MAX_SENTENCE_LENGTH
144 | htf_config.model_max_length = int(MAX_SENTENCE_LENGTH * max_sentences)
145 | htf_config.num_hidden_layers = len(ENCODER_LAYOUT.keys())
146 | # Transformer parameters
147 | htf_config.hidden_size = CONFIG['hidden_size']
148 | htf_config.intermediate_size = CONFIG['intermediate_size']
149 | htf_config.num_attention_heads = CONFIG['num_attention_heads']
150 | htf_config.encoder_layout = ENCODER_LAYOUT
151 | # Vocabulary parameters
152 | htf_config.vocab_size = 32000
153 | htf_config.type_vocab_size = 2
154 |
155 | # load dummy hat model
156 | htf_model = TASK_MODEL[task]['hilm'].from_config(htf_config)
157 | model_total_params = sum(p.numel() for p in htf_model.hat.parameters() if p.requires_grad)
158 | model_total_params = model_total_params / 1e6
159 | memory_use, time_use = test_memory_usage(htf_model, seq_length=int(MAX_SENTENCE_LENGTH * max_sentences), mode=mode, task_type=task)
160 | mem_gains = (lf_mem_use / memory_use) - 1
161 | time_gains = (lf_time_use / time_use) - 1
162 | print(f'Hi-transformer model with layout {layout} has {model_total_params:.1f}M number of parameters '
163 | f'{memory_use:.2f}GB peak memory use (-{mem_gains*100:.2f}%) and {time_use:.3f} batch/second (-{time_gains*100:.2f}%)!')
164 |
165 |
166 | if __name__ == '__main__':
167 | efficiency_metrics()
168 |
--------------------------------------------------------------------------------
/models/util/benchmark_original_models.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import torch
4 | import time
5 | from transformers import AutoConfig, AutoModelForSequenceClassification
6 | from models.longformer import LongformerModelForSentenceClassification
7 | from models.big_bird import BigBirdModelForSentenceClassification
8 | import numpy as np
9 | warnings.filterwarnings("ignore")
10 |
11 |
12 | def test_memory_usage(model, steps=100, batch_size=2, seq_length=4096, mode='test', task_type='lm'):
13 | model.to('cuda')
14 | if task_type != 'mc_qa':
15 | input_ids = torch.randint(1, 40000, (batch_size, seq_length), dtype=torch.long).to('cuda')
16 | input_ids[:, :: 128] = model.config.bos_token_id
17 | attention_mask = torch.ones((batch_size, seq_length), dtype=torch.int).to('cuda')
18 | else:
19 | input_ids = torch.randint(1, 40000, (batch_size, 2, seq_length), dtype=torch.long).to('cuda')
20 | input_ids[:, :: 128] = model.config.bos_token_id
21 | attention_mask = torch.ones((batch_size, 2, seq_length), dtype=torch.int).to('cuda')
22 | if mode == 'train':
23 | if task_type == 'lm':
24 | labels = input_ids.clone()
25 | elif task_type == 'sent_cls':
26 | labels = torch.ones((batch_size, 32), dtype=torch.int).long().to('cuda')
27 | else:
28 | labels = torch.ones((batch_size, ), dtype=torch.int).long().to('cuda')
29 | optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
30 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, steps)
31 | max_time = []
32 | max_mem = []
33 | for _ in range(steps):
34 | torch.cuda.reset_peak_memory_stats()
35 | start = time.time()
36 | if mode == 'train':
37 | outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
38 | loss = outputs.loss
39 | loss.backward(loss)
40 | optimizer.step()
41 | lr_scheduler.step()
42 | optimizer.zero_grad()
43 | else:
44 | with torch.no_grad():
45 | outputs = model(input_ids=input_ids, attention_mask=attention_mask)
46 | end = time.time()
47 | total_time = (end - start)
48 | max_time.append(total_time)
49 | max_mem.append(torch.cuda.max_memory_allocated() / 1e9)
50 |
51 | return np.mean(max_mem), np.mean(max_time)
52 |
53 |
54 | def estimate_model_size():
55 | for mode in ['train', 'test']:
56 | print(F'MODE: {mode.upper()}')
57 | for task in ['doc_cls', 'sent_cls']:
58 | MAX_SENTENCE_LENGTH = 128
59 | roberta_config = AutoConfig.from_pretrained('roberta-base')
60 | print('-' * 150)
61 | print(F'TASK: {task.upper()}\t'
62 | F'NUM LAYERS: {roberta_config.num_hidden_layers}\t'
63 | F'NUM HIDDEN: {roberta_config.hidden_size}\t'
64 | F'ATTENTION HEADS: {roberta_config.num_attention_heads}')
65 | print('-' * 150)
66 | MAX_SENTENCES = 32
67 | print('-' * 150)
68 | print(F'MAX SEQ LENGTH: {int(MAX_SENTENCES * MAX_SENTENCE_LENGTH)}')
69 | print('-' * 150)
70 | # load dummy longformer model
71 | lf_config = AutoConfig.from_pretrained('allenai/longformer-base-4096')
72 | lf_config.num_labels = 2
73 | lf_config.max_sentence_length = 128
74 | lf_config.max_sentences = 32
75 | lf_config.cls_token_id = lf_config.bos_token_id
76 | lf_config.sep_token_id = lf_config.eos_token_id
77 | if task == 'doc_cls':
78 | htf_model = AutoModelForSequenceClassification.from_config(lf_config)
79 | else:
80 | htf_model = LongformerModelForSentenceClassification.from_config(lf_config)
81 | model_total_params = sum(p.numel() for p in htf_model.longformer.parameters() if p.requires_grad)
82 | model_total_params = model_total_params / 1e6
83 | memory_use, time_use = test_memory_usage(htf_model, seq_length=4096, mode=mode,
84 | task_type=task)
85 | print(f'Original Longformer (12-layer) model has {model_total_params:.1f}M number of parameters '
86 | f'and {memory_use:.2f}GB peak memory use and {time_use:.3f} batch/second!')
87 | print('-' * 150)
88 |
89 | # load dummy bigbird model
90 | lf_config = AutoConfig.from_pretrained('google/bigbird-roberta-base')
91 | lf_config.num_labels = 2
92 | lf_config.max_sentence_length = 128
93 | lf_config.max_sentences = 32
94 | lf_config.cls_token_id = lf_config.bos_token_id
95 | lf_config.sep_token_id = lf_config.eos_token_id
96 | if task == 'doc_cls':
97 | htf_model = AutoModelForSequenceClassification.from_config(lf_config)
98 | else:
99 | htf_model = BigBirdModelForSentenceClassification.from_config(lf_config)
100 | model_total_params = sum(p.numel() for p in htf_model.bert.parameters() if p.requires_grad)
101 | model_total_params = model_total_params / 1e6
102 | memory_use, time_use = test_memory_usage(htf_model, seq_length=4096, mode=mode,
103 | task_type=task)
104 | print(f'Original BigBird (12-layer) model has {model_total_params:.1f}M number of parameters '
105 | f'and {memory_use:.2f}GB peak memory use and {time_use:.3f} batch/second!')
106 | print('-' * 150)
107 |
108 |
109 | if __name__ == '__main__':
110 | estimate_model_size()
111 |
112 | # ------------------------------------------------------------------------------------------------------------------------------------------------------
113 | # MODE: TRAIN
114 | # ------------------------------------------------------------------------------------------------------------------------------------------------------
115 | # TASK: DOC_CLS NUM LAYERS: 12 NUM HIDDEN: 768 ATTENTION HEADS: 12
116 | # ------------------------------------------------------------------------------------------------------------------------------------------------------
117 | # MAX SEQ LENGTH: 4096
118 | # ------------------------------------------------------------------------------------------------------------------------------------------------------
119 | # Original Longformer (12-layer) model has 148.1M number of parameters and 17.76GB peak memory use and 0.852 batch/second!
120 | # ------------------------------------------------------------------------------------------------------------------------------------------------------
121 | # Original BigBird (12-layer) model has 127.5M number of parameters and 18.84GB peak memory use and 0.795 batch/second!
122 | # ------------------------------------------------------------------------------------------------------------------------------------------------------
123 | # TASK: SENT_CLS NUM LAYERS: 12 NUM HIDDEN: 768 ATTENTION HEADS: 12
124 | # ------------------------------------------------------------------------------------------------------------------------------------------------------
125 | # MAX SEQ LENGTH: 4096
126 | # ------------------------------------------------------------------------------------------------------------------------------------------------------
127 | # Original Longformer (12-layer) model has 148.1M number of parameters and 18.37GB peak memory use and 0.895 batch/second!
128 | # ------------------------------------------------------------------------------------------------------------------------------------------------------
129 | # Original BigBird (12-layer) model has 126.9M number of parameters and 18.84GB peak memory use and 0.795 batch/second!
130 | # ------------------------------------------------------------------------------------------------------------------------------------------------------
131 | # MODE: TEST
132 | # ------------------------------------------------------------------------------------------------------------------------------------------------------
133 | # TASK: DOC_CLS NUM LAYERS: 12 NUM HIDDEN: 768 ATTENTION HEADS: 12
134 | # ------------------------------------------------------------------------------------------------------------------------------------------------------
135 | # MAX SEQ LENGTH: 4096
136 | # ------------------------------------------------------------------------------------------------------------------------------------------------------
137 | # Original Longformer (12-layer) model has 148.1M number of parameters and 1.70GB peak memory use and 0.223 batch/second!
138 | # ------------------------------------------------------------------------------------------------------------------------------------------------------
139 | # Original BigBird (12-layer) model has 127.5M number of parameters and 1.76GB peak memory use and 0.207 batch/second!
140 | # ------------------------------------------------------------------------------------------------------------------------------------------------------
141 | # TASK: SENT_CLS NUM LAYERS: 12 NUM HIDDEN: 768 ATTENTION HEADS: 12
142 | # ------------------------------------------------------------------------------------------------------------------------------------------------------
143 | # MAX SEQ LENGTH: 4096
144 | # ------------------------------------------------------------------------------------------------------------------------------------------------------
145 | # Original Longformer (12-layer) model has 148.1M number of parameters and 1.71GB peak memory use and 0.236 batch/second!
146 | # ------------------------------------------------------------------------------------------------------------------------------------------------------
147 | # Original BigBird (12-layer) model has 126.9M number of parameters and 1.76GB peak memory use and 0.207 batch/second!
148 | # ------------------------------------------------------------------------------------------------------------------------------------------------------
--------------------------------------------------------------------------------
/models/util/efficiency_metrics_roberta.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import warnings
3 |
4 | import torch
5 | import time
6 | from transformers import AutoConfig
7 | import numpy as np
8 |
9 | from data import DATA_DIR
10 | from models.hat import HATForMaskedLM, HATConfig, HATForSequenceClassification, \
11 | HATForMultipleChoice, HATModelForSequentialSentenceClassification
12 | from models.longformer import LongformerForMaskedLM, LongformerModelForSequenceClassification, LongformerForMultipleChoice , \
13 | LongformerModelForSentenceClassification
14 | warnings.filterwarnings("ignore")
15 |
16 | LAYOUTS = {
17 | 'f12': 'S|S|S|S|S|S|S|S|S|S|S|SD|D|D|D',
18 | 'p1': 'S|S|SD|S|S|SD|S|S|SD|S|S|SD',
19 | }
20 |
21 | TASK_MODEL = {'lm': {'longformer': LongformerForMaskedLM, 'hilm': HATForMaskedLM},
22 | 'doc_cls': {'longformer': LongformerModelForSequenceClassification, 'hilm': HATForSequenceClassification},
23 | 'mc_qa': {'longformer': LongformerForMultipleChoice, 'hilm': HATForMultipleChoice},
24 | 'sent_cls': {'longformer': LongformerModelForSentenceClassification, 'hilm': HATModelForSequentialSentenceClassification},
25 | }
26 |
27 |
28 | def test_memory_usage(model, steps=40, batch_size=2, seq_length=4096, mode='test', task_type='lm'):
29 | model.to('cuda')
30 | if task_type != 'mc_qa':
31 | input_ids = torch.randint(1, 40000, (batch_size, seq_length), dtype=torch.long).to('cuda')
32 | input_ids[:, :: 128] = model.config.bos_token_id
33 | attention_mask = torch.ones((batch_size, seq_length), dtype=torch.int).to('cuda')
34 | else:
35 | input_ids = torch.randint(1, 40000, (batch_size, 2, seq_length), dtype=torch.long).to('cuda')
36 | input_ids[:, :: 128] = model.config.bos_token_id
37 | attention_mask = torch.ones((batch_size, 2, seq_length), dtype=torch.int).to('cuda')
38 | if mode == 'train':
39 | if task_type == 'lm':
40 | labels = input_ids.clone()
41 | elif task_type == 'sent_cls':
42 | labels = torch.ones((batch_size, 32), dtype=torch.int).long().to('cuda')
43 | else:
44 | labels = torch.ones((batch_size, ), dtype=torch.int).long().to('cuda')
45 | optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
46 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, steps)
47 | max_time = []
48 | max_mem = []
49 | for _ in range(steps):
50 | torch.cuda.reset_peak_memory_stats()
51 | start = time.time()
52 | if mode == 'train':
53 | outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
54 | loss = outputs.loss
55 | loss.backward(loss)
56 | optimizer.step()
57 | lr_scheduler.step()
58 | optimizer.zero_grad()
59 | else:
60 | with torch.no_grad():
61 | outputs = model(input_ids=input_ids, attention_mask=attention_mask)
62 | end = time.time()
63 | total_time = (end - start)
64 | max_time.append(total_time)
65 | max_mem.append(torch.cuda.max_memory_allocated() / 1e9)
66 |
67 | return np.mean(max_mem), np.mean(max_time)
68 |
69 |
70 | def efficiency_metrics():
71 | for mode in ['train', 'test']:
72 | print(F'MODE: {mode.upper()}')
73 | for task in TASK_MODEL:
74 | MAX_SENTENCE_LENGTH = 128
75 | roberta_config = AutoConfig.from_pretrained('roberta-base')
76 | print('-' * 150)
77 | print(F'TASK: {task.upper()}\t'
78 | F'NUM LAYERS: {roberta_config.num_hidden_layers}\t'
79 | F'NUM HIDDEN: {roberta_config.hidden_size}\t'
80 | F'ATTENTION HEADS: {roberta_config.num_attention_heads}')
81 | print('-' * 150)
82 | MAX_SENTENCES = 32
83 | print('-' * 150)
84 | print(F'MAX SEQ LENGTH: {int(MAX_SENTENCES * MAX_SENTENCE_LENGTH)}')
85 | print('-' * 150)
86 | lf_config = AutoConfig.from_pretrained('allenai/longformer-base-4096')
87 | lf_config.num_hidden_layers = 12
88 | # Transformer parameters
89 | lf_config.hidden_size = roberta_config.hidden_size
90 | lf_config.intermediate_size = roberta_config.intermediate_size
91 | lf_config.num_attention_heads = roberta_config.num_attention_heads
92 | # Vocabulary parameters
93 | lf_config.vocab_size = roberta_config.vocab_size
94 | lf_config.type_vocab_size = 2
95 | lf_config.model_max_length = int(MAX_SENTENCE_LENGTH * MAX_SENTENCES)
96 | lf_config.max_sentence_length = int(MAX_SENTENCE_LENGTH)
97 | lf_config.max_sentences = int(MAX_SENTENCES)
98 | lf_config.max_position_embeddings = int(MAX_SENTENCE_LENGTH * MAX_SENTENCES) + 2
99 | lf_config.attention_window = [128] * roberta_config.num_hidden_layers
100 | lf_config.cls_token_id = 100
101 | lf_config.num_labels = 2
102 | # load dummy longformer model
103 | htf_model = TASK_MODEL[task]['longformer'].from_config(lf_config)
104 | model_total_params = sum(p.numel() for p in htf_model.longformer.parameters() if p.requires_grad)
105 | model_total_params = model_total_params / 1e6
106 | memory_use, time_use = test_memory_usage(htf_model, seq_length=lf_config.model_max_length, mode=mode, task_type=task)
107 | lf_mem_use = copy.deepcopy(memory_use)
108 | lf_time_use = copy.deepcopy(time_use)
109 | print(f'Longformer (12-layer) model has {model_total_params:.1f}M number of parameters '
110 | f'and {memory_use:.2f}GB peak memory use and {time_use:.3f} batch/second!')
111 | print('-' * 150)
112 | for layout in LAYOUTS:
113 | ENCODER_LAYOUT = {}
114 | for idx, block_pattern in enumerate(LAYOUTS[layout].split('|')):
115 | ENCODER_LAYOUT[str(idx)] = {"sentence_encoder": True if 'S' in block_pattern else False,
116 | "document_encoder": True if 'D' in block_pattern else False}
117 |
118 | # load dummy config and change specifications
119 | htf_config = HATConfig.from_pretrained(f'{DATA_DIR}/hi-transformer')
120 | # Text length parameters
121 | htf_config.max_sentence_length = MAX_SENTENCE_LENGTH
122 | htf_config.MAX_SENTENCES = MAX_SENTENCES
123 | htf_config.max_position_embeddings = MAX_SENTENCE_LENGTH
124 | htf_config.model_max_length = int(MAX_SENTENCE_LENGTH * MAX_SENTENCES)
125 | htf_config.num_hidden_layers = len(ENCODER_LAYOUT.keys())
126 | # Transformer parameters
127 | htf_config.hidden_size = roberta_config.hidden_size
128 | htf_config.intermediate_size = roberta_config.intermediate_size
129 | htf_config.num_attention_heads = roberta_config.num_attention_heads
130 | htf_config.encoder_layout = ENCODER_LAYOUT
131 | # Vocabulary parameters
132 | htf_config.vocab_size = roberta_config.vocab_size
133 | htf_config.type_vocab_size = 2
134 | lf_config.num_labels = 2
135 | # load dummy hi-transformer model
136 | htf_model = TASK_MODEL[task]['hilm'].from_config(htf_config)
137 | model_total_params = sum(p.numel() for p in htf_model.hat.parameters() if p.requires_grad)
138 | model_total_params = model_total_params / 1e6
139 | memory_use, time_use = test_memory_usage(htf_model, seq_length=int(MAX_SENTENCE_LENGTH * MAX_SENTENCES), mode=mode, task_type=task)
140 | mem_gains = (lf_mem_use / memory_use) - 1
141 | time_gains = (lf_time_use / time_use) - 1
142 | print(f'Hi-transformer model with layout {layout} has {model_total_params:.1f}M number of parameters '
143 | f'{memory_use:.2f}GB peak memory use (-{mem_gains*100:.2f}%) and {time_use:.3f} batch/second (-{time_gains*100:.2f}%)!')
144 |
145 |
146 | if __name__ == '__main__':
147 | efficiency_metrics()
148 |
149 |
150 |
151 | # ------------------------------------------------------------------------------------------------------------------------------------------------------
152 | # MODE: TRAIN
153 | # ------------------------------------------------------------------------------------------------------------------------------------------------------
154 | # TASK: SENT_CLS NUM LAYERS: 12 NUM HIDDEN: 768 ATTENTION HEADS: 12
155 | # ------------------------------------------------------------------------------------------------------------------------------------------------------
156 | # MAX SEQ LENGTH: 4096
157 | # ------------------------------------------------------------------------------------------------------------------------------------------------------
158 | # Longformer (12-layer) model has 148.1M number of parameters and 10.77GB peak memory use and 0.459 batch/second!
159 | # ------------------------------------------------------------------------------------------------------------------------------------------------------
160 | # Hi-transformer model with layout f12 has 152.2M number of parameters 8.96GB peak memory use (-20.24%) and 0.343 batch/second (-33.80%)!
161 | # Hi-transformer model with layout p1 has 152.2M number of parameters 8.97GB peak memory use (-20.11%) and 0.344 batch/second (-33.42%)!
162 | # ------------------------------------------------------------------------------------------------------------------------------------------------------
163 | # MODE: TEST
164 | # ------------------------------------------------------------------------------------------------------------------------------------------------------
165 | # TASK: SENT_CLS NUM LAYERS: 12 NUM HIDDEN: 768 ATTENTION HEADS: 12
166 | # ------------------------------------------------------------------------------------------------------------------------------------------------------
167 | # MAX SEQ LENGTH: 4096
168 | # ------------------------------------------------------------------------------------------------------------------------------------------------------
169 | # Longformer (12-layer) model has 148.1M number of parameters and 0.98GB peak memory use and 0.131 batch/second!
170 | # ------------------------------------------------------------------------------------------------------------------------------------------------------
171 | # Hi-transformer model with layout f12 has 152.2M number of parameters 0.90GB peak memory use (-8.59%) and 0.115 batch/second (-13.96%)!
172 | # Hi-transformer model with layout p1 has 152.2M number of parameters 0.90GB peak memory use (-8.59%) and 0.114 batch/second (-14.45%)!
173 | # ------------------------------------------------------------------------------------------------------------------------------------------------------
174 |
--------------------------------------------------------------------------------
/models/big_bird/tokenization_big_bird.py:
--------------------------------------------------------------------------------
1 | """Tokenization classes for BigBird."""
2 | import torch
3 | from transformers import AutoTokenizer
4 | from transformers.models.big_bird.configuration_big_bird import BigBirdConfig
5 | from transformers.utils import logging
6 | from nltk import sent_tokenize
7 | logger = logging.get_logger(__name__)
8 |
9 |
10 | class BigbirdTokenizer:
11 | def __init__(self, tokenizer=None):
12 | self._tokenizer = tokenizer
13 | self.config = BigBirdConfig.from_pretrained(self._tokenizer.name_or_path)
14 | # hardcoded values
15 | self.config.max_sentence_size = 128
16 | self.config.max_sentence_length = 128
17 | self.config.max_sentences = 32
18 | self.config.model_max_length = 4096
19 | self._tokenizer.model_max_length = self.model_max_length
20 | self.type2id = {'input_ids': (self._tokenizer.sep_token_id, self._tokenizer.pad_token_id),
21 | 'token_type_ids': (0, 0),
22 | 'attention_mask': (1, 0),
23 | 'special_tokens_mask': (1, -100)}
24 |
25 | @property
26 | def model_max_length(self):
27 | return self.config.model_max_length
28 |
29 | @property
30 | def mask_token(self):
31 | return self._tokenizer.mask_token
32 |
33 | @property
34 | def mask_token_id(self):
35 | return self._tokenizer.mask_token_id
36 |
37 | @property
38 | def pad_token_id(self):
39 | return self._tokenizer.pad_token_id
40 |
41 | @property
42 | def cls_token_id(self):
43 | return self._tokenizer.cls_token_id
44 |
45 | @property
46 | def sep_token_id(self):
47 | return self._tokenizer.sep_token_id
48 |
49 | @property
50 | def vocab(self):
51 | return self._tokenizer.vocab
52 |
53 | def __len__(self):
54 | """
55 | Size of the full vocabulary with the added tokens.
56 | """
57 | return len(self._tokenizer)
58 |
59 | def pad(self, *args, **kwargs):
60 | return self._tokenizer.pad(*args, **kwargs)
61 |
62 | def convert_tokens_to_ids(self, *args, **kwargs):
63 | return self._tokenizer.convert_tokens_to_ids(*args, **kwargs)
64 |
65 | def batch_decode(self, *args, **kwargs):
66 | return self._tokenizer.batch_decode(*args, **kwargs)
67 |
68 | def decode(self, *args, **kwargs):
69 | return self._tokenizer.decode(*args, **kwargs)
70 |
71 | def tokenize(self, text, **kwargs):
72 | return self._tokenizer.tokenize(text, **kwargs)
73 |
74 | def encode(self, text, **kwargs):
75 | input_ids = self._tokenizer.encode_plus(text, add_special_tokens=False, **kwargs)
76 | input_ids = self.chunks(input_ids[: self.model_max_length - self.config.max_sentences],
77 | chunk_size=self.config.max_sentence_length, special_id=self.type2id['input_ids'])
78 |
79 | for idx, _ in enumerate(input_ids):
80 | input_ids[idx][0] = self._tokenizer.cls_token_id
81 |
82 | return input_ids
83 |
84 | def get_special_tokens_mask(self, *args, **kwargs):
85 | return self._tokenizer.get_special_tokens_mask(*args, **kwargs)
86 |
87 | @classmethod
88 | def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
89 | return cls(tokenizer=AutoTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs))
90 |
91 | def save_pretrained(self, *args, **kwargs):
92 | return self._tokenizer.save_pretrained( *args, **kwargs)
93 |
94 | def __call__(self, texts, **kwargs):
95 | greedy_chunking = kwargs.pop('greedy_chunking', None)
96 | if isinstance(texts[0], list):
97 | batch = self.auto_chunking(texts, **kwargs)
98 | else:
99 | if greedy_chunking:
100 | # fixed uniform chunking
101 | batch = self.uniform_chunking(texts, **kwargs)
102 | else:
103 | # dynamic sentence splitting and grouping
104 | batch = self.sentence_splitting(texts, **kwargs)
105 |
106 | for idx, _ in enumerate(batch['input_ids']):
107 | batch['input_ids'][idx][0] = self._tokenizer.cls_token_id
108 |
109 | if kwargs['padding']:
110 | batch = self.pad(batch,
111 | padding=kwargs['padding'],
112 | max_length=kwargs['max_length'],
113 | pad_to_multiple_of=kwargs['max_length'])
114 |
115 | return batch
116 |
117 | def auto_chunking(self, texts, **kwargs):
118 | batch = {}
119 | for text_idx, text in enumerate(texts):
120 | example_batch = self._tokenizer(text, add_special_tokens=False, **kwargs)
121 | for input_key in example_batch:
122 | key_inputs_list = []
123 | for idx, example in enumerate(example_batch[input_key][:self.config.max_sentences]):
124 | key_inputs_list.append(self.pad_sentence(example, special_id=self.type2id[input_key]))
125 | if isinstance(key_inputs_list[0], list):
126 | key_inputs_list = [token for sentence in key_inputs_list for token in sentence]
127 | else:
128 | key_inputs_list = torch.stack(key_inputs_list)
129 | if input_key in batch:
130 | batch[input_key].append(key_inputs_list)
131 | else:
132 | batch[input_key] = [key_inputs_list]
133 |
134 | return batch
135 |
136 | def uniform_chunking(self, texts, **kwargs):
137 | original_batch = self._tokenizer(texts, add_special_tokens=False, **kwargs)
138 | batch = {input_type: [] for input_type in original_batch}
139 | for input_type in original_batch:
140 | fixed_batch = []
141 | for example in original_batch[input_type]:
142 | fixed_batch.append(self.chunks(example[: self.model_max_length - self.config.max_sentences],
143 | chunk_size=self.config.max_sentence_length,
144 | special_id=self.type2id[input_type]))
145 | batch[input_type] = fixed_batch if isinstance(fixed_batch[0], list) else torch.stack(fixed_batch)
146 | return batch
147 |
148 | def chunks(self, flat_inputs, chunk_size=128, special_id=0):
149 | if isinstance(flat_inputs, list):
150 | return self.list_chunks(flat_inputs, chunk_size, special_id)
151 | else:
152 | return self.tensor_chunks(flat_inputs, chunk_size, special_id)
153 |
154 | def list_chunks(self, flat_inputs, chunk_size=128, special_id=(0, 0)):
155 | """Yield successive n-sized chunks from lst."""
156 | structured_inputs = [[special_id[0] if sum(flat_inputs[i:i + chunk_size-1]) else special_id[1]]
157 | + flat_inputs[i:i + chunk_size-1] for i in range(0, len(flat_inputs), chunk_size-1)]
158 | return [token_input for sentence_inputs in structured_inputs for token_input in sentence_inputs]
159 |
160 | def tensor_chunks(self, flat_inputs, chunk_size=128, special_id=(0, 0)):
161 | """Yield successive n-sized chunks from lst."""
162 | structured_inputs = torch.stack([torch.cat((torch.tensor([special_id[0] if flat_inputs[i:i + chunk_size-1].sum() else special_id[1]], dtype=torch.int),
163 | flat_inputs[i:i + chunk_size-1])) for i in range(0, len(flat_inputs), chunk_size-1)])
164 | return structured_inputs.reshape(-1)
165 |
166 | def sentence_splitting(self, texts, **kwargs):
167 | fixed_batch = []
168 | doc_out = {}
169 | for text in texts:
170 | # sentence splitting
171 | sentences = sent_tokenize(text)
172 | # tokenization of sentences
173 | sentences = self._tokenizer(sentences, add_special_tokens=False, padding=False, truncation=False)
174 | # sentence grouping - merging short sentences to minimize padding
175 | doc_out = self.sentence_grouping(sentences)
176 | fixed_batch.append(doc_out)
177 | # batchify examples
178 | batch = {input_type: [] for input_type in doc_out}
179 | for input_type in batch:
180 | batch[input_type] = [example[input_type] for example in fixed_batch]
181 | if not isinstance(batch[input_type][0], list):
182 | batch[input_type] = torch.stack(batch[input_type])
183 |
184 | return batch
185 |
186 | def sentence_grouping(self, sentences):
187 | doc_out = {input_type: [] for input_type in sentences}
188 | for input_type in sentences:
189 | tmp_doc = []
190 | tmp_sentence = []
191 | for example in sentences[input_type]:
192 | if len(tmp_doc) >= self.config.max_sentences:
193 | break
194 | if len(tmp_sentence) + len(example) <= self.config.max_sentence_length - 1:
195 | tmp_sentence.extend(example)
196 | else:
197 | tmp_doc.append(self.pad_sentence(tmp_sentence if len(tmp_sentence) else example,
198 | chunk_size=self.config.max_sentence_length,
199 | special_id=self.type2id[input_type]))
200 | tmp_sentence = example if len(tmp_sentence) else example[self.config.max_sentence_length:]
201 | if len(tmp_sentence) and len(tmp_doc) < self.config.max_sentences:
202 | tmp_doc.append(self.pad_sentence(tmp_sentence,
203 | chunk_size=self.config.max_sentence_length,
204 | special_id=self.type2id[input_type]))
205 | doc_out[input_type] = [token for sentence in tmp_doc for token in sentence]
206 | return doc_out
207 |
208 | def pad_sentence(self, flat_input, chunk_size=128, special_id=(0, 0)):
209 | if isinstance(flat_input, list):
210 | return [special_id[0]] + flat_input[:chunk_size-1] + [self.pad_token_id] * max(0, chunk_size - len(flat_input) - 1)
211 | else:
212 | return torch.cat((torch.tensor([special_id[0] if flat_input[:chunk_size-1].sum()
213 | else special_id[1]], dtype=torch.int),
214 | flat_input[:chunk_size-1],
215 | torch.tensor([self.pad_token_id] * max(0, chunk_size - len(flat_input) - 1), dtype=torch.int)
216 | ))
217 |
218 |
219 | if __name__ == "__main__":
220 | tokenizer = BigbirdTokenizer.from_pretrained('roberta-base')
221 | inputs = tokenizer([' '.join(['dog'] * 8192),
222 | ' '.join(['cat'] * 7000),
223 | ' '.join(['mouse'] * 5000)],
224 | padding=True, max_length=8192, truncation=True
225 | )
226 | print()
227 |
--------------------------------------------------------------------------------
/models/longformer/tokenization_longformer.py:
--------------------------------------------------------------------------------
1 | """Tokenization classes for Longformer."""
2 | import torch
3 | from transformers import AutoTokenizer
4 | from transformers.models.longformer.configuration_longformer import LongformerConfig
5 | from transformers.utils import logging
6 | try:
7 | from nltk import sent_tokenize
8 | except:
9 | raise Exception('NLTK is not installed! Install it with `pip install nltk`...')
10 | logger = logging.get_logger(__name__)
11 |
12 |
13 | class LongformerTokenizer:
14 | def __init__(self, tokenizer=None):
15 | self._tokenizer = tokenizer
16 | self.config = LongformerConfig.from_pretrained(self._tokenizer.name_or_path)
17 | # hardcoded values
18 | self.config.max_sentence_size = 128
19 | self.config.max_sentence_length = 128
20 | self.config.max_sentences = 32
21 | self.config.model_max_length = 4096
22 | self._tokenizer.model_max_length = self.model_max_length
23 | self.type2id = {'input_ids': (self._tokenizer.sep_token_id, self._tokenizer.pad_token_id),
24 | 'token_type_ids': (0, 0),
25 | 'attention_mask': (1, 0),
26 | 'special_tokens_mask': (1, -100)}
27 |
28 | @property
29 | def model_max_length(self):
30 | return self.config.model_max_length
31 |
32 | @property
33 | def mask_token(self):
34 | return self._tokenizer.mask_token
35 |
36 | @property
37 | def mask_token_id(self):
38 | return self._tokenizer.mask_token_id
39 |
40 | @property
41 | def pad_token_id(self):
42 | return self._tokenizer.pad_token_id
43 |
44 | @property
45 | def cls_token_id(self):
46 | return self._tokenizer.cls_token_id
47 |
48 | @property
49 | def sep_token_id(self):
50 | return self._tokenizer.sep_token_id
51 |
52 | @property
53 | def vocab(self):
54 | return self._tokenizer.vocab
55 |
56 | def __len__(self):
57 | """
58 | Size of the full vocabulary with the added tokens.
59 | """
60 | return len(self._tokenizer)
61 |
62 | def pad(self, *args, **kwargs):
63 | return self._tokenizer.pad(*args, **kwargs)
64 |
65 | def convert_tokens_to_ids(self, *args, **kwargs):
66 | return self._tokenizer.convert_tokens_to_ids(*args, **kwargs)
67 |
68 | def batch_decode(self, *args, **kwargs):
69 | return self._tokenizer.batch_decode(*args, **kwargs)
70 |
71 | def decode(self, *args, **kwargs):
72 | return self._tokenizer.decode(*args, **kwargs)
73 |
74 | def tokenize(self, text, **kwargs):
75 | return self._tokenizer.tokenize(text, **kwargs)
76 |
77 | def encode(self, text, **kwargs):
78 | input_ids = self._tokenizer.encode_plus(text, add_special_tokens=False, **kwargs)
79 | input_ids = self.chunks(input_ids[: self.model_max_length - self.config.max_sentences],
80 | chunk_size=self.config.max_sentence_length, special_id=self.type2id['input_ids'])
81 |
82 | for idx, _ in enumerate(input_ids):
83 | input_ids[idx][0] = self._tokenizer.cls_token_id
84 |
85 | return input_ids
86 |
87 | def get_special_tokens_mask(self, *args, **kwargs):
88 | return self._tokenizer.get_special_tokens_mask(*args, **kwargs)
89 |
90 | @classmethod
91 | def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
92 | return cls(tokenizer=AutoTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs))
93 |
94 | def save_pretrained(self, *args, **kwargs):
95 | return self._tokenizer.save_pretrained( *args, **kwargs)
96 |
97 | def __call__(self, texts, **kwargs):
98 | greedy_chunking = kwargs.pop('greedy_chunking', None)
99 | if isinstance(texts[0], list):
100 | batch = self.auto_chunking(texts, **kwargs)
101 | else:
102 | if greedy_chunking:
103 | # fixed uniform chunking
104 | batch = self.uniform_chunking(texts, **kwargs)
105 | else:
106 | # dynamic sentence splitting and grouping
107 | batch = self.sentence_splitting(texts, **kwargs)
108 |
109 | for idx, _ in enumerate(batch['input_ids']):
110 | batch['input_ids'][idx][0] = self._tokenizer.cls_token_id
111 |
112 | if kwargs['padding']:
113 | batch = self.pad(batch,
114 | padding=kwargs['padding'],
115 | max_length=kwargs['max_length'],
116 | pad_to_multiple_of=kwargs['max_length'])
117 |
118 | return batch
119 |
120 | def auto_chunking(self, texts, **kwargs):
121 | batch = {}
122 | for text_idx, text in enumerate(texts):
123 | example_batch = self._tokenizer(text, add_special_tokens=False, **kwargs)
124 | for input_key in example_batch:
125 | key_inputs_list = []
126 | for idx, example in enumerate(example_batch[input_key][:self.config.max_sentences]):
127 | key_inputs_list.append(self.pad_sentence(example, special_id=self.type2id[input_key]))
128 | if isinstance(key_inputs_list[0], list):
129 | key_inputs_list = [token for sentence in key_inputs_list for token in sentence]
130 | else:
131 | key_inputs_list = torch.stack(key_inputs_list)
132 | if input_key in batch:
133 | batch[input_key].append(key_inputs_list)
134 | else:
135 | batch[input_key] = [key_inputs_list]
136 |
137 | return batch
138 |
139 | def uniform_chunking(self, texts, **kwargs):
140 | original_batch = self._tokenizer(texts, add_special_tokens=False, **kwargs)
141 | batch = {input_type: [] for input_type in original_batch}
142 | for input_type in original_batch:
143 | fixed_batch = []
144 | for example in original_batch[input_type]:
145 | fixed_batch.append(self.chunks(example[: self.model_max_length - self.config.max_sentences],
146 | chunk_size=self.config.max_sentence_length,
147 | special_id=self.type2id[input_type]))
148 | batch[input_type] = fixed_batch if isinstance(fixed_batch[0], list) else torch.stack(fixed_batch)
149 | return batch
150 |
151 | def chunks(self, flat_inputs, chunk_size=128, special_id=0):
152 | if isinstance(flat_inputs, list):
153 | return self.list_chunks(flat_inputs, chunk_size, special_id)
154 | else:
155 | return self.tensor_chunks(flat_inputs, chunk_size, special_id)
156 |
157 | def list_chunks(self, flat_inputs, chunk_size=128, special_id=(0, 0)):
158 | """Yield successive n-sized chunks from lst."""
159 | structured_inputs = [[special_id[0] if sum(flat_inputs[i:i + chunk_size-1]) else special_id[1]]
160 | + flat_inputs[i:i + chunk_size-1] for i in range(0, len(flat_inputs), chunk_size-1)]
161 | return [token_input for sentence_inputs in structured_inputs for token_input in sentence_inputs]
162 |
163 | def tensor_chunks(self, flat_inputs, chunk_size=128, special_id=(0, 0)):
164 | """Yield successive n-sized chunks from lst."""
165 | structured_inputs = torch.stack([torch.cat((torch.tensor([special_id[0] if flat_inputs[i:i + chunk_size-1].sum() else special_id[1]], dtype=torch.int),
166 | flat_inputs[i:i + chunk_size-1])) for i in range(0, len(flat_inputs), chunk_size-1)])
167 | return structured_inputs.reshape(-1)
168 |
169 | def sentence_splitting(self, texts, **kwargs):
170 | fixed_batch = []
171 | doc_out = {}
172 | for text in texts:
173 | # sentence splitting
174 | sentences = sent_tokenize(text)
175 | # tokenization of sentences
176 | sentences = self._tokenizer(sentences, add_special_tokens=False, padding=False, truncation=False)
177 | # sentence grouping - merging short sentences to minimize padding
178 | doc_out = self.sentence_grouping(sentences)
179 | fixed_batch.append(doc_out)
180 | # batchify examples
181 | batch = {input_type: [] for input_type in doc_out}
182 | for input_type in batch:
183 | batch[input_type] = [example[input_type] for example in fixed_batch]
184 | if not isinstance(batch[input_type][0], list):
185 | batch[input_type] = torch.stack(batch[input_type])
186 |
187 | return batch
188 |
189 | def sentence_grouping(self, sentences):
190 | doc_out = {input_type: [] for input_type in sentences}
191 | for input_type in sentences:
192 | tmp_doc = []
193 | tmp_sentence = []
194 | for example in sentences[input_type]:
195 | if len(tmp_doc) >= self.config.max_sentences:
196 | break
197 | if len(tmp_sentence) + len(example) <= self.config.max_sentence_length - 1:
198 | tmp_sentence.extend(example)
199 | else:
200 | tmp_doc.append(self.pad_sentence(tmp_sentence if len(tmp_sentence) else example,
201 | chunk_size=self.config.max_sentence_length,
202 | special_id=self.type2id[input_type]))
203 | tmp_sentence = example if len(tmp_sentence) else example[self.config.max_sentence_length:]
204 | if len(tmp_sentence) and len(tmp_doc) < self.config.max_sentences:
205 | tmp_doc.append(self.pad_sentence(tmp_sentence,
206 | chunk_size=self.config.max_sentence_length,
207 | special_id=self.type2id[input_type]))
208 | doc_out[input_type] = [token for sentence in tmp_doc for token in sentence]
209 | return doc_out
210 |
211 | def pad_sentence(self, flat_input, chunk_size=128, special_id=(0, 0)):
212 | if isinstance(flat_input, list):
213 | return [special_id[0]] + flat_input[:chunk_size-1] + [self.pad_token_id] * max(0, chunk_size - len(flat_input) - 1)
214 | else:
215 | return torch.cat((torch.tensor([special_id[0] if flat_input[:chunk_size-1].sum()
216 | else special_id[1]], dtype=torch.int),
217 | flat_input[:chunk_size-1],
218 | torch.tensor([self.pad_token_id] * max(0, chunk_size - len(flat_input) - 1), dtype=torch.int)
219 | ))
220 |
221 |
222 | if __name__ == "__main__":
223 | tokenizer = LongformerTokenizer.from_pretrained('roberta-base')
224 | inputs = tokenizer([' '.join(['dog'] * 8192),
225 | ' '.join(['cat'] * 7000),
226 | ' '.join(['mouse'] * 5000)],
227 | padding=True, max_length=8192, truncation=True
228 | )
229 | print()
230 |
--------------------------------------------------------------------------------
/models/hat/tokenization_hat.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # http://www.apache.org/licenses/LICENSE-2.0
7 | #
8 | # Unless required by applicable law or agreed to in writing, software
9 | # distributed under the License is distributed on an "AS IS" BASIS,
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | # See the License for the specific language governing permissions and
12 | # limitations under the License.
13 | """Tokenization classes for HAT."""
14 | import torch
15 | from transformers import RobertaTokenizer, BertTokenizer
16 | from .configuration_hat import HATConfig
17 | from transformers.utils import logging
18 | try:
19 | from nltk import sent_tokenize
20 | except:
21 | raise Exception('NLTK is not installed! Install it with `pip install nltk`...')
22 | logger = logging.get_logger(__name__)
23 |
24 |
25 | class HATTokenizer:
26 | def __init__(self, tokenizer=None):
27 | self._tokenizer = tokenizer
28 | self.config = HATConfig.from_pretrained(self._tokenizer.name_or_path)
29 | self._tokenizer.model_max_length = self.model_max_length
30 | self.type2id = {'input_ids': (self._tokenizer.cls_token_id, self._tokenizer.pad_token_id),
31 | 'token_type_ids': (0, 0),
32 | 'attention_mask': (1, 0),
33 | 'special_tokens_mask': (1, -100)}
34 |
35 | @property
36 | def model_max_length(self):
37 | return self.config.model_max_length
38 |
39 | @property
40 | def mask_token(self):
41 | return self._tokenizer.mask_token
42 |
43 | @property
44 | def mask_token_id(self):
45 | return self._tokenizer.mask_token_id
46 |
47 | @property
48 | def pad_token_id(self):
49 | return self._tokenizer.pad_token_id
50 |
51 | @property
52 | def cls_token_id(self):
53 | return self._tokenizer.cls_token_id
54 |
55 | @property
56 | def sep_token_id(self):
57 | return self._tokenizer.sep_token_id
58 |
59 | @property
60 | def vocab(self):
61 | return self._tokenizer.vocab
62 |
63 | def __len__(self):
64 | """
65 | Size of the full vocabulary with the added tokens.
66 | """
67 | return len(self._tokenizer)
68 |
69 | def pad(self, *args, **kwargs):
70 | return self._tokenizer.pad(*args, **kwargs)
71 |
72 | def convert_tokens_to_ids(self, *args, **kwargs):
73 | return self._tokenizer.convert_tokens_to_ids(*args, **kwargs)
74 |
75 | def batch_decode(self, *args, **kwargs):
76 | return self._tokenizer.batch_decode(*args, **kwargs)
77 |
78 | def decode(self, *args, **kwargs):
79 | return self._tokenizer.decode(*args, **kwargs)
80 |
81 | def tokenize(self, text, **kwargs):
82 | return self._tokenizer.tokenize(text, **kwargs)
83 |
84 | def encode(self, text, **kwargs):
85 | input_ids = self._tokenizer.encode_plus(text, add_special_tokens=False, **kwargs)
86 | input_ids = self.chunks(input_ids[: self.model_max_length - self.config.max_sentences],
87 | chunk_size=self.config.max_sentence_length, special_id=self.type2id['input_ids'])
88 | return input_ids
89 |
90 | def get_special_tokens_mask(self, *args, **kwargs):
91 | return self._tokenizer.get_special_tokens_mask(*args, **kwargs)
92 |
93 | @classmethod
94 | def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
95 | try:
96 | tokenizer = RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
97 | except:
98 | tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
99 | return cls(tokenizer=tokenizer)
100 |
101 | def save_pretrained(self, *args, **kwargs):
102 | return self._tokenizer.save_pretrained( *args, **kwargs)
103 |
104 | def __call__(self, text, **kwargs):
105 | greedy_chunking = kwargs.pop('greedy_chunking', None)
106 | text_pair = kwargs.pop('text_pair', None)
107 | if isinstance(text[0], list):
108 | batch = self.auto_chunking(text, **kwargs)
109 | elif greedy_chunking:
110 | # fixed uniform chunking
111 | batch = self.uniform_chunking(text, **kwargs)
112 | else:
113 | # dynamic sentence splitting and grouping
114 | batch = self.sentence_splitting(text, **kwargs)
115 |
116 | if text_pair:
117 | batch_b = self._tokenizer(text_pair, add_special_tokens=False,
118 | padding=False, truncation=False)
119 | for idx, sample in enumerate(batch['input_ids']):
120 | n_sentences = sum(sample[::self.config.max_sentence_size])
121 | for input_key in batch:
122 | batch[input_key][idx][self.config.max_sentence_size * n_sentences:
123 | self.config.max_sentence_size * (n_sentences + 1)] = \
124 | self.pad_sentence(batch_b[input_key][idx],
125 | special_id=(self.sep_token_id, self.pad_token_id)
126 | if input_key == 'input_ids' else self.type2id[input_key])
127 |
128 | return batch
129 |
130 | def uniform_chunking(self, texts, **kwargs):
131 | original_batch = self._tokenizer(texts, add_special_tokens=False, **kwargs)
132 | batch = {input_type: [] for input_type in original_batch}
133 | for input_type in original_batch:
134 | fixed_batch = []
135 | for example in original_batch[input_type]:
136 | fixed_batch.append(self.chunks(example[: self.model_max_length - self.config.max_sentences],
137 | chunk_size=self.config.max_sentence_length,
138 | special_id=self.type2id[input_type]))
139 | batch[input_type] = fixed_batch if isinstance(fixed_batch[0], list) else torch.stack(fixed_batch)
140 |
141 | if kwargs['padding']:
142 | batch = self.pad(batch,
143 | padding=kwargs['padding'],
144 | max_length=kwargs['max_length'],
145 | pad_to_multiple_of=kwargs['max_length'])
146 |
147 | return batch
148 |
149 | def auto_chunking(self, texts, **kwargs):
150 | batch = {}
151 | for text_idx, text in enumerate(texts):
152 | example_batch = self._tokenizer(text, add_special_tokens=False, **kwargs)
153 | for input_key in example_batch:
154 | key_inputs_list = []
155 | for idx, example in enumerate(example_batch[input_key][:self.config.max_sentences]):
156 | key_inputs_list.append(self.pad_sentence(example, special_id=self.type2id[input_key]))
157 | if isinstance(key_inputs_list[0], list):
158 | key_inputs_list = [token for sentence in key_inputs_list for token in sentence]
159 | else:
160 | key_inputs_list = torch.stack([token for sentence in key_inputs_list for token in sentence])
161 | if input_key in batch:
162 | batch[input_key].append(key_inputs_list)
163 | else:
164 | batch[input_key] = [key_inputs_list]
165 |
166 | if kwargs['padding']:
167 | batch = self.pad(batch,
168 | padding=kwargs['padding'],
169 | max_length=kwargs['max_length'],
170 | pad_to_multiple_of=kwargs['max_length'])
171 |
172 | return batch
173 |
174 | def chunks(self, flat_inputs, chunk_size=128, special_id=0):
175 | if isinstance(flat_inputs, list):
176 | return self.list_chunks(flat_inputs, chunk_size, special_id)
177 | else:
178 | return self.tensor_chunks(flat_inputs, chunk_size, special_id)
179 |
180 | def list_chunks(self, flat_inputs, chunk_size=128, special_id=(0, 0)):
181 | """Yield successive n-sized chunks from lst."""
182 | structured_inputs = [[special_id[0] if sum(flat_inputs[i:i + chunk_size-1]) else special_id[1]]
183 | + flat_inputs[i:i + chunk_size-1] for i in range(0, len(flat_inputs), chunk_size-1)]
184 | return [token_input for sentence_inputs in structured_inputs for token_input in sentence_inputs]
185 |
186 | def tensor_chunks(self, flat_inputs, chunk_size=128, special_id=(0, 0)):
187 | """Yield successive n-sized chunks from lst."""
188 | structured_inputs = torch.stack([torch.cat((torch.tensor([special_id[0] if flat_inputs[i:i + chunk_size-1].sum() else special_id[1]], dtype=torch.int),
189 | flat_inputs[i:i + chunk_size-1])) for i in range(0, len(flat_inputs), chunk_size-1)])
190 | return structured_inputs.reshape(-1)
191 |
192 | def sentence_splitting(self, texts, **kwargs):
193 | fixed_batch = []
194 | doc_out = {}
195 | for text in texts:
196 | # sentence splitting
197 | sentences = sent_tokenize(text)
198 | # tokenization of sentences
199 | sentences = self._tokenizer(sentences, add_special_tokens=False, padding=False, truncation=False)
200 | # sentence grouping - merging short sentences to minimize padding
201 | doc_out = self.sentence_grouping(sentences)
202 | fixed_batch.append(doc_out)
203 | # batchify examples
204 | batch = {input_type: [] for input_type in doc_out}
205 | for input_type in batch:
206 | batch[input_type] = [example[input_type] for example in fixed_batch]
207 | if not isinstance(batch[input_type][0], list):
208 | batch[input_type] = torch.stack(batch[input_type])
209 |
210 | if kwargs['padding']:
211 | batch = self.pad(batch,
212 | padding=kwargs['padding'],
213 | max_length=kwargs['max_length'],
214 | pad_to_multiple_of=kwargs['max_length'])
215 |
216 | return batch
217 |
218 | def sentence_grouping(self, sentences):
219 | doc_out = {input_type: [] for input_type in sentences}
220 | for input_type in sentences:
221 | tmp_doc = []
222 | tmp_sentence = []
223 | for example in sentences[input_type]:
224 | if len(tmp_doc) >= self.config.max_sentences:
225 | break
226 | if len(tmp_sentence) + len(example) <= self.config.max_sentence_length - 1:
227 | tmp_sentence.extend(example)
228 | else:
229 | tmp_doc.append(self.pad_sentence(tmp_sentence if len(tmp_sentence) else example,
230 | chunk_size=self.config.max_sentence_length,
231 | special_id=self.type2id[input_type]))
232 | tmp_sentence = example if len(tmp_sentence) else example[self.config.max_sentence_length:]
233 | if len(tmp_sentence) and len(tmp_doc) < self.config.max_sentences:
234 | tmp_doc.append(self.pad_sentence(tmp_sentence,
235 | chunk_size=self.config.max_sentence_length,
236 | special_id=self.type2id[input_type]))
237 | doc_out[input_type] = [token for sentence in tmp_doc for token in sentence]
238 | return doc_out
239 |
240 | def pad_sentence(self, flat_input, chunk_size=128, special_id=(0, 0)):
241 | if isinstance(flat_input, list):
242 | return [special_id[0]] + flat_input[:chunk_size-1] + [self.pad_token_id] * max(0, chunk_size - len(flat_input) - 1)
243 | else:
244 | return torch.cat((torch.tensor([special_id[0] if flat_input[:chunk_size-1].sum()
245 | else special_id[1]], dtype=torch.int),
246 | flat_input[:chunk_size-1],
247 | torch.tensor([self.pad_token_id] * max(0, chunk_size - len(flat_input) - 1), dtype=torch.int)
248 | ))
249 |
250 | @classmethod
251 | def register_for_auto_class(cls, auto_class="AutoModel"):
252 | """
253 | Register this class with a given auto class. This should only be used for custom models as the ones in the
254 | library are already mapped with an auto class.
255 |
256 | This API is experimental and may have some slight breaking changes in the next releases.
257 |
258 | Args:
259 | auto_class (`str` or `type`, *optional*, defaults to `"TFAutoModel"`):
260 | The auto class to register this new model with.
261 | """
262 | if not isinstance(auto_class, str):
263 | auto_class = auto_class.__name__
264 |
265 | import transformers.models.auto as auto_module
266 |
267 | if not hasattr(auto_module, auto_class):
268 | raise ValueError(f"{auto_class} is not a valid auto class.")
269 |
270 | cls._auto_class = auto_class
271 |
272 |
--------------------------------------------------------------------------------
/evaluation/run_sequential_sentence_classification.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | """ Finetuning models on the ECtHR Argument Mining dataset (e.g. Bert, RoBERTa, LEGAL-BERT)."""
4 |
5 | import logging
6 | import os
7 | import random
8 | import sys
9 | from dataclasses import dataclass, field
10 | from typing import Optional
11 |
12 | import datasets
13 | from datasets import load_dataset
14 | from sklearn.metrics import f1_score, classification_report
15 | import glob
16 | import shutil
17 |
18 | import transformers
19 | from transformers import (
20 | Trainer,
21 | AutoConfig,
22 | EvalPrediction,
23 | HfArgumentParser,
24 | TrainingArguments,
25 | set_seed,
26 | EarlyStoppingCallback,
27 | )
28 | from transformers.trainer_utils import get_last_checkpoint
29 | from transformers.utils import check_min_version
30 | from transformers.utils.versions import require_version
31 | from data_collator import DataCollatorForMultiLabelClassification
32 | from models.hat import HATModelForSequentialSentenceClassification, HATTokenizer, HATConfig
33 | from models.longformer import LongformerModelForSentenceClassification, LongformerTokenizer
34 | from models.big_bird import BigBirdModelForSentenceClassification, BigbirdTokenizer
35 |
36 |
37 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
38 | check_min_version("4.9.0")
39 |
40 | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
41 |
42 | logger = logging.getLogger(__name__)
43 |
44 |
45 | @dataclass
46 | class DataTrainingArguments:
47 | """
48 | Arguments pertaining to what data we are going to input our model for training and eval.
49 |
50 | Using `HfArgumentParser` we can turn this class
51 | into argparse arguments to be able to specify them on
52 | the command line.
53 | """
54 |
55 | max_seq_length: Optional[int] = field(
56 | default=4096,
57 | metadata={
58 | "help": "The maximum total input sequence length after tokenization. Sequences longer "
59 | "than this will be truncated, sequences shorter will be padded."
60 | },
61 | )
62 | max_sentences: int = field(
63 | default=32,
64 | metadata={
65 | "help": "The maximum number of sentences after tokenization. Sequences longer "
66 | "than this will be truncated."
67 | },
68 | )
69 | overwrite_cache: bool = field(
70 | default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
71 | )
72 | pad_to_max_length: bool = field(
73 | default=True,
74 | metadata={
75 | "help": "Whether to pad all samples to `max_seq_length`. "
76 | "If False, will pad the samples dynamically when batching to the maximum length in the batch."
77 | },
78 | )
79 | max_train_samples: Optional[int] = field(
80 | default=None,
81 | metadata={
82 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
83 | "value if set."
84 | },
85 | )
86 | max_eval_samples: Optional[int] = field(
87 | default=None,
88 | metadata={
89 | "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
90 | "value if set."
91 | },
92 | )
93 | max_predict_samples: Optional[int] = field(
94 | default=None,
95 | metadata={
96 | "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
97 | "value if set."
98 | },
99 | )
100 | dataset_name: Optional[str] = field(
101 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
102 | )
103 | dataset_config_name: Optional[str] = field(
104 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
105 | )
106 | task: Optional[str] = field(
107 | default='ecthr_arguments',
108 | metadata={
109 | "help": "Define downstream task"
110 | },
111 | )
112 | server_ip: Optional[str] = field(default=None, metadata={"help": "For distant debugging."})
113 | server_port: Optional[str] = field(default=None, metadata={"help": "For distant debugging."})
114 |
115 |
116 | @dataclass
117 | class ModelArguments:
118 | """
119 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
120 | """
121 |
122 | model_name_or_path: str = field(
123 | default=None, metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
124 | )
125 | config_name: Optional[str] = field(
126 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
127 | )
128 | tokenizer_name: Optional[str] = field(
129 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
130 | )
131 | cache_dir: Optional[str] = field(
132 | default=None,
133 | metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
134 | )
135 | use_fast_tokenizer: bool = field(
136 | default=True,
137 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
138 | )
139 | model_revision: str = field(
140 | default="main",
141 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
142 | )
143 | use_auth_token: bool = field(
144 | default=True,
145 | metadata={
146 | "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
147 | "with private models)."
148 | },
149 | )
150 |
151 |
152 | def main():
153 | # See all possible arguments in src/transformers/training_args.py
154 | # or by passing the --help flag to this script.
155 | # We now keep distinct sets of args, for a cleaner separation of concerns.
156 |
157 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
158 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
159 |
160 | # Setup distant debugging if needed
161 | if data_args.server_ip and data_args.server_port:
162 | # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
163 | import ptvsd
164 |
165 | print("Waiting for debugger attach")
166 | ptvsd.enable_attach(address=(data_args.server_ip, data_args.server_port), redirect_output=True)
167 | ptvsd.wait_for_attach()
168 |
169 | # Setup logging
170 | logging.basicConfig(
171 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
172 | datefmt="%m/%d/%Y %H:%M:%S",
173 | handlers=[logging.StreamHandler(sys.stdout)],
174 | )
175 |
176 | log_level = training_args.get_process_log_level()
177 | logger.setLevel(log_level)
178 | datasets.utils.logging.set_verbosity(log_level)
179 | transformers.utils.logging.set_verbosity(log_level)
180 | transformers.utils.logging.enable_default_handler()
181 | transformers.utils.logging.enable_explicit_format()
182 |
183 | # Log on each process the small summary:
184 | logger.warning(
185 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
186 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
187 | )
188 | logger.info(f"Training/evaluation parameters {training_args}")
189 |
190 | # Detecting last checkpoint.
191 | last_checkpoint = None
192 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
193 | last_checkpoint = get_last_checkpoint(training_args.output_dir)
194 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
195 | raise ValueError(
196 | f"Output directory ({training_args.output_dir}) already exists and is not empty. "
197 | "Use --overwrite_output_dir to overcome."
198 | )
199 | elif last_checkpoint is not None:
200 | logger.info(
201 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
202 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
203 | )
204 |
205 | # Set seed before initializing model.
206 | set_seed(training_args.seed)
207 |
208 | # In distributed training, the load_dataset function guarantees that only one local process can concurrently
209 | # download the dataset.
210 | # Downloading and loading eurlex dataset from the hub.
211 | if training_args.do_train:
212 | train_dataset = load_dataset(
213 | data_args.dataset_name,
214 | data_args.dataset_config_name,
215 | split="train",
216 | data_dir=data_args.dataset_name,
217 | cache_dir=model_args.cache_dir,
218 | )
219 |
220 | if training_args.do_eval:
221 | eval_dataset = load_dataset(
222 | data_args.dataset_name,
223 | data_args.dataset_config_name,
224 | split="validation",
225 | data_dir=data_args.dataset_name,
226 | cache_dir=model_args.cache_dir,
227 | )
228 |
229 | if training_args.do_predict:
230 | predict_dataset = load_dataset(
231 | data_args.dataset_name,
232 | data_args.dataset_config_name,
233 | split="test",
234 | data_dir=data_args.dataset_name,
235 | cache_dir=model_args.cache_dir,
236 | )
237 | # Labels
238 | label_list = list(range(train_dataset.features['labels'].feature.feature.num_classes))
239 | label_names = train_dataset.features['labels'].feature.feature.names
240 | num_labels = len(label_list)
241 |
242 | # Load pretrained model and tokenizer
243 | # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
244 | # download model & vocab.
245 |
246 | if 'hat' in model_args.model_name_or_path:
247 | config = HATConfig.from_pretrained(
248 | model_args.model_name_or_path,
249 | num_labels=num_labels,
250 | finetuning_task="ecthr-args",
251 | cache_dir=model_args.cache_dir,
252 | revision=model_args.model_revision,
253 | )
254 | tokenizer = HATTokenizer.from_pretrained(
255 | model_args.model_name_or_path,
256 | cache_dir=model_args.cache_dir,
257 | use_fast=model_args.use_fast_tokenizer,
258 | revision=model_args.model_revision,
259 | )
260 | model = HATModelForSequentialSentenceClassification.from_pretrained(
261 | model_args.model_name_or_path,
262 | from_tf=bool(".ckpt" in model_args.model_name_or_path),
263 | config=config,
264 | cache_dir=model_args.cache_dir,
265 | revision=model_args.model_revision,
266 | )
267 | elif 'longformer' in model_args.model_name_or_path:
268 | config = AutoConfig.from_pretrained(
269 | model_args.model_name_or_path,
270 | num_labels=num_labels,
271 | finetuning_task="ecthr-args",
272 | cache_dir=model_args.cache_dir,
273 | revision=model_args.model_revision,
274 | )
275 | config.max_sentence_size = 128
276 | config.max_sentence_length = 128
277 | config.max_sentences = data_args.max_sentences
278 | config.model_max_length = 4096
279 | config.cls_token_id = config.bos_token_id
280 | config.sep_token_id = config.eos_token_id
281 | tokenizer = LongformerTokenizer.from_pretrained(
282 | model_args.model_name_or_path,
283 | cache_dir=model_args.cache_dir,
284 | use_fast=model_args.use_fast_tokenizer,
285 | revision=model_args.model_revision,
286 | )
287 | model = LongformerModelForSentenceClassification.from_pretrained(
288 | model_args.model_name_or_path,
289 | from_tf=bool(".ckpt" in model_args.model_name_or_path),
290 | config=config,
291 | cache_dir=model_args.cache_dir,
292 | revision=model_args.model_revision,
293 | )
294 | elif 'bigbird' in model_args.model_name_or_path:
295 | config = AutoConfig.from_pretrained(
296 | model_args.model_name_or_path,
297 | num_labels=num_labels,
298 | finetuning_task="ecthr-args",
299 | cache_dir=model_args.cache_dir,
300 | revision=model_args.model_revision,
301 | )
302 | config.max_sentence_size = 128
303 | config.max_sentence_length = 128
304 | config.max_sentences = data_args.max_sentences
305 | config.model_max_length = 4096
306 | config.cls_token_id = config.bos_token_id
307 | config.sep_token_id = config.eos_token_id
308 | tokenizer = BigbirdTokenizer.from_pretrained(
309 | model_args.model_name_or_path,
310 | cache_dir=model_args.cache_dir,
311 | use_fast=model_args.use_fast_tokenizer,
312 | revision=model_args.model_revision,
313 | )
314 | model = BigBirdModelForSentenceClassification.from_pretrained(
315 | model_args.model_name_or_path,
316 | from_tf=bool(".ckpt" in model_args.model_name_or_path),
317 | config=config,
318 | cache_dir=model_args.cache_dir,
319 | revision=model_args.model_revision,
320 | )
321 |
322 | # Preprocessing the datasets
323 | # Padding strategy
324 | if data_args.pad_to_max_length:
325 | padding = "max_length"
326 | else:
327 | # We will pad later, dynamically at batch creation, to the max sequence length in each batch
328 | padding = False
329 |
330 | # for document, labels in zip(train_dataset['text'], train_dataset['labels']):
331 | # for paragraph, par_labels in zip(document, labels):
332 | # par_labels = [label_names[label] for label in par_labels]
333 | # if len(par_labels) > 1:
334 | # print()
335 |
336 | def preprocess_function(examples):
337 | # Tokenize the texts
338 | batch = tokenizer(
339 | examples['text'],
340 | padding=padding,
341 | max_length=data_args.max_seq_length,
342 | truncation=True,
343 | )
344 |
345 | label_ids = []
346 | for idx, labels in enumerate(examples["labels"]):
347 | par_label_ids = []
348 | for par_labels in labels[:tokenizer.config.max_sentences]:
349 | par_label_ids.append([1.0 if label in par_labels else 0.0 for label in label_list])
350 | par_label_ids.extend([[-1.0] * len(label_list)] * (tokenizer.config.max_sentences - len(par_label_ids)))
351 | label_ids.append(par_label_ids)
352 |
353 | batch["label_ids"] = label_ids
354 |
355 | return batch
356 |
357 | if training_args.do_train:
358 | if data_args.max_train_samples is not None:
359 | train_dataset = train_dataset.select(range(data_args.max_train_samples))
360 | with training_args.main_process_first(desc="train dataset map pre-processing"):
361 | train_dataset = train_dataset.map(
362 | preprocess_function,
363 | batched=True,
364 | load_from_cache_file=not data_args.overwrite_cache,
365 | desc="Running tokenizer on train dataset",
366 | )
367 | # Log a few random samples from the training set:
368 | for index in random.sample(range(len(train_dataset)), 3):
369 | logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
370 |
371 | if training_args.do_eval:
372 | if data_args.max_eval_samples is not None:
373 | eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
374 | with training_args.main_process_first(desc="validation dataset map pre-processing"):
375 | eval_dataset = eval_dataset.map(
376 | preprocess_function,
377 | batched=True,
378 | load_from_cache_file=not data_args.overwrite_cache,
379 | desc="Running tokenizer on validation dataset",
380 | )
381 |
382 | if training_args.do_predict:
383 | if data_args.max_predict_samples is not None:
384 | predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
385 | with training_args.main_process_first(desc="prediction dataset map pre-processing"):
386 | predict_dataset = predict_dataset.map(
387 | preprocess_function,
388 | batched=True,
389 | load_from_cache_file=not data_args.overwrite_cache,
390 | desc="Running tokenizer on prediction dataset",
391 | )
392 |
393 | # You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
394 | # predictions and label_ids field) and has to return a dictionary string to float.
395 | def compute_metrics(p: EvalPrediction):
396 | # Fix scores
397 | from multi_label_utils import fix_multi_label_scores
398 | y_true, y_pred = fix_multi_label_scores(p.predictions, p.label_ids,
399 | unpad_sequences=True, flatten_sequences=True)
400 |
401 | # Compute scores
402 | macro_f1 = f1_score(y_true=y_true, y_pred=y_pred, average='macro', zero_division=0)
403 | micro_f1 = f1_score(y_true=y_true, y_pred=y_pred, average='micro', zero_division=0)
404 | logger.info(classification_report(y_true=y_true, y_pred=y_pred,
405 | target_names=label_names + ['None'], zero_division=0)+'\n')
406 | return {'macro-f1': macro_f1, 'micro-f1': micro_f1}
407 |
408 | # Initialize our Trainer
409 | trainer = Trainer(
410 | model=model,
411 | args=training_args,
412 | train_dataset=train_dataset if training_args.do_train else None,
413 | eval_dataset=eval_dataset if training_args.do_eval else None,
414 | compute_metrics=compute_metrics,
415 | tokenizer=tokenizer,
416 | data_collator=DataCollatorForMultiLabelClassification(tokenizer),
417 | callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
418 | )
419 |
420 | # Training
421 | if training_args.do_train:
422 | checkpoint = None
423 | if training_args.resume_from_checkpoint is not None:
424 | checkpoint = training_args.resume_from_checkpoint
425 | elif last_checkpoint is not None:
426 | checkpoint = last_checkpoint
427 | train_result = trainer.train(resume_from_checkpoint=checkpoint)
428 | metrics = train_result.metrics
429 | max_train_samples = (
430 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
431 | )
432 | metrics["train_samples"] = min(max_train_samples, len(train_dataset))
433 |
434 | trainer.save_model() # Saves the tokenizer too for easy upload
435 |
436 | trainer.log_metrics("train", metrics)
437 | trainer.save_metrics("train", metrics)
438 | trainer.save_state()
439 |
440 | # Evaluation
441 | if training_args.do_eval:
442 | logger.info("*** Evaluate ***")
443 | metrics = trainer.evaluate(eval_dataset=eval_dataset)
444 |
445 | max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
446 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
447 |
448 | trainer.log_metrics("eval", metrics)
449 | trainer.save_metrics("eval", metrics)
450 |
451 | # Prediction
452 | if training_args.do_predict:
453 | logger.info("*** Predict ***")
454 | predictions, labels, metrics = trainer.predict(predict_dataset, metric_key_prefix="predict")
455 |
456 | max_predict_samples = (
457 | data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
458 | )
459 | metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
460 |
461 | trainer.log_metrics("predict", metrics)
462 | trainer.save_metrics("predict", metrics)
463 |
464 | output_predict_file = os.path.join(training_args.output_dir, "test_predictions.csv")
465 | report_predict_file = os.path.join(training_args.output_dir, "classification_report.txt")
466 | if trainer.is_world_process_zero():
467 | with open(output_predict_file, "w") as writer:
468 | try:
469 | for index, pred_list in enumerate(predictions[0]):
470 | pred_line = '\t'.join([f'{pred:.5f}' for pred in pred_list])
471 | writer.write(f"{index}\t{pred_line}\n")
472 | except:
473 | try:
474 | for index, pred_list in enumerate(predictions):
475 | pred_line = '\t'.join([f'{pred:.5f}' for pred in pred_list])
476 | writer.write(f"{index}\t{pred_line}\n")
477 | except:
478 | pass
479 |
480 | # Clean up checkpoints
481 | checkpoints = [filepath for filepath in glob.glob(f'{training_args.output_dir}/*/') if '/checkpoint' in filepath]
482 | for checkpoint in checkpoints:
483 | shutil.rmtree(checkpoint)
484 |
485 |
486 | if __name__ == "__main__":
487 | main()
488 |
--------------------------------------------------------------------------------