├── __init__.py ├── .github └── FUNDING.yml ├── submission_scripts ├── utils │ └── create_folder_outputs.sh ├── full_sentence │ ├── training_ALR.sh │ ├── training_ALRR.sh │ ├── training_ALSR.sh │ ├── training_ELR.sh │ ├── validation_script.sh │ ├── training_ELR_all_encoder.sh │ ├── training_ALRR_all_encoder.sh │ ├── training_ALSR_all_encoder.sh │ ├── validation_script_ALR_encoder.sh │ ├── validation_script_ALR_decoder.sh │ ├── validation_script_ELR_all_encoder.sh │ ├── validation_script_ALRR_all_encoder.sh │ ├── validation_script_ALSR_all_encoder.sh │ ├── validation_script_ALR_cross_decoder.sh │ ├── training_ALR_all_all.sh │ ├── validation_script_ALR_encoder_decoder.sh │ └── validation_script_ALR_all.sh ├── baseline │ └── training_script.sh └── extraction │ ├── extract.sh │ └── extract_mha.sh ├── .gitignore ├── environment.yml ├── scripts ├── baseline │ ├── prepare_dataset.py │ └── training_script.py ├── extraction │ ├── extract.py │ └── extract_mha.py └── full_sentence │ ├── validation_script.py │ ├── training_ALRR.py │ ├── training_ELR.py │ ├── training_ALSR.py │ └── training_ALR.py ├── LICENCE ├── utils ├── constants.py ├── visualization_utils.py ├── utils.py ├── optimizers_and_distributions.py ├── simulator.py ├── decoding_utils.py └── data_utils.py ├── models └── definitions │ ├── ALRR_FF.py │ ├── ELR_FF.py │ ├── ALSR_FF.py │ └── ALR_FF.py └── README.md /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | patreon: theaiepiphany 2 | -------------------------------------------------------------------------------- /submission_scripts/utils/create_folder_outputs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | DIR="$SCRATCH/pytorch-original-transformer/sbatch_log/$1" 3 | if [ ! -d $DIR ];then 4 | mkdir $DIR 5 | fi 6 | model_name=$DIR/$2 7 | mkdir $model_name 8 | mkdir $model_name/evaluation_outputs 9 | mkdir $model_name/training_outputs 10 | -------------------------------------------------------------------------------- /submission_scripts/full_sentence/training_ALR.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --output=sbatch_log/%j.out 3 | #SBATCH --gpus=1 4 | #SBATCH --mem-per-cpu=16000 5 | #SBATCH --gres=gpumem:20g 6 | #SBATCH --time 250 7 | source $SCRATCH/miniconda3/etc/profile.d/conda.sh 8 | source activate pytorch-transformer 9 | python -u scripts/full_sentence/training_ALR.py "$@" -------------------------------------------------------------------------------- /submission_scripts/full_sentence/training_ALRR.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --output=sbatch_log/%j.out 3 | #SBATCH --gpus=1 4 | #SBATCH --mem-per-cpu=16000 5 | #SBATCH --gres=gpumem:20g 6 | #SBATCH --time 600 7 | source $SCRATCH/miniconda3/etc/profile.d/conda.sh 8 | source activate pytorch-transformer 9 | python -u scripts/full_sentence/training_ALRR.py "$@" -------------------------------------------------------------------------------- /submission_scripts/full_sentence/training_ALSR.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --output=sbatch_log/%j.out 3 | #SBATCH --gpus=1 4 | #SBATCH --mem-per-cpu=16000 5 | #SBATCH --gres=gpumem:20g 6 | #SBATCH --time 600 7 | source $SCRATCH/miniconda3/etc/profile.d/conda.sh 8 | source activate pytorch-transformer 9 | python -u scripts/full_sentence/training_ALSR.py "$@" -------------------------------------------------------------------------------- /submission_scripts/full_sentence/training_ELR.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --output=sbatch_log/%j.out 3 | #SBATCH --gpus=1 4 | #SBATCH --mem-per-cpu=16000 5 | #SBATCH --gres=gpumem:20g 6 | #SBATCH --time 250 7 | source $SCRATCH/miniconda3/etc/profile.d/conda.sh 8 | source activate pytorch-transformer 9 | python -u scripts/full_sentence/training_ELR.py "$@" -------------------------------------------------------------------------------- /submission_scripts/baseline/training_script.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --output=sbatch_log/%j.out 3 | #SBATCH --gpus=1 4 | #SBATCH --mem-per-cpu=16000 5 | #SBATCH --gres=gpumem:20g 6 | #SBATCH --time 1000 7 | source $SCRATCH/miniconda3/etc/profile.d/conda.sh 8 | source activate pytorch-transformer 9 | python -u scripts/baseline/training_script.py "$@" 10 | -------------------------------------------------------------------------------- /submission_scripts/full_sentence/validation_script.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --output=sbatch_log/%j.out 3 | #SBATCH --gpus=1 4 | #SBATCH --mem-per-cpu=16000 5 | #SBATCH --gres=gpumem:20g 6 | #SBATCH --time 180 7 | source $SBATCH/miniconda3/etc/profile.d/conda.sh 8 | source activate pytorch-transformer 9 | python -u scripts/full_sentence/validation_script.py "$@" -------------------------------------------------------------------------------- /submission_scripts/extraction/extract.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | #SBATCH --output=sbatch_log/%j.out 3 | #SBATCH --mem-per-cpu=16000 4 | #SBATCH --gres=gpumem:20g 5 | #SBATCH --gpus=1 6 | #SBATCH --time=240 7 | source $SCRATCH/miniconda3/etc/profile.d/conda.sh 8 | source activate pytorch-transformer 9 | python3 -u scripts/extraction/extract.py --path_to_weights $SCRATCH/pytorch-original-transformer/models/binaries/Transformer_None_None_20.pth --batch_size 1400 --dataset_name IWSLT --language_direction E2F --model_name 128emb_20ep -------------------------------------------------------------------------------- /submission_scripts/extraction/extract_mha.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --output=sbatch_log/%j.out 3 | #SBATCH --gpus=1 4 | #SBATCH --mem-per-cpu=16000 5 | #SBATCH --time 250 6 | source $SCRATCH/miniconda3/etc/profile.d/conda.sh 7 | source activate pytorch-transformer 8 | python -u scripts/extraction/extract_mha.py --batch_size 1400 --dataset_name IWSLT --language_direction E2F --model_name 128emb_20ep --path_to_weights $SCRATCH/pytorch-original-transformer/models/binaries/Transformer_None_None_20.pth --output_path $SCRATCH/pytorch-original-transformer/mha_outputs "$@" 9 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # PyCharm IDE 2 | .vscode 3 | .idea 4 | __pycache__ 5 | 6 | # This is where torch text datasets like IWSLT and WMT14 will be downloaded to 7 | data/ 8 | data/*.csv 9 | data/iwslt 10 | data/wikitext-2 11 | data/wmt14 12 | 13 | # Jupyter notebook checkpoints 14 | .ipynb_checkpoints 15 | 16 | # Tensorboard log files 17 | runs 18 | 19 | # Models checkpoints and binaries 20 | models/checkpoints 21 | models/binaries 22 | 23 | # Azure ML related 24 | .azureml 25 | submit_to_aml.py 26 | environment_aml.yml 27 | training_script_aml.py 28 | pytorch-original-transformer 29 | 30 | sbatch_log 31 | layer_outputs 32 | mha_outputs -------------------------------------------------------------------------------- /submission_scripts/full_sentence/training_ELR_all_encoder.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | if [ $# == 0 ]; then 3 | echo "" 4 | echo "submission_scripts/full_sentence/training_ELR_submit_all.sh ?" 5 | echo 6 | echo "Args:" 7 | echo " in [ FFNetwork_L,FFNetwork_M, FFNetwork_XL, FFNetwork_XS, FFNetwork_S]" 8 | exit 9 | fi 10 | 11 | ./submission_scripts/utils/create_folder_outputs.sh ELR $1 12 | for i in {0..5};do 13 | echo sbatch --output=./sbatch_log/ELR/$1/training_outputs/%j.out submission_scripts/full_sentence/training_ELR.sh --num_of_curr_trained_layer $i --substitute_class $1 14 | echo "" 15 | sbatch --output=./sbatch_log/ELR/$1/training_outputs/%j.out submission_scripts/full_sentence/training_ELR.sh --num_of_curr_trained_layer $i --substitute_class $1 16 | done 17 | -------------------------------------------------------------------------------- /submission_scripts/full_sentence/training_ALRR_all_encoder.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | if [ $# == 0 ]; then 3 | echo "" 4 | echo "submission_scripts/full_sentence/training_ALRR_submit_all.sh ?" 5 | echo 6 | echo "Args:" 7 | echo " in [ FFNetwork_L,FFNetwork_M, FFNetwork_XL, FFNetwork_XS, FFNetwork_S]" 8 | exit 9 | fi 10 | 11 | ./submission_scripts/utils/create_folder_outputs.sh ALRR $1 12 | for i in {0..5};do 13 | echo sbatch --output=./sbatch_log/ALRR/$1/training_outputs/%j.out submission_scripts/full_sentence/training_ALRR.sh --num_of_curr_trained_layer $i --substitute_class $1 14 | echo "" 15 | sbatch --output=./sbatch_log/ALRR/$1/training_outputs/%j.out submission_scripts/full_sentence/training_ALRR.sh --num_of_curr_trained_layer $i --substitute_class $1 16 | done 17 | -------------------------------------------------------------------------------- /submission_scripts/full_sentence/training_ALSR_all_encoder.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | if [ $# == 0 ]; then 3 | echo "" 4 | echo "submission_scripts/full_sentence/training_ALSR_submit_all.sh ?" 5 | echo 6 | echo "Args:" 7 | echo " in [ FFNetwork_L,FFNetwork_M, FFNetwork_XL, FFNetwork_XS, FFNetwork_S]" 8 | exit 9 | fi 10 | 11 | ./submission_scripts/utils/create_folder_outputs.sh ALSR $1 12 | for i in {0..5};do 13 | echo sbatch --output=./sbatch_log/ALSR/$1/training_outputs/%j.out submission_scripts/full_sentence/training_ALSR.sh --num_of_curr_trained_layer $i --substitute_class $1 14 | echo "" 15 | sbatch --output=./sbatch_log/ALSR/$1/training_outputs/%j.out submission_scripts/full_sentence/training_ALSR.sh --num_of_curr_trained_layer $i --substitute_class $1 16 | done 17 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: rethinking-attention 2 | channels: 3 | - defaults 4 | - pytorch 5 | dependencies: 6 | - python==3.8.3 7 | - pip==20.0.2 8 | - matplotlib==3.1.3 9 | - pytorch==1.5.0 10 | - torchtext==0.6.0 11 | - numpy==1.20.3 12 | - pip: 13 | - GitPython==3.1.2 14 | - jupyter==1.0.0 15 | - numpy==1.20.3 16 | - nltk==3.5 17 | - seaborn==0.11.0 18 | - spacy==2.3.2 19 | - https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-2.3.1/en_core_web_sm-2.3.1.tar.gz#egg=en_core_web_sm 20 | - https://github.com/explosion/spacy-models/releases/download/de_core_news_sm-2.3.0/de_core_news_sm-2.3.0.tar.gz#egg=de_core_news_sm 21 | - https://github.com/explosion/spacy-models/releases/download/fr_core_news_sm-2.3.0/fr_core_news_sm-2.3.0.tar.gz#egg=fr_core_news_sm 22 | - optuna 23 | - datasets 24 | -------------------------------------------------------------------------------- /scripts/baseline/prepare_dataset.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | import pandas as pd 3 | import os 4 | 5 | if not os.path.exists('./data/prepared_data'): 6 | os.makedirs('./data/prepared_data') 7 | 8 | subset_of_data = 'de-en' # options: 'en-fr', 'fr-en', 'en-de', 'de-en' ... 9 | src_lang, trg_lang = subset_of_data.split('-') 10 | target_subset = src_lang + '_' + trg_lang 11 | dataset = load_dataset('iwslt2017', pair= subset_of_data, is_multilingual = False, cache_dir='./data') 12 | df = pd.DataFrame(dataset['train']['translation'], columns=[src_lang, trg_lang]) 13 | df.to_csv(f'./data/prepared_data/train_{target_subset}.csv', index=False) 14 | 15 | df = pd.DataFrame(dataset['validation']['translation'], columns=[src_lang, trg_lang]) 16 | df.to_csv(f'./data/prepared_data/val_{target_subset}.csv', index=False) 17 | 18 | df = pd.DataFrame(dataset['test']['translation'], columns=[src_lang, trg_lang]) 19 | df.to_csv(f'./data/prepared_data/test_{target_subset}.csv', index=False) -------------------------------------------------------------------------------- /LICENCE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Aleksa Gordić 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /submission_scripts/full_sentence/validation_script_ALR_encoder.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | if [ $# == 0 ]; then 3 | echo "" 4 | echo "submission_scripts/full_sentence/validation_script_ALR_submit_all.sh ?" 5 | echo 6 | echo "Args:" 7 | echo " in [ FFNetwork_L,FFNetwork_M, FFNetwork_XL,FF Network_XS, FFNetwork_S]" 8 | exit 9 | fi 10 | epoch=21 11 | ./submission_scripts/utils/create_folder_outputs.sh ALR $1 12 | 13 | # for i in {0..5}; do 14 | # echo "Substiting layer $i..." 15 | # sbatch --output=../sbatch_log/ALR/$1/evaluation_outputs/%j.out submission_scripts/full_sentence/validation_script.sh --substitute_model_path$suffix $SCRATCH/models/checkpoints/ALR/$1/ --epoch$suffix $epoch --substitute_type$suffix ALR --substitute_class$suffix $1 --layers$suffix $i 16 | # done 17 | 18 | echo "Substituting all layers" 19 | sbatch --output=./sbatch_log/ALR/$1/evaluation_outputs/%j.out submission_scripts/full_sentence/validation_script.sh --substitute_model_path $SCRATCH/pytorch-original-transformer/models/checkpoints/ALR/$1/ --epoch $epoch --substitute_type ALR --substitute_class $1 20 | -------------------------------------------------------------------------------- /submission_scripts/full_sentence/validation_script_ALR_decoder.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | if [ $# == 0 ]; then 3 | echo "" 4 | echo "submission_scripts/full_sentence/validation_script_ALR_submit_all.sh ?" 5 | echo 6 | echo "Args:" 7 | echo " in [ FFNetwork_decoder_L,FFNetwork_decoder_M, FFNetwork_decoder_XL,FFNetwork_decoder_XS, FFNetwork_decoder_S]" 8 | exit 9 | fi 10 | epoch=21 11 | ./submission_scripts/utils/create_folder_outputs.sh ALR $1 12 | 13 | # for i in {0..5}; do 14 | # echo "Substiting layer $i..." 15 | # sbatch --output=../sbatch_log/ALR/$1/evaluation_outputs/%j.out submission_scripts/full_sentence/validation_script.sh --substitute_model_path$suffix $SCRATCH/models/checkpoints/ALR/$1/ --epoch$suffix $epoch --substitute_type$suffix ALR --substitute_class$suffix $1 --layers$suffix $i 16 | # done 17 | 18 | echo "Substituting all layers" 19 | sbatch --output=./sbatch_log/ALR/$1/evaluation_outputs/%j.out submission_scripts/full_sentence/validation_script.sh --substitute_model_path_d $SCRATCH/pytorch-original-transformer/models/checkpoints/ALR/$1/ --epoch_d $epoch --substitute_type_d ALR --substitute_class_d $1 -------------------------------------------------------------------------------- /submission_scripts/full_sentence/validation_script_ELR_all_encoder.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | if [ $# == 0 ]; then 3 | echo "" 4 | echo "submission_scripts/full_sentence/validation_script_ELR_submit_all.sh ?" 5 | echo 6 | echo "Args:" 7 | echo " in [ FFNetwork_L,FFNetwork_M, FFNetwork_XL,FF Network_XS, FFNetwork_S]" 8 | echo "modify the parameter epoch as you need in the script" 9 | exit 10 | fi 11 | epoch=21 12 | ./submission_scripts/utils/create_folder_outputs.sh ELR $1 13 | 14 | # for i in {0..5}; do 15 | # echo "Substiting layer $i..." 16 | # sbatch --output=../sbatch_log/ELR/$1/evaluation_outputs/%j.out submission_scripts/full_sentence/validation_script.sh --substitute_model_path$suffix $SCRATCH/models/checkpoints/ELR/$1/ --epoch$suffix $epoch --substitute_type$suffix ELR --substitute_class$suffix $1 --layers$suffix $i 17 | # done 18 | 19 | echo "Substituting all layers" 20 | sbatch --output=./sbatch_log/ELR/$1/evaluation_outputs/%j.out submission_scripts/full_sentence/validation_script.sh --substitute_model_path $SCRATCH/pytorch-original-transformer/models/checkpoints/ELR/$1/ --epoch $epoch --substitute_type ELR --substitute_class $1 21 | -------------------------------------------------------------------------------- /submission_scripts/full_sentence/validation_script_ALRR_all_encoder.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | if [ $# == 0 ]; then 3 | echo "" 4 | echo "submission_scripts/full_sentence/validation_script_ALR_submit_all.sh ?" 5 | echo 6 | echo "Args:" 7 | echo " in [ FFNetwork_L,FFNetwork_M, FFNetwork_XL,FF Network_XS, FFNetwork_S]" 8 | echo "modify the parameter epoch as you need in the script" 9 | exit 10 | fi 11 | epoch=21 12 | ./submission_scripts/utils/create_folder_outputs.sh ALRR $1 13 | 14 | # for i in {0..5}; do 15 | # echo "Substiting layer $i..." 16 | # sbatch --output=../sbatch_log/ALRR/$1/evaluation_outputs/%j.out submission_scripts/full_sentence/validation_script.sh --substitute_model_path$suffix $SCRATCH/models/checkpoints/ALRR/$1/ --epoch$suffix $epoch --substitute_type$suffix ALRR --substitute_class$suffix $1 --layers$suffix $i 17 | # done 18 | 19 | echo "Substituting all layers" 20 | sbatch --output=./sbatch_log/ALRR/$1/evaluation_outputs/%j.out submission_scripts/full_sentence/validation_script.sh --substitute_model_path $SCRATCH/pytorch-original-transformer/models/checkpoints/ALRR/$1/ --epoch $epoch --substitute_type ALRR --substitute_class $1 21 | -------------------------------------------------------------------------------- /submission_scripts/full_sentence/validation_script_ALSR_all_encoder.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | if [ $# == 0 ]; then 3 | echo "" 4 | echo "submission_scripts/full_sentence/validation_script_ALSR_submit_all.sh ?" 5 | echo 6 | echo "Args:" 7 | echo " in [ FFNetwork_L,FFNetwork_M, FFNetwork_XL,FF Network_XS, FFNetwork_S]" 8 | echo "modify the parameter epoch as you need in the script" 9 | exit 10 | fi 11 | epoch=21 12 | ./submission_scripts/utils/create_folder_outputs.sh ALSR $1 13 | 14 | # for i in {0..5}; do 15 | # echo "Substiting layer $i..." 16 | # sbatch --output=../sbatch_log/ALSR/$1/evaluation_outputs/%j.out submission_scripts/full_sentence/validation_script.sh --substitute_model_path$suffix $SCRATCH/models/checkpoints/ALSR/$1/ --epoch$suffix $epoch --substitute_type$suffix ALSR --substitute_class$suffix $1 --layers$suffix $i 17 | # done 18 | 19 | echo "Substituting all layers" 20 | sbatch --output=./sbatch_log/ALSR/$1/evaluation_outputs/%j.out submission_scripts/full_sentence/validation_script.sh --substitute_model_path $SCRATCH/pytorch-original-transformer/models/checkpoints/ALSR/$1/ --epoch $epoch --substitute_type ALSR --substitute_class $1 21 | -------------------------------------------------------------------------------- /submission_scripts/full_sentence/validation_script_ALR_cross_decoder.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | if [ $# == 0 ]; then 3 | echo "" 4 | echo "submission_scripts/full_sentence/validation_script_ALR_submit_all.sh ?" 5 | echo 6 | echo "Args:" 7 | echo " in [ FFNetwork_cross_decoder_L,FFNetwork_cross_decoder_M, FFNetwork_cross_decoder_XL,FF Network_cross_decoder_XS, FFNetwork_cross_decoder_S]" 8 | exit 9 | fi 10 | epoch=21 11 | ./submission_scripts/utils/create_folder_outputs.sh ALR $1 12 | 13 | # for i in {0..5}; do 14 | # echo "Substiting layer $i..." 15 | # sbatch --output=../sbatch_log/ALR/$1/evaluation_outputs/%j.out submission_scripts/full_sentence/validation_script.sh --substitute_model_path$suffix $SCRATCH/models/checkpoints/ALR/$1/ --epoch$suffix $epoch --substitute_type$suffix ALR --substitute_class$suffix $1 --layers$suffix $i 16 | # done 17 | 18 | echo "Substituting all layers" 19 | sbatch --output=./sbatch_log/ALR/$1/evaluation_outputs/%j.out submission_scripts/full_sentence/validation_script.sh --substitute_model_path_d_ca $SCRATCH/pytorch-original-transformer/models/checkpoints/ALR/$1/ --epoch_d_ca $epoch --substitute_type_d_ca ALR --substitute_class_d_ca $1 20 | -------------------------------------------------------------------------------- /submission_scripts/full_sentence/training_ALR_all_all.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | if [ $# == 0 ]; then 3 | echo "" 4 | echo "submission_scripts/full_sentence/training_ALR_all_all.sh ?" 5 | echo 6 | echo "Args:" 7 | echo " in [ FFNetwork_L,FFNetwork_M, FFNetwork_XL, FFNetwork_XS, FFNetwork_S]" 8 | echo " in encoder, set it if the layer you want to train is part of the encoder" 9 | echo -e "\t\t\t [ FFNetwork_decoder_L, FFNetwork_decoder_M, FFNetwork_decoder_XL, FFNetwork_decoder_XS, FFNetwork_decoder_S]" 10 | echo " in decoder, set it if the layer you want to train is part of the self decoder" 11 | echo -e "\t\t\t [ FFNetwork_cross_decoder_L, FFNetwork_cross_decoder_M, FFNetwork_cross_decoder_XL, FFNetwork_cross_decoder_XS, FFNetwork_cross_decoder_S]" 12 | echo " in decoder_ca, set it if the layer you want to train is part of the cross decoder" 13 | exit 14 | fi 15 | 16 | 17 | ./submission_scripts/utils/create_folder_outputs.sh ALR $1 18 | for i in {0..5};do 19 | echo sbatch --output=./sbatch_log/ALR/$1/training_outputs/%j.out submission_scripts/full_sentence/training_ALR.sh --num_of_curr_trained_layer $i --substitute_class $1 --att_replacement $2 20 | echo "" 21 | sbatch --output=./sbatch_log/ALR/$1/training_outputs/%j.out submission_scripts/full_sentence/training_ALR.sh --num_of_curr_trained_layer $i --substitute_class $1 --att_replacement $2 22 | done 23 | -------------------------------------------------------------------------------- /utils/constants.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | BASELINE_MODEL_NUMBER_OF_LAYERS = 6 5 | #BASELINE_MODEL_DIMENSION = 512 6 | BASELINE_MODEL_DIMENSION = 128 7 | BASELINE_MODEL_NUMBER_OF_HEADS = 8 8 | BASELINE_MODEL_DROPOUT_PROB = 0.1 9 | BASELINE_MODEL_LABEL_SMOOTHING_VALUE = 0.1 10 | 11 | 12 | BIG_MODEL_NUMBER_OF_LAYERS = 6 13 | BIG_MODEL_DIMENSION = 1024 14 | BIG_MODEL_NUMBER_OF_HEADS = 16 15 | BIG_MODEL_DROPOUT_PROB = 0.3 16 | BIG_MODEL_LABEL_SMOOTHING_VALUE = 0.1 17 | 18 | 19 | SCRATCH = os.environ.get('SCRATCH') 20 | CHECKPOINTS_SCRATCH = os.path.join(SCRATCH, 'pytorch-original-transformer', 'models', 'checkpoints') 21 | CHECKPOINTS_PATH = CHECKPOINTS_SCRATCH 22 | 23 | 24 | BINARIES_PATH = os.path.join(SCRATCH, 'pytorch-original-transformer', 'models', 'binaries') 25 | DATA_DIR_PATH = os.path.join(os.path.dirname(__file__), os.pardir, 'data') 26 | LAYER_OUTPUT_PATH = os.path.join(SCRATCH, 'pytorch-original-transformer', "layer_outputs") 27 | MHA_OUTPUT_PATH = os.path.join(SCRATCH, 'pytorch-original-transformer', "mha_outputs") 28 | ALR_CHECKPOINT_FORMAT = "ff_network_{0}_layer_{1}.pth" #.format(epoch, layer) 29 | MHA_SEPARATE_CHECKPOINT_FORMAT = "ff_network_{0}_layer_{1}_head{2}.pth" 30 | os.makedirs(CHECKPOINTS_SCRATCH, exist_ok=True) 31 | 32 | os.makedirs(CHECKPOINTS_PATH, exist_ok=True) 33 | os.makedirs(BINARIES_PATH, exist_ok=True) 34 | os.makedirs(DATA_DIR_PATH, exist_ok=True) 35 | os.makedirs(LAYER_OUTPUT_PATH, exist_ok=True) 36 | os.makedirs(MHA_OUTPUT_PATH, exist_ok=True) 37 | os.makedirs(CHECKPOINTS_SCRATCH, exist_ok=True) 38 | 39 | 40 | BOS_TOKEN = '' 41 | EOS_TOKEN = '' 42 | PAD_TOKEN = "" 43 | MAX_LEN = 50 -------------------------------------------------------------------------------- /submission_scripts/full_sentence/validation_script_ALR_encoder_decoder.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | if [ $# == 0 ]; then 3 | echo "" 4 | echo "submission_scripts/full_sentence/validation_script_ALR_submit_all.sh [--encoder] [--decoder] [--decoder_ca]?" 5 | echo 6 | echo "Args:" 7 | echo " in [ FFNetwork_L,FFNetwork_M, FFNetwork_XL,FF Network_XS, FFNetwork_S]" 8 | echo "--encoder: set it if the layer you want to substitute encoder layers" 9 | echo -e "\t\t\t [ FFNetwork_decoder_L, FFNetwork_decoder_M, FFNetwork_decoder_XL, FFNetwork_decoder_XS, FFNetwork_decoder_S]" 10 | echo "--decoder: set it if the layer you want to substitute decoder layers" 11 | echo "modify the parameter epoch as you need in the script" 12 | exit 13 | fi 14 | epoch=21 15 | ./submission_scripts/utils/create_folder_outputs.sh ALR $1 16 | 17 | suffix_1=_d 18 | # for i in {0..5}; do 19 | # echo "Substiting layer $i..." 20 | # sbatch --output=../sbatch_log/ALR/$1/evaluation_outputs/%j.out submission_scripts/full_sentence/validation_script.sh --substitute_model_path$suffix $SCRATCH/models/checkpoints/ALR/$1/ --epoch$suffix $epoch --substitute_type$suffix ALR --substitute_class$suffix $1 --layers$suffix $i 21 | # done 22 | 23 | echo "Substituting all layers" 24 | sbatch --output=./sbatch_log/ALR/$1/evaluation_outputs/%j.out submission_scripts/full_sentence/validation_script.sh --substitute_model_path $SCRATCH/pytorch-original-transformer/models/checkpoints/ALR/$1/ --epoch $epoch --substitute_type ALR --substitute_class $1 --substitute_model_path$suffix_1 $SCRATCH/pytorch-original-transformer/models/checkpoints/ALR/$3/ --epoch$suffix_1 $epoch --substitute_type$suffix_1 ALR --substitute_class$suffix_1 $3 25 | -------------------------------------------------------------------------------- /submission_scripts/full_sentence/validation_script_ALR_all.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | if [ $# == 0 ]; then 3 | echo "" 4 | echo "submission_scripts/full_sentence/validation_script_ALR_submit_all.sh [--encoder] [--decoder] [--decoder_ca]?" 5 | echo 6 | echo "Args:" 7 | echo " in [ FFNetwork_L,FFNetwork_M, FFNetwork_XL,FF Network_XS, FFNetwork_S]" 8 | echo "--encoder: set it if the layer you want to substitute encoder layers" 9 | echo -e "\t\t\t [ FFNetwork_decoder_L, FFNetwork_decoder_M, FFNetwork_decoder_XL, FFNetwork_decoder_XS, FFNetwork_decoder_S]" 10 | echo "--decoder: set it if the layer you want to substitute decoder layers" 11 | echo -e "\t\t\t [ FFNetwork_cross_decoder_L, FFNetwork_cross_decoder_M, FFNetwork_cross_decoder_XL, FFNetwork_cross_decoder_XS, FFNetwork_cross_decoder_S]" 12 | echo "--decoder_ca: set it if the layer you want to substitute decoder cross attention layers" 13 | echo "modify the parameter epoch as you need in the script" 14 | exit 15 | fi 16 | epoch=21 17 | ./submission_scripts/utils/create_folder_outputs.sh ALR $1 18 | 19 | suffix_1=_d 20 | suffix_2=_d_ca 21 | # for i in {0..5}; do 22 | # echo "Substiting layer $i..." 23 | # sbatch --output=../sbatch_log/ALR/$1/evaluation_outputs/%j.out submission_scripts/full_sentence/validation_script.sh --substitute_model_path$suffix $SCRATCH/models/checkpoints/ALR/$1/ --epoch$suffix $epoch --substitute_type$suffix ALR --substitute_class$suffix $1 --layers$suffix $i 24 | # done 25 | 26 | echo "Substituting all layers" 27 | sbatch --output=./sbatch_log/ALR/$1/evaluation_outputs/%j.out submission_scripts/full_sentence/validation_script.sh --substitute_model_path $SCRATCH/pytorch-original-transformer/models/checkpoints/ALR/$1/ --epoch $epoch --substitute_type ALR --substitute_class $1 --substitute_model_path$suffix_1 $SCRATCH/pytorch-original-transformer/models/checkpoints/ALR/$3/ --epoch$suffix_1 $epoch --substitute_type$suffix_1 ALR --substitute_class$suffix_1 $3 --substitute_model_path$suffix_2 $SCRATCH/pytorch-original-transformer/models/checkpoints/ALR/$5/ --epoch$suffix_2 $epoch --substitute_type$suffix_2 ALR --substitute_class$suffix_2 $5 28 | -------------------------------------------------------------------------------- /utils/visualization_utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import seaborn 3 | 4 | 5 | def plot_attention_heatmap(data, x, y, head_id, ax): 6 | seaborn.heatmap(data, xticklabels=x, yticklabels=y, square=True, vmin=0.0, vmax=1.0, cbar=False, annot=True, fmt=".2f", ax=ax) 7 | ax.set_title(f'MHA head id = {head_id}') 8 | 9 | 10 | def visualize_attention_helper(attention_weights, source_sentence_tokens=None, target_sentence_tokens=None, title=''): 11 | num_columns = 4 12 | num_rows = 2 13 | fig, axs = plt.subplots(num_rows, num_columns, figsize=(20, 10)) # prepare the figure and axes 14 | 15 | assert source_sentence_tokens is not None or target_sentence_tokens is not None, \ 16 | f'Either source or target sentence must be passed in.' 17 | 18 | target_sentence_tokens = source_sentence_tokens if target_sentence_tokens is None else target_sentence_tokens 19 | source_sentence_tokens = target_sentence_tokens if source_sentence_tokens is None else source_sentence_tokens 20 | 21 | for head_id, head_attention_weights in enumerate(attention_weights): 22 | row_index = int(head_id / num_columns) 23 | column_index = head_id % num_columns 24 | plot_attention_heatmap(head_attention_weights, source_sentence_tokens, target_sentence_tokens if head_id % num_columns == 0 else [], head_id, axs[row_index, column_index]) 25 | 26 | fig.suptitle(title) 27 | plt.show() 28 | 29 | 30 | def visualize_attention(baseline_transformer, source_sentence_tokens, target_sentence_tokens): 31 | encoder = baseline_transformer.encoder 32 | decoder = baseline_transformer.decoder 33 | 34 | # Remove the end of sentence token as we never attend to it, it's produced at the output and we stop 35 | target_sentence_tokens = target_sentence_tokens[0][:-1] 36 | 37 | # Visualize encoder attention weights 38 | for layer_id, encoder_layer in enumerate(encoder.encoder_layers): 39 | mha = encoder_layer.multi_headed_attention # Every encoder layer has 1 MHA module 40 | 41 | # attention_weights shape = (B, NH, S, S), extract 0th batch and loop over NH (number of heads) MHA heads 42 | # S stands for maximum source token-sequence length 43 | attention_weights = mha.attention_weights.cpu().numpy()[0] 44 | 45 | title = f'Encoder layer {layer_id + 1}' 46 | visualize_attention_helper(attention_weights, source_sentence_tokens, title=title) 47 | 48 | # Visualize decoder attention weights 49 | for layer_id, decoder_layer in enumerate(decoder.decoder_layers): 50 | mha_trg = decoder_layer.trg_multi_headed_attention # Extract the self-attention MHA 51 | mha_src = decoder_layer.src_multi_headed_attention # Extract the source attending MHA 52 | 53 | # attention_weights shape = (B, NH, T, T), T stands for maximum target token-sequence length 54 | attention_weights_trg = mha_trg.attention_weights.cpu().numpy()[0] 55 | # shape = (B, NH, T, S), target token representations create queries and keys/values come from the encoder 56 | attention_weights_src = mha_src.attention_weights.cpu().numpy()[0] 57 | 58 | title = f'Decoder layer {layer_id + 1}, self-attention MHA' 59 | visualize_attention_helper(attention_weights_trg, target_sentence_tokens=target_sentence_tokens, title=title) 60 | 61 | title = f'Decoder layer {layer_id + 1}, source-attending MHA' 62 | visualize_attention_helper(attention_weights_src, source_sentence_tokens, target_sentence_tokens, title) -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | import time 4 | 5 | 6 | import git 7 | import torch 8 | from nltk.translate.bleu_score import corpus_bleu 9 | 10 | 11 | from .constants import BINARIES_PATH, PAD_TOKEN 12 | from .decoding_utils import greedy_decoding 13 | from .data_utils import get_masks_and_count_tokens_src 14 | 15 | 16 | def get_available_binary_name(): 17 | prefix = 'transformer' 18 | 19 | def valid_binary_name(binary_name): 20 | # First time you see raw f-string? Don't worry the only trick is to double the brackets. 21 | pattern = re.compile(rf'{prefix}_[0-9]{{6}}\.pth') 22 | return re.fullmatch(pattern, binary_name) is not None 23 | 24 | # Just list the existing binaries so that we don't overwrite them but write to a new one 25 | valid_binary_names = list(filter(valid_binary_name, os.listdir(BINARIES_PATH))) 26 | if len(valid_binary_names) > 0: 27 | last_binary_name = sorted(valid_binary_names)[-1] 28 | new_suffix = int(last_binary_name.split('.')[0][-6:]) + 1 # increment by 1 29 | return f'{prefix}_{str(new_suffix).zfill(6)}.pth' 30 | else: 31 | return f'{prefix}_000000.pth' 32 | 33 | 34 | def get_training_state(training_config, steps_taken, model): 35 | training_state = { 36 | # "commit_hash": git.Repo(search_parent_directories=True).head.object.hexsha, 37 | "dataset_name": training_config['dataset_name'], 38 | "language_direction": training_config['language_direction'], 39 | 40 | "num_of_epochs": training_config['num_of_epochs'], 41 | "batch_size": training_config['batch_size'], 42 | "steps_taken": steps_taken, 43 | 44 | "state_dict": model.state_dict() 45 | } 46 | 47 | return training_state 48 | 49 | 50 | def print_model_metadata(training_state): 51 | header = f'\n{"*"*5} Model training metadata: {"*"*5}' 52 | print(header) 53 | 54 | for key, value in training_state.items(): 55 | if key != 'state_dict': # don't print state_dict it's a bunch of numbers... 56 | if key == 'language_direction': # convert into human readable format 57 | value = 'English to German' if value == 'E2G' else 'German to English' 58 | print(f'{key}: {value}') 59 | print(f'{"*" * len(header)}\n') 60 | 61 | 62 | # Calculate the BLEU-4 score 63 | def calculate_bleu_score(transformer, token_ids_loader, trg_field_processor): 64 | with torch.no_grad(): 65 | pad_token_id = trg_field_processor.vocab.stoi[PAD_TOKEN] 66 | 67 | gt_sentences_corpus = [] 68 | predicted_sentences_corpus = [] 69 | 70 | ts = time.time() 71 | for batch_idx, token_ids_batch in enumerate(token_ids_loader): 72 | src_token_ids_batch, trg_token_ids_batch = token_ids_batch.src, token_ids_batch.trg 73 | if batch_idx % 10 == 0: 74 | print(f'batch={batch_idx}, time elapsed = {time.time()-ts} seconds.') 75 | 76 | # Optimization - compute the source token representations only once 77 | src_mask, _ = get_masks_and_count_tokens_src(src_token_ids_batch, pad_token_id) 78 | src_representations_batch = transformer.encode(src_token_ids_batch, src_mask) 79 | 80 | predicted_sentences = greedy_decoding(transformer, src_representations_batch, src_mask, trg_field_processor) 81 | predicted_sentences_corpus.extend(predicted_sentences) # add them to the corpus of translations 82 | 83 | # Get the token and not id version of GT (ground-truth) sentences 84 | trg_token_ids_batch = trg_token_ids_batch.cpu().numpy() 85 | for target_sentence_ids in trg_token_ids_batch: 86 | target_sentence_tokens = [trg_field_processor.vocab.itos[id] for id in target_sentence_ids if id != pad_token_id] 87 | gt_sentences_corpus.append([target_sentence_tokens]) # add them to the corpus of GT translations 88 | 89 | bleu_score = corpus_bleu(gt_sentences_corpus, predicted_sentences_corpus) 90 | print(f'BLEU-4 corpus score = {bleu_score}, corpus length = {len(gt_sentences_corpus)}, time elapsed = {time.time()-ts} seconds.') 91 | return bleu_score 92 | -------------------------------------------------------------------------------- /utils/optimizers_and_distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class CustomLRAdamOptimizer: 6 | """ 7 | Linear ramp learning rate for the warm-up number of steps and then start decaying 8 | according to the inverse square root law of the current training step number. 9 | 10 | Check out playground.py for visualization of the learning rate (visualize_custom_lr_adam). 11 | """ 12 | 13 | def __init__(self, optimizer, model_dimension, num_of_warmup_steps, current_step_number = 0): 14 | self.optimizer = optimizer 15 | self.model_size = model_dimension 16 | self.num_of_warmup_steps = num_of_warmup_steps 17 | 18 | self.current_step_number = current_step_number 19 | 20 | def step(self): 21 | self.current_step_number += 1 22 | current_learning_rate = self.get_current_learning_rate() 23 | 24 | for p in self.optimizer.param_groups: 25 | p['lr'] = current_learning_rate 26 | 27 | self.optimizer.step() # apply gradients 28 | 29 | # Check out the formula at Page 7, Chapter 5.3 "Optimizer" and playground.py for visualization 30 | def get_current_learning_rate(self): 31 | # For readability purpose 32 | step = self.current_step_number 33 | warmup = self.num_of_warmup_steps 34 | 35 | return self.model_size ** (-0.5) * min(step ** (-0.5), step * warmup ** (-1.5)) 36 | 37 | def zero_grad(self): 38 | self.optimizer.zero_grad() 39 | 40 | 41 | class LabelSmoothingDistribution(nn.Module): 42 | """ 43 | Instead of one-hot target distribution set the target word's probability to "confidence_value" (usually 0.9) 44 | and distribute the rest of the "smoothing_value" mass (usually 0.1) over the rest of the vocab. 45 | 46 | Check out playground.py for visualization of how the smooth target distribution looks like compared to one-hot. 47 | """ 48 | 49 | def __init__(self, smoothing_value, pad_token_id, trg_vocab_size, device): 50 | assert 0.0 <= smoothing_value <= 1.0 51 | 52 | super(LabelSmoothingDistribution, self).__init__() 53 | 54 | self.confidence_value = 1.0 - smoothing_value 55 | self.smoothing_value = smoothing_value 56 | 57 | self.pad_token_id = pad_token_id 58 | self.trg_vocab_size = trg_vocab_size 59 | self.device = device 60 | 61 | def forward(self, trg_token_ids_batch): 62 | 63 | batch_size = trg_token_ids_batch.shape[0] 64 | smooth_target_distributions = torch.zeros((batch_size, self.trg_vocab_size), device=self.device) 65 | 66 | # -2 because we are not distributing the smoothing mass over the pad token index and over the ground truth index 67 | # those 2 values will be overwritten by the following 2 lines with confidence_value and 0 (for pad token index) 68 | smooth_target_distributions.fill_(self.smoothing_value / (self.trg_vocab_size - 2)) 69 | 70 | smooth_target_distributions.scatter_(1, trg_token_ids_batch, self.confidence_value) 71 | smooth_target_distributions[:, self.pad_token_id] = 0. 72 | 73 | # If we had a pad token as a target we set the distribution to all 0s instead of smooth labeled distribution 74 | smooth_target_distributions.masked_fill_(trg_token_ids_batch == self.pad_token_id, 0.) 75 | 76 | return smooth_target_distributions 77 | 78 | 79 | class OneHotDistribution(nn.Module): 80 | """ 81 | Create a one hot distribution (feel free to ignore used only in playground.py) 82 | """ 83 | 84 | def __init__(self, pad_token_id, trg_vocab_size): 85 | 86 | super(OneHotDistribution, self).__init__() 87 | 88 | self.pad_token_id = pad_token_id 89 | self.trg_vocab_size = trg_vocab_size 90 | 91 | def forward(self, trg_token_ids_batch): 92 | 93 | batch_size = trg_token_ids_batch.shape[0] 94 | one_hot_distribution = torch.zeros((batch_size, self.trg_vocab_size)) 95 | one_hot_distribution.scatter_(1, trg_token_ids_batch, 1.) 96 | 97 | # If we had a pad token as a target we set the distribution to all 0s instead of one-hot distribution 98 | one_hot_distribution.masked_fill_(trg_token_ids_batch == self.pad_token_id, 0.) 99 | 100 | return one_hot_distribution 101 | -------------------------------------------------------------------------------- /models/definitions/ALRR_FF.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from utils.constants import * 3 | import torch 4 | 5 | class FFNetwork_XS(nn.ModuleList): 6 | def __init__(self, model_dimension=128,sentence_length=MAX_LEN): 7 | super(FFNetwork_XS, self).__init__() 8 | self.sentence_length=sentence_length 9 | self.model_dimension=model_dimension 10 | self.width=self.sentence_length*self.model_dimension 11 | self.layers=list() 12 | widths=[1,256,1] 13 | self.depth=len(widths)-1 14 | self.layers=nn.ModuleList() 15 | for i in range(self.depth): 16 | self.layers.extend([nn.LayerNorm(self.width // widths[i]),nn.Linear(self.width // widths[i], self.width // widths[i+1])]) 17 | if(i dim: batch size x padded sentence length x embedding size 110 | i = torch.from_numpy(np.load(inf)) 111 | m = torch.from_numpy(np.load(maskf)) 112 | m = torch.squeeze(m, dim=1) 113 | m = torch.squeeze(m, dim=1) 114 | o = torch.from_numpy(np.load(outf)) 115 | l = torch.sum(m, dim = 1) 116 | for j in range(i.shape[0]): 117 | if t == "max": 118 | if l[j] <= n: 119 | self.input.append( i[ j, : l[j] ] ) 120 | self.output.append(o[ j, : l[j] ] ) 121 | self.mask.append( m[ j, : l[j] ] ) 122 | else: 123 | if l[j] == n: 124 | self.input.append(i[j, :l[j]]) 125 | self.output.append(o[j, :l[j]]) 126 | except (UnpicklingError, ValueError): 127 | print(f"Finished loading datasets from {input_path} and {output_path}") 128 | print(f"Loaded {len(self.output)} samples in {time.time() - start}s") 129 | finally: 130 | inf.close() 131 | outf.close() 132 | maskf.close() 133 | # self.input = torch.cat(self.input, dim=0) 134 | # self.output = torch.cat(self.output, dim=0) 135 | torch.save(self.input, in_cache) 136 | torch.save(self.output, out_cache) 137 | if t == "max": 138 | # self.mask = torch.cat(self.mask, dim=0) 139 | torch.save(self.mask, mask_cache) 140 | 141 | def __len__(self): 142 | return len(self.input) 143 | 144 | def __getitem__(self, idx): 145 | # if we have exactly the same length, there is no need for padding/masking 146 | if self.t == "exact": 147 | return (self.input[idx], self.output[idx]) 148 | return (self.input[idx], self.output[idx], self.mask[idx]) 149 | 150 | def emb_size(self): 151 | return self.input.shape[1] 152 | 153 | def pad_shape(batch, masks = False): 154 | shape = batch.shape 155 | if masks: 156 | return shape[0],MAX_LEN-shape[1] 157 | return shape[0], MAX_LEN-shape[1], shape[2] 158 | 159 | def collate_batch(batch): 160 | # pad batch to a fixed length 161 | inputs = pad_sequence([x[0] for x in batch], batch_first=True, padding_value=0) 162 | outputs = pad_sequence([x[1] for x in batch], batch_first=True, padding_value=0) 163 | masks = pad_sequence([x[2] for x in batch], batch_first=True, padding_value=0) 164 | # pad batch to MAX_LEN 165 | inputs = torch.cat([inputs, torch.zeros(pad_shape(inputs))], dim = 1).to(device) 166 | outputs = torch.cat([outputs, torch.zeros(pad_shape(outputs))], dim = 1).to(device) 167 | masks = torch.cat([masks, torch.zeros(pad_shape(masks, masks = True), dtype=torch.bool)], dim = 1).to(device) 168 | # reshape 169 | masks = torch.repeat_interleave(masks, inputs.shape[-1] ,dim=1) 170 | inputs = torch.reshape(inputs, (inputs.shape[0],inputs.shape[1]*inputs.shape[2])) 171 | outputs = torch.reshape(outputs, (outputs.shape[0],outputs.shape[1]*outputs.shape[2])) 172 | 173 | return inputs, outputs, masks 174 | 175 | if __name__ == "__main__": 176 | parser = argparse.ArgumentParser() 177 | parser.add_argument("--num_of_epochs", type=int, help="number of training epochs", default=21) 178 | parser.add_argument("--dataset_path", type=str, help='download dataset to this path', default=DATA_PATH) 179 | parser.add_argument("--model_dimension", type=str, help='embedding size', default=128) 180 | parser.add_argument("--num_of_curr_trained_layer", type=str, help='num_of_curr_trained_layer', default=0) 181 | parser.add_argument("--batch_size", type=str, help='batch_size', default=2000) 182 | parser.add_argument("--substitute_class", type = str, help="name of the FF to train defined in models/definitions/ALR.py", required=True) 183 | parser.add_argument("--language_direction", choices=[el.name for el in LanguageDirection], help='which direction to translate', default=LanguageDirection.de_en.name) 184 | 185 | args = parser.parse_args() 186 | # Wrapping training configuration into a dictionary 187 | training_config = dict() 188 | for arg in vars(args): 189 | training_config[arg] = getattr(args, arg) 190 | 191 | training_config["checkpoints_folder"] = os.path.join(CHECKPOINTS_SCRATCH,"ALRR" ,training_config["substitute_class"], f"layer{training_config['num_of_curr_trained_layer']}") 192 | os.makedirs(training_config["checkpoints_folder"], exist_ok = True) 193 | print("Training arguments parsed") 194 | print("Training layer {0}".format(training_config["num_of_curr_trained_layer"])) 195 | training_replacement_FF(training_config) 196 | -------------------------------------------------------------------------------- /scripts/full_sentence/training_ELR.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import time 4 | 5 | from pickle import UnpicklingError 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | from torch.utils.data import DataLoader, random_split 10 | from torch.optim import Adam 11 | from torch.nn.utils.rnn import pad_sequence 12 | 13 | 14 | # Local imports 15 | from pathlib import Path 16 | import sys 17 | path_root = Path(__file__).parents[2] 18 | sys.path.append(str(path_root)) 19 | 20 | from utils.constants import ALR_CHECKPOINT_FORMAT, SCRATCH, MAX_LEN,CHECKPOINTS_SCRATCH 21 | import models.definitions.ELR_FF as nets 22 | from utils.data_utils import LanguageDirection 23 | 24 | DATA_PATH=os.path.join(SCRATCH, "pytorch-original-transformer","layer_outputs") 25 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # checking whether you have a GPU, I hope so! 26 | 27 | def MAPE(target, output): 28 | #Mean Absolute Percentage Error 29 | with torch.no_grad(): 30 | relative_error = torch.abs(output - target) / torch.max(torch.abs(target), torch.ones(output.shape, device = device)*1e-32) 31 | return torch.mean(relative_error) 32 | 33 | def prepare_data(data_path, language_direction, chosen_layer = 0, batch_size = 5, t = "train", dev = False): 34 | if t not in ["train", "test", "val"]: 35 | raise ValueError("ERROR: t must be train, test, or val.") 36 | in_path = os.path.join(data_path,f"128emb_20ep_IWSLT_{language_direction}_ELR_layer{chosen_layer}_inputs_{t}") 37 | out_path = os.path.join(data_path,f"128emb_20ep_IWSLT_{language_direction}_ELR_layer{chosen_layer}_outputs_{t}") 38 | mask_path = os.path.join(data_path,f"128emb_20ep_IWSLT_{language_direction}_masks_{t}") 39 | dataset = AttentionDataset(in_path, out_path, mask_path, MAX_LEN) 40 | if dev: 41 | dataset, _ = dataset = random_split(dataset, [0.2, 0.8]) 42 | return DataLoader(dataset, collate_fn=collate_batch, batch_size= batch_size) 43 | 44 | def training_replacement_FF(params): 45 | FF_net = getattr(nets, params["substitute_class"]) 46 | print(f"Training model: {FF_net}") 47 | model=FF_net().to(device) 48 | model.train(True) 49 | print("FF model created") 50 | lr_optimizer = Adam(model.parameters(), lr=0.001,betas=(0.9, 0.98), eps=1e-9) 51 | print("Preparing data") 52 | data_loader=prepare_data(params['dataset_path'], params['language_direction'], chosen_layer = params['num_of_curr_trained_layer'], batch_size = params["batch_size"]) 53 | mse_loss=nn.MSELoss() 54 | for epoch in range(params['num_of_epochs']): 55 | print("Epoch: ",epoch) 56 | epoch_loss=0 57 | num_embeddings=0 58 | mapes = [] 59 | start = time.time() 60 | for (data,label, mask) in data_loader: 61 | lr_optimizer.zero_grad() 62 | pred=model(data,mask) 63 | with torch.no_grad(): 64 | num_embeddings+=torch.sum(torch.flatten(mask)).item() 65 | loss_normalizer=torch.sum(torch.flatten(mask)).item()/(mask.shape[0]*mask.shape[1]) 66 | loss=mse_loss(label,pred)/loss_normalizer 67 | loss.backward() 68 | lr_optimizer.step() 69 | with torch.no_grad(): 70 | epoch_loss+=loss.item()*torch.sum(torch.flatten(mask)).item() 71 | mapes.append(MAPE(label, pred)) 72 | if(epoch%20==0): 73 | ckpt_model_name = ALR_CHECKPOINT_FORMAT.format(epoch+1, params['num_of_curr_trained_layer']) 74 | torch.save(model.state_dict(), os.path.join(training_config["checkpoints_folder"],ckpt_model_name)) 75 | print(f"Loss per embedding element:{epoch_loss/num_embeddings}, MAPE: {MAPE(label, pred)}, time: {time.time() - start}") 76 | 77 | class AttentionDataset(torch.utils.data.Dataset): 78 | def __init__(self, input_path, output_path, mask_path, n, t = "max"): 79 | print(f"Starting to load datasets from {input_path} and {output_path} and {mask_path}") 80 | start = time.time() 81 | 82 | self.n = n 83 | if t != "max" and t != "exact": 84 | raise ValueError("ERROR: t has to be either 'max' or 'exact'.") 85 | self.t = t 86 | self.input = [] 87 | self.output = [] 88 | if t == "max": 89 | self.mask = [] 90 | mask_cache = f"{mask_path}_fixed_{n}_{t}.cache" 91 | 92 | in_cache = f"{input_path}_fixed_{n}_{t}.cache" 93 | out_cache = f"{output_path}_fixed_{n}_{t}.cache" 94 | 95 | if os.path.exists(in_cache) and os.path.exists(out_cache) and (t == "exact" or os.path.exists(mask_cache)): 96 | self.input = torch.load(in_cache) 97 | self.output = torch.load(out_cache) 98 | if t == "max": 99 | self.mask = torch.load(mask_cache) 100 | print(f"Finished loading mask dataset from cache {mask_cache}") 101 | print(f"Finished loading datasets from cache {in_cache} and {out_cache}") 102 | print(f"Loaded {len(self.output)} samples in {time.time() - start}s") 103 | return 104 | 105 | inf = open(input_path, "rb") 106 | outf = open(output_path, "rb") 107 | maskf = open(mask_path, "rb") 108 | try: 109 | while(True): 110 | # i represents one batch of sentences -> dim: batch size x padded sentence length x embedding size 111 | i = torch.from_numpy(np.load(inf)) 112 | m = torch.from_numpy(np.load(maskf)) 113 | m = torch.squeeze(m, dim=1) 114 | m = torch.squeeze(m, dim=1) 115 | o = torch.from_numpy(np.load(outf)) 116 | l = torch.sum(m, dim = 1) 117 | for j in range(i.shape[0]): 118 | if t == "max": 119 | if l[j] <= n: 120 | self.input.append( i[ j, : l[j] ] ) 121 | self.output.append(o[ j, : l[j] ] ) 122 | self.mask.append( m[ j, : l[j] ] ) 123 | else: 124 | if l[j] == n: 125 | self.input.append(i[j, :l[j]]) 126 | self.output.append(o[j, :l[j]]) 127 | except (UnpicklingError, ValueError): 128 | print(f"Finished loading datasets from {input_path} and {output_path}") 129 | print(f"Loaded {len(self.output)} samples in {time.time() - start}s") 130 | finally: 131 | inf.close() 132 | outf.close() 133 | maskf.close() 134 | # self.input = torch.cat(self.input, dim=0) 135 | # self.output = torch.cat(self.output, dim=0) 136 | torch.save(self.input, in_cache) 137 | torch.save(self.output, out_cache) 138 | if t == "max": 139 | # self.mask = torch.cat(self.mask, dim=0) 140 | torch.save(self.mask, mask_cache) 141 | 142 | def __len__(self): 143 | return len(self.input) 144 | 145 | def __getitem__(self, idx): 146 | # if we have exactly the same length, there is no need for padding/masking 147 | if self.t == "exact": 148 | return (self.input[idx], self.output[idx]) 149 | return (self.input[idx], self.output[idx], self.mask[idx]) 150 | 151 | def emb_size(self): 152 | return self.input.shape[1] 153 | 154 | def pad_shape(batch, masks = False): 155 | shape = batch.shape 156 | if masks: 157 | return shape[0],MAX_LEN-shape[1] 158 | return shape[0], MAX_LEN-shape[1], shape[2] 159 | 160 | def collate_batch(batch): 161 | # pad batch to a fixed length 162 | inputs = pad_sequence([x[0] for x in batch], batch_first=True, padding_value=0) 163 | outputs = pad_sequence([x[1] for x in batch], batch_first=True, padding_value=0) 164 | masks = pad_sequence([x[2] for x in batch], batch_first=True, padding_value=0) 165 | # pad batch to MAX_LEN 166 | inputs = torch.cat([inputs, torch.zeros(pad_shape(inputs))], dim = 1).to(device) 167 | outputs = torch.cat([outputs, torch.zeros(pad_shape(outputs))], dim = 1).to(device) 168 | masks = torch.cat([masks, torch.zeros(pad_shape(masks, masks = True), dtype=torch.bool)], dim = 1).to(device) 169 | # reshape 170 | masks = torch.repeat_interleave(masks, inputs.shape[-1] ,dim=1) 171 | inputs = torch.reshape(inputs, (inputs.shape[0],inputs.shape[1]*inputs.shape[2])) 172 | outputs = torch.reshape(outputs, (outputs.shape[0],outputs.shape[1]*outputs.shape[2])) 173 | 174 | return inputs, outputs, masks 175 | 176 | if __name__ == "__main__": 177 | parser = argparse.ArgumentParser() 178 | parser.add_argument("--num_of_epochs", type=int, help="number of training epochs", default=21) 179 | parser.add_argument("--dataset_path", type=str, help='download dataset to this path', default=DATA_PATH) 180 | parser.add_argument("--model_dimension", type=str, help='embedding size', default=128) 181 | parser.add_argument("--num_of_curr_trained_layer", type=str, help='num_of_curr_trained_layer', default=0) 182 | parser.add_argument("--batch_size", type=str, help='batch_size', default=2000) 183 | parser.add_argument("--substitute_class", type = str, help="name of the FF to train defined in models/definitions/ALR.py", required=True) 184 | parser.add_argument("--language_direction", choices=[el.name for el in LanguageDirection], help='which direction to translate', default=LanguageDirection.de_en.name) 185 | 186 | args = parser.parse_args() 187 | # Wrapping training configuration into a dictionary 188 | training_config = dict() 189 | for arg in vars(args): 190 | training_config[arg] = getattr(args, arg) 191 | 192 | training_config["checkpoints_folder"] = os.path.join(CHECKPOINTS_SCRATCH,"ELR" ,training_config["substitute_class"], f"layer{training_config['num_of_curr_trained_layer']}") 193 | os.makedirs(training_config["checkpoints_folder"], exist_ok = True) 194 | print("Training arguments parsed") 195 | print("Training layer {0}".format(training_config["num_of_curr_trained_layer"])) 196 | training_replacement_FF(training_config) 197 | -------------------------------------------------------------------------------- /scripts/full_sentence/training_ALSR.py: -------------------------------------------------------------------------------- 1 | from pickle import UnpicklingError 2 | import argparse 3 | import os 4 | import time 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | from torch.utils.data import DataLoader, random_split 10 | from torch.optim import Adam 11 | from torch.nn.utils.rnn import pad_sequence 12 | 13 | 14 | # Local imports 15 | from pathlib import Path 16 | import sys 17 | path_root = Path(__file__).parents[2] 18 | sys.path.append(str(path_root)) 19 | from utils.constants import MHA_SEPARATE_CHECKPOINT_FORMAT, SCRATCH, MAX_LEN, CHECKPOINTS_SCRATCH 20 | import models.definitions.ALSR_FF as nets 21 | from utils.data_utils import LanguageDirection 22 | 23 | DATA_PATH=os.path.join(SCRATCH,"pytorch-original-transformer", "mha_outputs") 24 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # checking whether you have a GPU, I hope so! 25 | 26 | def MAPE(target, output): 27 | #Mean Absolute Percentage Error 28 | with torch.no_grad(): 29 | relative_error = torch.abs(output - target) / torch.max(torch.abs(target), torch.ones(output.shape, device = device)*1e-32) 30 | return torch.mean(relative_error) 31 | 32 | def prepare_data(data_path, language_direction, head = 0, chosen_layer = 0, batch_size = 5, t = "train", dev = False): 33 | if t not in ["train", "test", "val"]: 34 | raise ValueError("ERROR: t must be train, test, or val.") 35 | in_path = os.path.join(data_path, "encoder", f"128emb_20ep_IWSLT_{language_direction}_layer{chosen_layer}_v_inputs_{t}") 36 | out_path = os.path.join(data_path, "encoder", f"128emb_20ep_IWSLT_{language_direction}_layer{chosen_layer}_outputs_{t}") 37 | mask_path = os.path.join(data_path, "encoder", f"128emb_20ep_IWSLT_{language_direction}_masks_{t}") 38 | dataset = SeparateHeadsDataset(in_path, out_path, mask_path, head, MAX_LEN) 39 | print("Training head {0}".format(head)) 40 | if dev: 41 | dataset, _ = dataset = random_split(dataset, [0.2, 0.8]) 42 | return DataLoader(dataset, collate_fn=collate_batch, batch_size= batch_size) 43 | 44 | 45 | def training_replacement_FF(params): 46 | print("Training layer {0}".format(params["num_of_curr_trained_layer"])) 47 | FF_net = getattr(nets, params["substitute_class"]) 48 | for head in range(8): 49 | model=FF_net().to(device) 50 | model.train(True) 51 | print("FF model created") 52 | lr_optimizer = Adam(model.parameters(),betas=(0.9, 0.98), eps=1e-9) 53 | print("Preparing data") 54 | data_loader=prepare_data(params['dataset_path'], params['language_direction'], head=head, chosen_layer = params['num_of_curr_trained_layer'], batch_size = params["batch_size"]) 55 | # TODO: loop over heads, prepare data for the head, train 56 | mse_loss=nn.MSELoss() 57 | # mean_abs_percentage_error = MeanAbsolutePercentageError() 58 | for epoch in range(params['num_of_epochs']): 59 | print("Epoch: ",epoch) 60 | epoch_loss=0 61 | num_embeddings=0 62 | mapes = [] 63 | start = time.time() 64 | for (data,label, mask) in data_loader: 65 | lr_optimizer.zero_grad() 66 | pred=model(data,mask) 67 | with torch.no_grad(): 68 | num_embeddings+=torch.sum(torch.flatten(mask)).item() 69 | loss_normalizer=torch.sum(torch.flatten(mask)).item()/(mask.shape[0]*mask.shape[1]) 70 | loss=mse_loss(label,pred)/loss_normalizer 71 | loss.backward() 72 | lr_optimizer.step() 73 | with torch.no_grad(): 74 | epoch_loss+=loss.item()*torch.sum(torch.flatten(mask)).item() 75 | mapes.append(MAPE(label, pred)) 76 | if (epoch % 20 == 0): 77 | ckpt_model_name = MHA_SEPARATE_CHECKPOINT_FORMAT.format(epoch+1, params['num_of_curr_trained_layer'], head) 78 | torch.save(model.state_dict(), os.path.join(params["checkpoints_folder"], ckpt_model_name)) 79 | print(f"Loss per embedding element:{epoch_loss/num_embeddings}, MAPE: {MAPE(label, pred)}, time: {time.time() - start}") 80 | 81 | class SeparateHeadsDataset(torch.utils.data.Dataset): 82 | # NOTE: added h to specify which head to use 83 | def __init__(self, input_path, output_path, mask_path, h, n, t = "max"): 84 | print(f"Starting to load datasets from {input_path} and {output_path} and {mask_path}") 85 | start = time.time() 86 | 87 | self.n = n 88 | if t != "max" and t != "exact": 89 | raise ValueError("ERROR: t has to be either 'max' or 'exact'.") 90 | self.t = t 91 | self.input = [] 92 | self.output = [] 93 | if t == "max": 94 | self.mask = [] 95 | mask_cache = f"{mask_path}_h_{h}_fixed_{n}_{t}.cache" 96 | 97 | in_cache = f"{input_path}_h_{h}_fixed_{n}_{t}.cache" 98 | out_cache = f"{output_path}_h_{h}_fixed_{n}_{t}.cache" 99 | 100 | if os.path.exists(in_cache) and os.path.exists(out_cache) and (t == "exact" or os.path.exists(mask_cache)): 101 | self.input = torch.load(in_cache) 102 | self.output = torch.load(out_cache) 103 | if t == "max": 104 | self.mask = torch.load(mask_cache) 105 | print(f"Finished loading mask dataset from cache {mask_cache}") 106 | print(f"Finished loading datasets from cache {in_cache} and {out_cache}") 107 | print(f"Loaded {len(self.output)} samples in {time.time() - start}s") 108 | return 109 | 110 | inf = open(input_path, "rb") 111 | outf = open(output_path, "rb") 112 | maskf = open(mask_path, "rb") 113 | try: 114 | while(True): 115 | # i represents one batch of sentences -> dim: batch size x padded sentence length x embedding size 116 | i = torch.from_numpy(np.load(inf)) 117 | m = torch.from_numpy(np.load(maskf)) 118 | m = torch.squeeze(m, dim=1) 119 | m = torch.squeeze(m, dim=1) 120 | o = torch.from_numpy(np.load(outf)) 121 | l = torch.sum(m, dim = 1) 122 | for j in range(i.shape[0]): 123 | if t == "max": 124 | if l[j] <= n: 125 | self.input.append(i[j, :l[j]]) 126 | self.output.append(o[j,h,:l[j]]) 127 | self.mask.append(m[j, :l[j]]) 128 | else: 129 | if l[j] == n: 130 | self.input.append(i[j,h ,:l[j]]) 131 | self.output.append(o[j, :l[j]]) 132 | except (UnpicklingError, ValueError): 133 | print(f"Finished loading datasets from {input_path} and {output_path}") 134 | print(f"Loaded {len(self.output)} samples in {time.time() - start}s") 135 | finally: 136 | inf.close() 137 | outf.close() 138 | maskf.close() 139 | # self.input = torch.cat(self.input, dim=0) 140 | # self.output = torch.cat(self.output, dim=0) 141 | torch.save(self.input, in_cache) 142 | torch.save(self.output, out_cache) 143 | if t == "max": 144 | # self.mask = torch.cat(self.mask, dim=0) 145 | torch.save(self.mask, mask_cache) 146 | 147 | def __len__(self): 148 | return len(self.input) 149 | 150 | def __getitem__(self, idx): 151 | # if we have exactly the same length, there is no need for padding/masking 152 | if self.t == "exact": 153 | return (self.input[idx], self.output[idx]) 154 | return (self.input[idx], self.output[idx], self.mask[idx]) 155 | 156 | def emb_size(self): 157 | return self.input.shape[1] 158 | 159 | def pad_shape(batch, masks = False): 160 | shape = batch.shape 161 | if masks: 162 | return shape[0],MAX_LEN-shape[1] 163 | return shape[0], MAX_LEN-shape[1], shape[2] 164 | 165 | def collate_batch(batch): 166 | """Creates a batch given a list of inputs. The output is the concatenation of the outputs from a single head for each word reperesentation in the sentece. Mask has the same shape as the output because the FF_net should multiply outputs*masks after inference. Here there is no need to multiply the inputs by masks because there is no padding related to the batch. The multiplication with the mask is performed in the AttentionSubistute because there some padding might be added when batching. 167 | 168 | Args: 169 | batch (list): list of tuples (input(S x MD), output(S x HD), batch(S)) 170 | 171 | Returns: 172 | inputs : B x MAX_LEN*MD 173 | outputs: B x MAX_LEN*HD 174 | masks : B x MAX_LEN*HD 175 | """ 176 | # Pad all elements to the same length 177 | inputs = pad_sequence([x[0] for x in batch], batch_first=True, padding_value=0) 178 | outputs = pad_sequence([x[1] for x in batch], batch_first=True, padding_value=0) 179 | masks = pad_sequence([x[2] for x in batch], batch_first=True, padding_value=0) 180 | # print(inputs.shape) 181 | # print(outputs.shape) 182 | # print(masks.shape) 183 | 184 | # Pad to fixed length 185 | inputs = torch.cat([inputs, torch.zeros(pad_shape(inputs))], dim = 1).to(device) 186 | outputs = torch.cat([outputs, torch.zeros(pad_shape(outputs))], dim = 1).to(device) 187 | masks = torch.cat([masks, torch.zeros(pad_shape(masks, masks = True), dtype=torch.bool)], dim = 1).to(device) 188 | 189 | # Reshape concatenating the embeddings for each sentence 190 | masks = torch.repeat_interleave(masks, outputs.shape[-1] ,dim=1) 191 | inputs = torch.reshape(inputs, (inputs.shape[0],inputs.shape[1]*inputs.shape[2])) 192 | outputs = torch.reshape(outputs, (outputs.shape[0],outputs.shape[1]*outputs.shape[2])) 193 | masks = masks.reshape(outputs.shape) 194 | return inputs, outputs, masks 195 | 196 | if __name__ == "__main__": 197 | parser = argparse.ArgumentParser() 198 | parser.add_argument("--num_of_epochs", type=int, help="number of training epochs", default=21) 199 | parser.add_argument("--dataset_path", type=str, help='download dataset to this path', default=DATA_PATH) 200 | parser.add_argument("--model_dimension", type=str, help='embedding size', default=128) 201 | parser.add_argument("--batch_size", type=str, help='batch_size', default=2000) 202 | 203 | # Params to set when running the script 204 | parser.add_argument("--num_of_curr_trained_layer", type=str, help='num_of_curr_trained_layer', default=5) 205 | parser.add_argument("--substitute_class", type = str, help="name of the FF to train defined in models/definitions/ALR.py", required=True) 206 | parser.add_argument("--language_direction", choices=[el.name for el in LanguageDirection], help='which direction to translate', default=LanguageDirection.de_en.name) 207 | args = parser.parse_args() 208 | # Wrapping training configuration into a dictionary 209 | training_config = dict() 210 | for arg in vars(args): 211 | training_config[arg] = getattr(args, arg) 212 | print("Training arguments parsed") 213 | training_config["checkpoints_folder"] = os.path.join(CHECKPOINTS_SCRATCH,"ALSR", training_config["substitute_class"], f"layer{training_config['num_of_curr_trained_layer']}") 214 | os.makedirs(training_config["checkpoints_folder"], exist_ok = True) 215 | print(training_config["checkpoints_folder"]) 216 | training_replacement_FF(training_config) 217 | -------------------------------------------------------------------------------- /utils/simulator.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | import os 4 | import copy 5 | from pickle import UnpicklingError 6 | 7 | import torch 8 | from torch import nn 9 | 10 | import utils.utils as utils 11 | from utils.constants import * 12 | 13 | # Dataset with a single word and the average of its sentence as input 14 | class SingleWordsInterResultsDataset(torch.utils.data.Dataset): 15 | def __init__(self, index_in, index_out, t, device, ext_pref): 16 | assert(t in ["train", "test", "val"]) 17 | assert(ext_pref in ["ELR", "ALR", "ALRR"]) 18 | pref = f"128emb_20ep_IWSLT_E2G" 19 | 20 | mask_path = os.path.join(LAYER_OUTPUT_PATH, f"{pref}_masks_{t}") 21 | input_path = os.path.join(LAYER_OUTPUT_PATH, f"{pref}_{ext_pref}_layer{index_in}_inputs_{t}") 22 | output_path = os.path.join(LAYER_OUTPUT_PATH, f"{pref}_{ext_pref}_layer{index_out}_outputs_{t}") 23 | 24 | self.index_in = index_in 25 | self.index_out = index_out 26 | 27 | print(f"Starting to load datasets from {input_path} and {output_path} and {mask_path}") 28 | start = time.time() 29 | 30 | self.input = [] 31 | self.output = [] 32 | 33 | in_cache = f"{input_path}_single.cache" 34 | out_cache = f"{output_path}_single.cache" 35 | 36 | if os.path.exists(in_cache) and os.path.exists(out_cache): 37 | self.input = torch.load(in_cache, map_location=device) 38 | self.output = torch.load(out_cache, map_location=device) 39 | print(f"Finished loading datasets from cache {in_cache} and {out_cache}") 40 | print(f"Loaded {len(self.output)} samples (flattened) in {time.time() - start}s") 41 | return 42 | 43 | inf = open(input_path, "rb") 44 | outf = open(output_path, "rb") 45 | maskf = open(mask_path, "rb") 46 | 47 | try: 48 | while(True): 49 | # input dimension: B x L x 128 50 | i = torch.from_numpy(np.load(inf)) 51 | # mask dimension: B x 1 x 1 x L 52 | m = torch.from_numpy(np.load(maskf)) 53 | # output dimension: B x L x 128 54 | o = torch.from_numpy(np.load(outf)) 55 | m = m.squeeze(dim=1) 56 | m = m.squeeze(dim=1) 57 | denom = m.sum(dim = 1) 58 | avg = torch.sum(i * m.unsqueeze(2), dim=1) / denom.reshape((-1, 1)) 59 | for j, s in enumerate(i): 60 | inp = torch.cat([s[:denom[j]], avg[j].expand((denom[j], 128))], dim=1) 61 | self.input.append(inp) 62 | out = o[j, :denom[j]] 63 | self.output.append(out) 64 | except (UnpicklingError, ValueError): 65 | print(f"Finished disk access") 66 | print(f"Still need to change the dataset in-memory!") 67 | finally: 68 | inf.close() 69 | outf.close() 70 | maskf.close() 71 | self.input = torch.cat(self.input, dim=0).to(device) 72 | self.output = torch.cat(self.output, dim=0).to(device) 73 | torch.save(self.input, in_cache) 74 | torch.save(self.output, out_cache) 75 | print(f"Loaded {len(self.output)} samples (flattened) in {time.time() - start}s") 76 | 77 | def __len__(self): 78 | return self.input.shape[0] 79 | 80 | def __getitem__(self, idx): 81 | return (self.input[idx], self.output[idx]) 82 | 83 | # Dataset for unchanged access to the data extracted from the transformer (used in sim_all_together.py) 84 | class UnchangedDataset(torch.utils.data.Dataset): 85 | def __init__(self, index_in, index_out, t, device, ext_pref): 86 | assert(t in ["train", "test", "val"]) 87 | assert(ext_pref in ["ELR", "ALR", "ALRR"]) 88 | pref = "128emb_20ep_IWSLT_E2G" 89 | 90 | mask_path = os.path.join(LAYER_OUTPUT_PATH, f"{pref}_masks_{t}") 91 | input_path = os.path.join(LAYER_OUTPUT_PATH, f"{pref}_{ext_pref}_layer{index_in}_inputs_{t}") 92 | output_path = os.path.join(LAYER_OUTPUT_PATH, f"{pref}_{ext_pref}_layer{index_out}_outputs_{t}") 93 | 94 | self.index_in = index_in 95 | self.index_out = index_out 96 | 97 | print(f"Starting to load datasets from {input_path} and {output_path} and {mask_path}") 98 | start = time.time() 99 | 100 | self.input = [] 101 | self.output = [] 102 | self.mask = [] 103 | 104 | inf = open(input_path, "rb") 105 | outf = open(output_path, "rb") 106 | maskf = open(mask_path, "rb") 107 | 108 | try: 109 | while(True): 110 | # input dimension: B x L x 128 111 | i = torch.from_numpy(np.load(inf)).to(device) 112 | # mask dimension: B x 1 x 1 x L 113 | m = torch.from_numpy(np.load(maskf)).to(device) 114 | # output dimension: B x L x 128 115 | o = torch.from_numpy(np.load(outf)).to(device) 116 | self.input.append(i) 117 | self.output.append(o) 118 | self.mask.append(m) 119 | except (UnpicklingError, ValueError): 120 | pass 121 | finally: 122 | inf.close() 123 | outf.close() 124 | maskf.close() 125 | print(f"Loaded {len(self.output)} samples (flattened) in {time.time() - start}s") 126 | 127 | def __len__(self): 128 | return len(self.input) 129 | 130 | def __getitem__(self, idx): 131 | return (self.input[idx], self.output[idx], self.mask[idx]) 132 | 133 | # basically just getting the average of a sentence, but as a module 134 | class ConvertInput(nn.Module): 135 | def __init__(self): 136 | super(ConvertInput, self).__init__() 137 | 138 | def forward(self, src_representations_batch, src_mask): 139 | # input dimension: B x L x 128 140 | # mask dimension: B x 1 x 1 x L 141 | src_mask = src_mask.squeeze(dim=1) 142 | src_mask = src_mask.squeeze(dim=1) 143 | denom = src_mask.sum(dim = 1) 144 | avg = torch.sum(src_representations_batch * src_mask.unsqueeze(2), dim=1) / denom.reshape((-1, 1)) 145 | return avg 146 | 147 | # The class getting as input the concatenation of a word embedding and the average of the sentence 148 | class AttentionSimulator(nn.Module): 149 | def __init__(self, nr_layers, nr_units): 150 | super(AttentionSimulator, self).__init__() 151 | model_dimension = 128 152 | layers = [nn.BatchNorm1d(2*model_dimension)] 153 | def append_layer(in_dim, out_dim): 154 | layers.append(nn.Sequential(nn.Linear(int(in_dim*model_dimension), int(out_dim*model_dimension)), nn.LeakyReLU())) 155 | 156 | assert(nr_layers >= 1) 157 | if (nr_layers == 1): 158 | append_layer(2, 1) 159 | elif isinstance(nr_units, int): 160 | append_layer(2, nr_units) 161 | for i in range(1, nr_layers-1): 162 | append_layer(nr_units, nr_units) 163 | append_layer(nr_units, 1) 164 | else: 165 | assert(len(nr_units)+1 == nr_layers) 166 | append_layer(2, nr_units[0]) 167 | for i in range(1, nr_layers-1): 168 | append_layer(nr_units[i-1], nr_units[i]) 169 | append_layer(nr_units[-1], 1) 170 | self.sequential = nn.Sequential(*layers) 171 | self.name = f"{nr_layers}_{nr_units}".replace(" ", "") 172 | 173 | def forward(self, x): 174 | return self.sequential(x) 175 | 176 | # Adapter to transform the input we get in a transformer to the one the AttentionSimulator expects, basically just reshaping and getting the average 177 | class SimulatorAdapter(nn.Module): 178 | def __init__(self, attention_simulator): 179 | super(SimulatorAdapter, self).__init__() 180 | self.c = ConvertInput() 181 | self.a = attention_simulator 182 | 183 | def forward(self, src_representations_batch, src_mask): 184 | B = src_representations_batch.shape[0] 185 | L = src_representations_batch.shape[1] 186 | 187 | avg = self.c(src_representations_batch, src_mask) 188 | x = src_representations_batch.reshape((B*L, 128)) 189 | y = avg.unsqueeze(1).expand((-1, L, -1)).reshape((B*L, 128)) 190 | src_representations_batch = self.a(torch.cat((x, y), dim=1)).reshape((B, L, 128)) 191 | return src_representations_batch 192 | 193 | # Class to replace the whole encoder with the AttentionSimulators of the "whole" approach. The last layer should also be trained to the norm of the last encoder layer. 194 | class MultipleSimulator(nn.Module): 195 | def __init__(self, sims): 196 | super(MultipleSimulator, self).__init__() 197 | self.layers = nn.ModuleList() 198 | for a in sims: 199 | self.layers.append(SimulatorAdapter(a)) 200 | self.name = f"MultipleSimulator_{self.layers[0].a.name}" 201 | 202 | def forward(self, src_embeddings_batch, src_mask): 203 | src_representations_batch = src_embeddings_batch 204 | for layer in self.layers: 205 | src_representations_batch = layer(src_representations_batch, src_mask) 206 | 207 | return src_representations_batch 208 | 209 | def get_batches(data_set, batch_size): 210 | return [(i, min(i+batch_size, len(data_set)-1)) for i in range(0, len(data_set), batch_size)] 211 | 212 | def get_checkpoint_name(model_name, batch_size, index_in, index_out, epoch, ext_pref): 213 | assert(ext_pref in ["ELR", "ALR", "ALRR"]) 214 | inst_name = f"{model_name}_bs{batch_size}_fr{index_in}_to{index_out}_{ext_pref}" 215 | ckpt_model_name = f"{inst_name}_ckpt_epoch_{epoch}.pth" 216 | return ckpt_model_name 217 | 218 | # replace one encoder layer 219 | class EncoderLayerSubstituteJ(nn.Module): 220 | def __init__(self, encoder_layer, new_layer): 221 | super().__init__() 222 | self.sublayer_zero = SublayerSubstituteJ(encoder_layer.sublayers[0], new_layer) 223 | self.sublayers = encoder_layer.sublayers 224 | self.pointwise_net = encoder_layer.pointwise_net 225 | self.model_dimension = encoder_layer.model_dimension 226 | self.mha = encoder_layer.multi_headed_attention 227 | 228 | def forward(self, src_representations_batch, src_mask): 229 | src_representations_batch = self.sublayer_zero(src_representations_batch, src_mask) 230 | src_representations_batch = self.sublayers[1](src_representations_batch, self.pointwise_net) 231 | 232 | return src_representations_batch 233 | 234 | # replace one sublayer 235 | class SublayerSubstituteJ(nn.Module): 236 | def __init__(self, old_sublayer, layer): 237 | super().__init__() 238 | self.norm = copy.deepcopy(old_sublayer.norm) 239 | self.dropout = copy.deepcopy(old_sublayer.dropout) 240 | self.layer = copy.deepcopy(layer) 241 | 242 | def forward(self, representations_batch, src_mask): 243 | # Residual connection between input and sublayer output, details: Page 7, Chapter 5.4 "Regularization", 244 | return representations_batch + self.dropout(self.layer(self.norm(representations_batch), src_mask)) 245 | 246 | class MHAWrap(nn.Module): 247 | def __init__(self, mha): 248 | super().__init__() 249 | self.mha = copy.deepcopy(mha) 250 | 251 | def forward(self, srb, src_mask): 252 | return self.mha(srb, srb, srb, src_mask) 253 | 254 | def restructure_encoder_layers(transformer): 255 | new_layers = nn.ModuleList() 256 | for enc_l in transformer.encoder.encoder_layers: 257 | new_layers.append(EncoderLayerSubstituteJ(enc_l, MHAWrap(enc_l.multi_headed_attention))) 258 | transformer.encoder.encoder_layers = new_layers -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## The Rethinking-attention paper official implementation (AAAI24) :computer: = :rainbow: 2 | This repository contains the code developed for the [Rethinking attention paper](arxiv link) which got accepted at AAAI24 conference. 3 | In this project we tried to replace the self-attention with feed-forward networks (FFN) to evaluate the importance of attention inside the transformer. 4 | 5 | The developed code builds on top of an open-source implementation by Aleksa Gordic, 6 | [ pytorch-original-transformer](https://github.com/gordicaleksa/pytorch-original-transformer) 7 | of the original transformer [Vaswani et al.](https://arxiv.org/abs/1706.03762). 8 | 9 | ## Table of Contents 10 | * [Environment setup](#environment-setup) 11 | * [Code overview](#code-overview) 12 | * [Train baseline transformer](#train-baseline-transformer) 13 | * [Intermediate data extraction](#intermediate-data-extraction) 14 | * [Train replacement FF networks](#full-sentence-approach) 15 | * [Evaluation](#evaluation) 16 | 17 | 18 | ## Environment setup 19 | 20 | 1. Navigate into the project directory `cd path_to_repo` 21 | 2. Run `conda env create` from project directory (this will create a brand new conda environment). 22 | 3. Run `rethinking-attention` (for running scripts from your console or set the interpreter in your IDE) 23 | 4. Run `export SCRATCH=path_to_outputs`, where `path_to_outputs` is the directory where you want the output files to be stored. If you are running this code on the Euler cluster, the variable is already defined. 24 | 5. Execute `./scripts/baseline/prepare_dataset.py` to download the IWSLT2017 dataset which will be used in this project. Choose between different subsets for download: English-to-German (E2G), English-to-French (E2F)... 25 | 26 | In the following, all the commands are assumed to be run from the root of the repository. 27 | 28 | ## Code overview 29 | 30 | As previously mentioned, the code was developed on top of an existing implementation of the transformer. Our main contribution to this code resides in 31 | - the folder `./scripts` which contains scripts for extracting intermediate data, training different architectures and evaluating them in the transformer. 32 | - The file `./utils/full_sentence_utils.py` contains the classes and functions for substituting FFN in the transformer. 33 | 34 | The provided code was run on different GPUs all with a minimum of 11GB of memory, but up to 24GB. 35 | In case your GPU does not have this much memory you should try to reduce the batch size. However, some of the bigger architectures may not work properly. 36 | 37 | The description on how to run the code is general for any platform. Since we run the code on a cluster which uses slurm, we left in the `./submission_scripts` folder 38 | the wrapper scripts which were used to submit jobs. If you want to use them, please either adjust the path output-path argument or create a folder `./sbatch_log` 39 | which will collect all the program outputs. 40 | The folder `./submission_scripts` is organised as the `./scripts` folder. 41 | The wrapper script for *scripts/example_script.py* is located at *submission_scripts/example_script.sh*. 42 | 43 | ## Train baseline transformer 44 | 45 | Initial step in the pipeline would be training the transformer. It could be done by running: 46 | 47 | 1. Execute `python3 ./scripts/baseline/training_script.py` 48 | 49 | In the file you would be able to find specific arguments to control the training process. 50 | 51 | 52 | ## Intermediate data extraction 53 | 54 | To train our FFNs we first extract the intermediate values that are given as input and output to the attention module. To extract the intermediate data (e.g. from English to German dataset) run 55 | 1. `python3 ./scripts/extraction/extract.py --path_to_weights ./models/binaries/Transformer_None_None_20.pth --batch_size 1400 --dataset_name IWSLT --language_direction en_de --model_name 128emb_20ep --output_path ./mha_outputs` 56 | 2. `python3 ./scripts/extraction/extract_mha.py --path_to_weights ./models/binaries/Transformer_None_None_20.pth --batch_size 1400 --dataset_name IWSLT --language_direction en_de --model_name 128emb_20ep --output_path ./mha_outputs` 57 | 58 | The first script extracts inputs and outputs of 59 | - each encoder layer (identified by *ELR* in the file name), 60 | - each multi-headed attention (MHA) module (identified by *ALR* in the file name), 61 | - each "sublayer zero" which consists of the MHA, the layer normalization and the residual connection (identified by *ALRR* in the file name). 62 | 63 | The second script extracts inputs and outputs of 64 | - each MHA excluded the linear layer which mixes the values extracted by each head. This is to enable learning the output of each head separately as in the 'separate head' approach. 65 | 66 | At the end of this section, your main folder should contain one folder *output_layers* containing the output of the first script and one folder *mha_outputs* 67 | with the outputs of the second script. This same script simultaneously extracts the data from all 3 types of attention and stores it in the same file. These values are used to train FFNs which replace attention with different layers of abstraction. 68 | 69 | ## Train replacement FF networks 70 | 71 | In this step, the FFN takes in the concatenated word representations of a sentence as input and produces updated word representations as output in a single pass. 72 | In order to handle input sentences of varying lengths, we have decided to pad all sentences to a maximum fixed length and mask the padded values with zeros to 73 | prevent them from influencing the model's inference. 74 | 75 | We tried substituting attention with three layer of abstraction: 76 | - *Encoder Layer Replacement (ELR)*: replaces the whole encoder layer in the encoder 77 | - *Attention Layer with Residual Replacement(ALRR)*: replaces the MHA and the residual connection 78 | - *Attention Layer Replacement (ALR)*: replaces only the MHA 79 | - *Attention Layer Separate heads Replacement (ALSR)*: replaces the same part as *ALR*, but one FFN is trained for each head. 80 | 81 | The architecture used for each approach are listed in 82 | - `./models/definitions/ALRR_FF.py` 83 | - `./models/definitions/ALR_FF.py` 84 | - `./models/definitions/ALSR_FF.py`. 85 | - `./models/definitions/ELR_FF.py`. 86 | 87 | For the final experiment we considered 5 architectures ranging from extra small (XS) 88 | to large (L). The considered range of number of parameters shows the operating range of the FFN. In particular, the XS network reduces the BLEU score of the transformer, while as the number of parameter grows, so does the BLEU up to saturation with the L network. 89 | Each approach uses a different training script. Each training script contains a data loader responsible for loading the data extracted at the previous step and 90 | creating batches of a fixed length *MAX_LEN* (using padding). Each training script receives as input the name of the substitute class (e.g. `FFNetwork_L`) 91 | and the index of the layer to emulate. The training loop iterates over the training data for a specified maximum number of epochs. 92 | The instruction for running the training scripts are listed below. 93 | 94 | ### Training `ALRR` 95 | 96 | To train one of the architectures defined in `models/definitions/ALRR.py` for a specific layer run: 97 | `python3 ./scripts/full_sentence/training_ALRR.py --num_of_curr_trained_layer [0-5] --substitute_class `. 98 | For example to train the network *FFNetwork_L* to substitute layer zero run 99 | `python3 ./scripts/training_ALRR.py --num_of_curr_trained_layer 0 --substitute_class FFNetwork_L`. 100 | 101 | ### Training `ALSR` 102 | 103 | To train one of the architectures defined in `./models/definitions/ALSR_FF.py` for a specific layer run: 104 | `python3 ./scripts/full_sentence/training_ALR.py --num_of_curr_trained_layer [0-5] --substitute_class `. 105 | For example to train the network *FFNetwork_L* to substitute layer zero with 8 heads, one for each head in the MHA of layer zero, run: 106 | `python3 ./scripts/full_sentence/training_ALSR.py --num_of_curr_trained_layer 0 --substitute_class FFNetwork_L`. 107 | 108 | ### Training `ALRR` 109 | 110 | To train one of the architectures defined in `models/definitions/ELR.py` for a specific layer run: 111 | `python3 ./scripts/full_sentence/training_ELR.py --num_of_curr_trained_layer [0-5] --substitute_class `. 112 | For example to train the network *FFNetwork_L* to substitute layer zero run 113 | `python3 ./scripts/training_ELR.py --num_of_curr_trained_layer 0 --substitute_class FFNetwork_L`. 114 | 115 | 116 | ### Training `ALR` 117 | 118 | To train one of the architectures defined in `./models/definitions/ALR_FF.py` for a specific layer run: 119 | `python3 ./scripts/full_sentence/training_ALR.py --num_of_curr_trained_layer [0-5] --substitute_class `. 120 | For example to train the network *FFNetwork_L* to substitute layer zero with 8 heads, one for each head in the MHA of layer zero, run: 121 | `python3 ./scripts/full_sentence/training_ALR.py --num_of_curr_trained_layer 0 --substitute_class FFNetwork_L`. 122 | 123 | ### Training `ALR` in the decoder 124 | The `ALR` approach was also used to train self-attention and cross-attention in the decoder. The architecture used in the decoder are denoted by the word *decoder* and *cross_decoder* in the class name. 125 | To train one of this architecture to substitute self-attention in the decoder layer run 126 | `python3 ./scripts/full_sentence/training_ALR.py --num_of_curr_trained_layer [0-5] --substitute_class FFNetwork_decoder_L --decoder` 127 | 128 | To train one of this architecture to substitute cross-attention in the decoder layer run 129 | `python3 ./scripts/full_sentence/training_ALR.py --num_of_curr_trained_layer [0-5] --substitute_class FFNetwork_cross_decoder_L --decoder_ca` 130 | 131 | In case you are running this code on a cluster which uses slurm, the script `./submission_scripts/training_ALR_FF_submit_all.sh` can be used to automatically 132 | submit the training of a network for each layer (0-5). 133 | If you use that script, please make sure that the path specified for the output of the program exists. 134 | The script currently assumes a directory `./sbatch_log` which will collect all the outputs. 135 | 136 | ## Evaluation 137 | 138 | All the networks trained in the previous step can be evaluated using `./scripts/full_sentence/validation_script.py`. 139 | The validation is performed substituting the trained FFN in the pretrained transformer and computing the BLEU score on the validation data. 140 | The script receives as inputs the following parameters: 141 | - `substitute_type`: type of approach to use for substitution. Must be in [`ALRR`, `ALR`, `ALSR`, `ELR`, `None`]. If `None`, no substitution takes place; 142 | - `substitute_class`: class that substitutes attention e.g. *FFNetwork_L*; 143 | - `layers`: list of layers to substitute. If layer is not specified, all layers are substituted; 144 | - `epoch`: epoch checkpoint to use; 145 | 146 | The second-to-last four attributes appended with `_d` can be used to substitute self-attention in the decoder, while last four, appended with `_d_ca` substitute cross-attention layers. Currently, only the `ALR` supports substitution 147 | in the decoder layer. 148 | To run the evaluation script the following command can be used 149 | `python3 ./scripts/full_sentence/validation_script.py --substitute_type --substitute_class --layers [0-5]* --epoch ` 150 | As an example if you want to evaluate the performance of *FFNetwork_L* in the `ALR` approach, substituting all layers in the encoder with 151 | the checkpoint at epoch 21 the following command can be used: 152 | `python3 ./scripts/full_sentence/validation_script.py --substitute_type ALR --substitute_class FFNetwork_L --epoch 21` 153 | -------------------------------------------------------------------------------- /utils/decoding_utils.py: -------------------------------------------------------------------------------- 1 | import enum 2 | 3 | 4 | import torch 5 | import numpy as np 6 | 7 | 8 | from .constants import * 9 | from utils.data_utils import get_masks_and_count_tokens_trg 10 | 11 | 12 | class DecodingMethod(enum.Enum): 13 | GREEDY = 0, 14 | BEAM = 1 15 | 16 | 17 | def greedy_decoding(baseline_transformer, src_representations_batch, src_mask, trg_field_processor, max_target_tokens=MAX_LEN): 18 | """ 19 | Supports batch (decode multiple source sentences) greedy decoding. 20 | 21 | Decoding could be further optimized to cache old token activations because they can't look ahead and so 22 | adding a newly predicted token won't change old token's activations. 23 | 24 | Example: we input and do a forward pass. We get intermediate activations for and at the output at position 25 | 0, after the doing linear layer we get e.g. token . Now we input , but 's activations will remain 26 | the same. Similarly say we now got at output position 1, in the next step we input ,, and so 's 27 | activations will remain the same as it only looks at/attends to itself and to and so forth. 28 | 29 | """ 30 | 31 | device = next(baseline_transformer.parameters()).device 32 | pad_token_id = trg_field_processor.vocab.stoi[PAD_TOKEN] 33 | 34 | # Initial prompt is the beginning/start of the sentence token. Make it compatible shape with source batch => (B,1) 35 | target_sentences_tokens = [[BOS_TOKEN] for _ in range(src_representations_batch.shape[0])] 36 | trg_token_ids_batch = torch.tensor([[trg_field_processor.vocab.stoi[tokens[0]]] for tokens in target_sentences_tokens], device=device) 37 | 38 | # Set to true for a particular target sentence once it reaches the EOS (end-of-sentence) token 39 | is_decoded = [False] * src_representations_batch.shape[0] 40 | 41 | while True: 42 | trg_mask, _ = get_masks_and_count_tokens_trg(trg_token_ids_batch, pad_token_id) 43 | # Shape = (B*T, V) where T is the current token-sequence length and V target vocab size 44 | predicted_log_distributions = baseline_transformer.decode(trg_token_ids_batch, src_representations_batch, trg_mask, src_mask) 45 | 46 | # Extract only the indices of last token for every target sentence (we take every T-th token) 47 | num_of_trg_tokens = len(target_sentences_tokens[0]) 48 | predicted_log_distributions = predicted_log_distributions[num_of_trg_tokens-1::num_of_trg_tokens] 49 | 50 | # This is the "greedy" part of the greedy decoding: 51 | # We find indices of the highest probability target tokens and discard every other possibility 52 | most_probable_last_token_indices = torch.argmax(predicted_log_distributions, dim=-1).cpu().numpy() 53 | 54 | # Find target tokens associated with these indices 55 | predicted_words = [trg_field_processor.vocab.itos[index] for index in most_probable_last_token_indices] 56 | 57 | for idx, predicted_word in enumerate(predicted_words): 58 | target_sentences_tokens[idx].append(predicted_word) 59 | 60 | if predicted_word == EOS_TOKEN: # once we find EOS token for a particular sentence we flag it 61 | is_decoded[idx] = True 62 | 63 | if all(is_decoded) or num_of_trg_tokens == max_target_tokens: 64 | break 65 | 66 | # Prepare the input for the next iteration (merge old token ids with the new column of most probable token ids) 67 | trg_token_ids_batch = torch.cat((trg_token_ids_batch, torch.unsqueeze(torch.tensor(most_probable_last_token_indices, device=device), 1)), 1) 68 | 69 | # Post process the sentences - remove everything after the EOS token 70 | target_sentences_tokens_post = [] 71 | for target_sentence_tokens in target_sentences_tokens: 72 | try: 73 | target_index = target_sentence_tokens.index(EOS_TOKEN) + 1 74 | except: 75 | target_index = None 76 | 77 | target_sentence_tokens = target_sentence_tokens[:target_index] 78 | target_sentences_tokens_post.append(target_sentence_tokens) 79 | 80 | return target_sentences_tokens_post 81 | 82 | 83 | def get_beam_decoder(translation_config): 84 | """ 85 | Note: this implementation could probably be further optimized I just wanted a decent working version. 86 | 87 | Notes: 88 | 89 | https://arxiv.org/pdf/1609.08144.pdf introduces various heuristics into the beam search algorithm like coverage 90 | penalty, etc. Here I only designed a simple beam search algorithm with length penalty. As the probability of the 91 | sequence is constructed by multiplying the conditional probabilities (which are numbers smaller than 1) the beam 92 | search algorithm will prefer shorter sentences which we compensate for using the length penalty. 93 | 94 | """ 95 | beam_size = translation_config['beam_size'] 96 | length_penalty_coefficient = translation_config['length_penalty_coefficient'] 97 | 98 | def beam_decoding(baseline_transformer, src_representations_batch, src_mask, trg_field_processor, max_target_tokens=100): 99 | raise Exception('Not yet implemented.') 100 | device = next(baseline_transformer.parameters()).device 101 | pad_token_id = trg_field_processor.vocab.stoi[PAD_TOKEN] 102 | 103 | # Initial prompt is the beginning/start of the sentence token. Make it compatible shape with source batch => (B,1) 104 | batch_size, S, model_dimension = src_representations_batch.shape 105 | target_multiple_hypotheses_tokens = [[BOS_TOKEN] for _ in range(batch_size)] 106 | trg_token_ids_batch = torch.tensor([[trg_field_processor.vocab.stoi[tokens[0]]] for tokens in target_multiple_hypotheses_tokens], device=device) 107 | 108 | # Repeat so that source sentence representations are repeated contiguously, say we have [s1, s2] we want 109 | # [s1, s1, s2, s2] and not [s1, s2, s1, s2] where s1 is single sentence representation with shape=(S, D) 110 | # where S - max source token-sequence length, D - model dimension 111 | src_representations_batch = src_representations_batch.repeat(1, beam_size, 1).view(beam_size*batch_size, -1, model_dimension) 112 | trg_token_ids_batch = trg_token_ids_batch.repeat(beam_size, 1) 113 | 114 | hypotheses_log_probs = torch.zeros((batch_size * beam_size, 1), device=device) 115 | had_eos = [[False] for _ in range(hypotheses_log_probs.shape[0])] 116 | 117 | while True: 118 | trg_mask, _ = get_masks_and_count_tokens_trg(trg_token_ids_batch, pad_token_id) 119 | # Shape = (B*BS*T, V) T - current token-sequence length, V - target vocab size, BS - beam size, B - batch 120 | predicted_log_distributions = baseline_transformer.decode(trg_token_ids_batch, src_representations_batch, trg_mask, src_mask) 121 | 122 | # Extract only the indices of last token for every target sentence (we take every T-th token) 123 | # Shape = (B*BS, V) 124 | num_of_trg_tokens = trg_token_ids_batch.shape[-1] 125 | predicted_log_distributions = predicted_log_distributions[num_of_trg_tokens - 1::num_of_trg_tokens] 126 | 127 | # This time extract beam_size number of highest probability tokens (compare to greedy's arg max) 128 | # Shape = (B*BS, BS) 129 | latest_token_log_probs, most_probable_token_indices = torch.topk(predicted_log_distributions, beam_size, dim=-1, sorted=True) 130 | 131 | # Don't update the hypothesis which had EOS already (pruning) 132 | latest_token_log_probs.masked_fill(torch.tensor(had_eos == True), float("-inf")) 133 | 134 | # Calculate probabilities for every beam hypothesis (since we have log prob we add instead of multiply) 135 | # Shape = (B*BS, BS) 136 | hypotheses_pool_log_probs = hypotheses_log_probs + latest_token_log_probs 137 | # Shape = (B, BS, BS) 138 | most_probable_token_indices = most_probable_token_indices.view(batch_size, beam_size, beam_size) 139 | hypotheses_pool_log_probs = hypotheses_pool_log_probs.view(batch_size, beam_size, beam_size) 140 | # Shape = (B, BS*BS) 141 | hypotheses_pool_log_probs = torch.flatten(hypotheses_pool_log_probs, start_dim=-1) 142 | 143 | # Figure out indices of beam_size most probably hypothesis for every target sentence in the batch 144 | # Shape = (B, BS) 145 | new_hypothesis_log_probs, next_hypothesis_indices = torch.topk(hypotheses_pool_log_probs, beam_size, dim=-1, sorted=True) 146 | 147 | # Create new target ids batch 148 | hypotheses_log_probs_tmp = torch.empty((batch_size * beam_size, 1)) 149 | 150 | T = trg_token_ids_batch.shape[-1] 151 | new_trg_token_ids_batch = torch.empty((batch_size * beam_size, T + 1)) 152 | 153 | next_hypothesis_indices = next_hypothesis_indices.cpu().numpy() 154 | # Prepare new hypotheses for the next iteration 155 | for b_idx, indices in enumerate(next_hypothesis_indices): 156 | for h_idx, token_index in indices: 157 | row, column = token_index / beam_size, token_index % beam_size 158 | hypothesis_index = b_idx * beam_size + h_idx 159 | 160 | new_token_id = most_probable_token_indices[b_idx, row, column] 161 | if had_eos[hypothesis_index]: 162 | new_trg_token_ids_batch[hypothesis_index, :-1] = trg_token_ids_batch[hypothesis_index, :] 163 | else: 164 | new_trg_token_ids_batch[hypothesis_index, :-1] = trg_token_ids_batch[b_idx * beam_size + row, :] 165 | new_trg_token_ids_batch[hypothesis_index, -1] = new_token_id 166 | 167 | if had_eos[hypothesis_index]: 168 | hypotheses_log_probs_tmp[hypothesis_index] = hypotheses_log_probs[hypothesis_index] 169 | else: 170 | hypotheses_log_probs_tmp[hypothesis_index] = new_hypothesis_log_probs[hypothesis_index] 171 | 172 | if new_token_id == trg_field_processor.vocab.stoi[EOS_TOKEN]: 173 | had_eos[hypothesis_index] = True 174 | 175 | # Update the current hypothesis probabilities 176 | hypotheses_log_probs = hypotheses_log_probs_tmp 177 | trg_token_ids_batch = new_trg_token_ids_batch 178 | 179 | if all(had_eos) or num_of_trg_tokens == max_target_tokens: 180 | break 181 | 182 | # 183 | # Selection and post-processing 184 | # 185 | 186 | target_multiple_hypotheses_tokens = [] 187 | trg_token_ids_batch_numpy = trg_token_ids_batch.cpu().numpy() 188 | for hypothesis_ids in trg_token_ids_batch_numpy: 189 | target_multiple_hypotheses_tokens.append([trg_field_processor.vocab.itos[token_id] for token_id in hypothesis_ids]) 190 | 191 | # Step 1: Select the most probable hypothesis out of beam_size hypotheses for each target sentence 192 | hypotheses_log_probs = hypotheses_log_probs.view(batch_size, beam_size) 193 | most_probable_hypotheses_indices = torch.argmax(hypotheses_log_probs, dim=-1).cpu().numpy() 194 | target_sentences_tokens = [] 195 | for b_idx, index in enumerate(most_probable_hypotheses_indices): 196 | target_sentences_tokens.append(target_multiple_hypotheses_tokens[b_idx * beam_size + index]) 197 | 198 | # Step 2: Post process the sentences - remove everything after the EOS token 199 | target_sentences_tokens_post = [] 200 | for target_sentence_tokens in target_sentences_tokens: 201 | try: 202 | target_index = target_sentence_tokens.index(EOS_TOKEN) + 1 203 | except: 204 | target_index = None 205 | 206 | target_sentence_tokens = target_sentence_tokens[:target_index] 207 | target_sentences_tokens_post.append(target_sentence_tokens) 208 | 209 | return target_sentences_tokens_post 210 | 211 | return beam_decoding 212 | 213 | -------------------------------------------------------------------------------- /scripts/baseline/training_script.py: -------------------------------------------------------------------------------- 1 | """ 2 | Notes: 3 | * I won't add model checkpoint averaging as mentioned in the paper - it just feels like an arbitrary heuristic 4 | and it won't add anything to the learning experience this repo aims to provide. 5 | 6 | """ 7 | 8 | 9 | import argparse 10 | import time 11 | 12 | 13 | import torch 14 | from torch import nn 15 | from torch.optim import Adam 16 | 17 | # Handle imports from utils 18 | from pathlib import Path 19 | import sys 20 | path_root = Path(__file__).parents[2] 21 | sys.path.append(str(path_root)) 22 | 23 | from utils.optimizers_and_distributions import CustomLRAdamOptimizer, LabelSmoothingDistribution 24 | from models.definitions.transformer_model import Transformer 25 | from utils.data_utils import get_data_loaders, get_masks_and_count_tokens, get_src_and_trg_batches, DatasetType, LanguageDirection 26 | import utils.utils as utils 27 | from utils.constants import * 28 | from utils.full_sentence_utils import substitute_attention 29 | 30 | # Global vars for logging purposes 31 | num_of_trg_tokens_processed = 0 32 | bleu_scores = [] 33 | global_train_step, global_val_step = [0, 0] 34 | 35 | 36 | # Simple decorator function so that I don't have to pass these arguments every time I call get_train_val_loop 37 | def get_train_val_loop(baseline_transformer, custom_lr_optimizer, kl_div_loss, label_smoothing, pad_token_id, time_start): 38 | 39 | def train_val_loop(is_train, token_ids_loader, epoch): 40 | global num_of_trg_tokens_processed, global_train_step, global_val_step 41 | 42 | if is_train: 43 | baseline_transformer.train() 44 | else: 45 | baseline_transformer.eval() 46 | 47 | device = next(baseline_transformer.parameters()).device 48 | 49 | # 50 | # Main loop - start of the CORE PART 51 | # 52 | for batch_idx, token_ids_batch in enumerate(token_ids_loader): 53 | src_token_ids_batch, trg_token_ids_batch_input, trg_token_ids_batch_gt = get_src_and_trg_batches(token_ids_batch) 54 | src_mask, trg_mask, num_src_tokens, num_trg_tokens = get_masks_and_count_tokens(src_token_ids_batch, trg_token_ids_batch_input, pad_token_id, device) 55 | 56 | # log because the KL loss expects log probabilities (just an implementation detail) 57 | predicted_log_distributions = baseline_transformer(src_token_ids_batch, trg_token_ids_batch_input, src_mask, trg_mask) 58 | smooth_target_distributions = label_smoothing(trg_token_ids_batch_gt) # these are regular probabilities 59 | 60 | if is_train: 61 | custom_lr_optimizer.zero_grad() # clean the trainable weights gradients in the computational graph 62 | 63 | loss = kl_div_loss(predicted_log_distributions, smooth_target_distributions) 64 | 65 | if is_train: 66 | loss.backward() # compute the gradients for every trainable weight in the computational graph 67 | custom_lr_optimizer.step() # apply the gradients to weights 68 | 69 | # End of CORE PART 70 | 71 | # 72 | # Logging and metrics 73 | # 74 | 75 | if is_train: 76 | global_train_step += 1 77 | num_of_trg_tokens_processed += num_trg_tokens 78 | 79 | if training_config['console_log_freq'] is not None and batch_idx % training_config['console_log_freq'] == 0: 80 | print(f'Transformer training: time elapsed= {(time.time() - time_start):.2f} [s] ' 81 | f'| epoch={epoch + 1} | batch= {batch_idx + 1} ' 82 | f'| target tokens/batch= {num_of_trg_tokens_processed / training_config["console_log_freq"]}') 83 | 84 | num_of_trg_tokens_processed = 0 85 | 86 | # Save model checkpoint 87 | if training_config['checkpoint_freq'] is not None and (epoch + 1) % training_config['checkpoint_freq'] == 0 and batch_idx == 0: 88 | ckpt_model_name = f"transformer_ckpt_epoch_{epoch + 1}.pth" 89 | torch.save(utils.get_training_state(training_config, custom_lr_optimizer.current_step_number, baseline_transformer), os.path.join(CHECKPOINTS_PATH, ckpt_model_name)) 90 | else: 91 | global_val_step += 1 92 | 93 | return train_val_loop 94 | 95 | 96 | def train_transformer(training_config): 97 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # checking whether you have a GPU, I hope so! 98 | # device = "cpu" 99 | # Step 1: Prepare data loaders 100 | # NOTE: If we wanted to load the pretrained transformer, we would need to first load the entire training data to get the full vocabulary. Then reload the dataset filtering for sentences s.t. S <= MAX_LEN 101 | train_token_ids_loader, val_token_ids_loader, test_token_ids_loader, src_field_processor, trg_field_processor = get_data_loaders( 102 | training_config['dataset_path'], 103 | training_config['language_direction'], 104 | training_config['dataset_name'], 105 | training_config['batch_size'], 106 | device) 107 | 108 | pad_token_id = src_field_processor.vocab.stoi[PAD_TOKEN] # pad token id is the same for target as well 109 | src_vocab_size = len(src_field_processor.vocab) 110 | trg_vocab_size = len(trg_field_processor.vocab) 111 | 112 | # Step 2: Prepare the model (original transformer) and push to GPU 113 | baseline_transformer = Transformer( 114 | model_dimension=BASELINE_MODEL_DIMENSION, 115 | src_vocab_size=src_vocab_size, 116 | trg_vocab_size=trg_vocab_size, 117 | number_of_heads=BASELINE_MODEL_NUMBER_OF_HEADS, 118 | number_of_layers=BASELINE_MODEL_NUMBER_OF_LAYERS, 119 | dropout_probability=BASELINE_MODEL_DROPOUT_PROB 120 | ).to(device) 121 | # model_path = os.path.join(BINARIES_PATH, training_config['model_name']) 122 | # model_state = torch.load(model_path) 123 | # baseline_transformer.load_state_dict(model_state["state_dict"], strict=True) 124 | # baseline_transformer.train() 125 | 126 | # reloading the data, filtering sentences of len>50 127 | train_token_ids_loader, val_token_ids_loader, test_token_ids_loader, src_field_processor, trg_field_processor = get_data_loaders( 128 | training_config['dataset_path'], 129 | training_config['language_direction'], 130 | training_config['dataset_name'], 131 | training_config['batch_size'], 132 | device, 133 | max_len_train=MAX_LEN) 134 | 135 | # Step 3: substitute attention 136 | if training_config["substitute_type"] != "None": 137 | substitute_attention(baseline_transformer, 138 | training_config["substitute_class"], 139 | training_config["substitute_model_path"], 140 | training_config["layer"], 141 | training_config["epoch"], 142 | training_config["substitute_type"], 143 | training_config["untrained"]) 144 | else: 145 | print("#"*100) 146 | print("\n\t NO SUBSTITUTION \n") 147 | print("#"*100) 148 | 149 | #baseline_transformer=torch.nn.DataParallel(baseline_transformer,device_ids=list(range(4))) 150 | # Step 3: Prepare other training related utilities 151 | kl_div_loss = nn.KLDivLoss(reduction='batchmean') # gives better BLEU score than "mean" 152 | 153 | # Makes smooth target distributions as opposed to conventional one-hot distributions 154 | # My feeling is that this is a really dummy and arbitrary heuristic but time will tell. 155 | label_smoothing = LabelSmoothingDistribution(BASELINE_MODEL_LABEL_SMOOTHING_VALUE, pad_token_id, trg_vocab_size, device) 156 | 157 | # Make resuming the training possible 158 | steps_taken = 0 159 | if (training_config['start_point']): 160 | ckpt_model_name = f"transformer_ckpt_epoch_{training_config['start_point']}.pth" 161 | path = os.path.join(CHECKPOINTS_PATH, ckpt_model_name) 162 | print(f"Trying to resume training from starting point {path}.") 163 | if (not os.path.exists(path)): 164 | print(f"Requested starting point {path} does not exist.") 165 | return 166 | else: 167 | training_state = torch.load(path) 168 | assert(training_config['dataset_name'] == training_state['dataset_name']) 169 | assert(training_config['language_direction'] == training_state['language_direction']) 170 | baseline_transformer.load_state_dict(training_state['state_dict']) 171 | steps_taken = training_state['steps_taken'] 172 | print("Loaded state dict, resuming training") 173 | 174 | # Check out playground.py for an intuitive visualization of how the LR changes with time/training steps, easy stuff. 175 | custom_lr_optimizer = CustomLRAdamOptimizer( 176 | Adam(baseline_transformer.parameters(), betas=(0.9, 0.98), eps=1e-9), 177 | BASELINE_MODEL_DIMENSION, 178 | training_config['num_warmup_steps'], 179 | steps_taken 180 | ) 181 | 182 | # The decorator function makes things cleaner since there is a lot of redundancy between the train and val loops 183 | train_val_loop = get_train_val_loop(baseline_transformer, custom_lr_optimizer, kl_div_loss, label_smoothing, pad_token_id, time.time()) 184 | 185 | # Step 4: Start the training 186 | for epoch in range(training_config['start_point'], training_config['num_of_epochs']): 187 | # Training loop 188 | train_val_loop(is_train=True, token_ids_loader=train_token_ids_loader, epoch=epoch) 189 | 190 | # Validation loop 191 | with torch.no_grad(): 192 | train_val_loop(is_train=False, token_ids_loader=val_token_ids_loader, epoch=epoch) 193 | 194 | bleu_score = utils.calculate_bleu_score(baseline_transformer, val_token_ids_loader, trg_field_processor) 195 | 196 | # Save the latest transformer in the binaries directory 197 | model_name = f"Transformer_{training_config['substitute_type']}_{training_config['substitute_class']}_{training_config['num_of_epochs']}.pth" 198 | torch.save(utils.get_training_state(training_config, custom_lr_optimizer.current_step_number, baseline_transformer), os.path.join(BINARIES_PATH, model_name)) 199 | 200 | if __name__ == "__main__": 201 | # 202 | # Fixed args - don't change these unless you have a good reason 203 | # 204 | num_warmup_steps = 4000 205 | # Modifiable args - feel free to play with these (only small subset is exposed by design to avoid cluttering) 206 | # 207 | parser = argparse.ArgumentParser() 208 | # According to the paper I infered that the baseline was trained for ~19 epochs on the WMT-14 dataset and I got 209 | # nice returns up to epoch ~20 on IWSLT as well (nice round number) 210 | parser.add_argument("--num_of_epochs", type=int, help="number of training epochs", default=10) 211 | # You should adjust this for your particular machine (I have RTX 2080 with 8 GBs of VRAM so 1500 fits nicely!) 212 | parser.add_argument("--batch_size", type=int, help="target number of tokens in a src/trg batch", default=1500) 213 | 214 | # Data related args 215 | parser.add_argument("--dataset_name", choices=[el.name for el in DatasetType], help='which dataset to use for training', default=DatasetType.IWSLT.name) 216 | parser.add_argument("--language_direction", choices=[el.name for el in LanguageDirection], help='which direction to translate', default=LanguageDirection.de_en.name) 217 | parser.add_argument("--dataset_path", type=str, help='download dataset to this path', default=DATA_DIR_PATH) 218 | parser.add_argument("--model_name", type=str, help="transformer model name", default=r'Transformer_None_None_20.pth') 219 | 220 | # Logging/debugging/checkpoint related (helps a lot with experimentation) 221 | parser.add_argument("--console_log_freq", type=int, help="log to output console (batch) freq", default=10) 222 | parser.add_argument("--checkpoint_freq", type=int, help="checkpoint model saving (epoch) freq", default=20) 223 | parser.add_argument("--start_point", type=int, help="checkpoint model (epoch) where to resume training from", default=0) 224 | parser.add_argument("--substitute_class", type=str, help="class that substitutes attention e.g. FFNetwork_L", choices=["FFNetwork_XS", "FFNetwork_S", "FFNetwork_M", "FFNetwork_L", "FFNetwork_XL",], default="None") 225 | parser.add_argument("--substitute_model_path", type=str, help="path to the substitue of attention. The folder should contain 6 subfolders one for each layer. Inside the FF checkpoints are stored with name: ff_network_{epoch}_layer_{layer}.pth") 226 | parser.add_argument("--layer", help = "If layer is not specified, all layers are substituted", default = None) 227 | parser.add_argument("--epoch", type = int, help="Epoch checkpoint to use.", default=20) 228 | parser.add_argument("--substitute_type", type = str, help="Type of the substitute layer.", choices=["ALRR", "ALR", "ALSR", "None"], default="None") 229 | parser.add_argument("--untrained", type=bool, default = True) 230 | args = parser.parse_args() 231 | # Wrapping training configuration into a dictionary 232 | training_config = dict() 233 | for arg in vars(args): 234 | training_config[arg] = getattr(args, arg) 235 | training_config['num_warmup_steps'] = num_warmup_steps 236 | 237 | # Train the original transformer model 238 | train_transformer(training_config) 239 | -------------------------------------------------------------------------------- /models/definitions/ALR_FF.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from utils.constants import * 3 | import torch 4 | 5 | class FFNetwork_XS(nn.ModuleList): 6 | def __init__(self, model_dimension=128,sentence_length=MAX_LEN): 7 | super(FFNetwork_XS, self).__init__() 8 | self.sentence_length=sentence_length 9 | self.model_dimension=model_dimension 10 | self.width=self.sentence_length*self.model_dimension 11 | self.layers=list() 12 | widths=[1,256,1] 13 | self.depth=len(widths)-1 14 | self.layers=nn.ModuleList() 15 | for i in range(self.depth): 16 | self.layers.extend([nn.LayerNorm(self.width // widths[i]),nn.Linear(self.width // widths[i], self.width // widths[i+1])]) 17 | if(i 2.5s (massive!) 40 | 41 | """ 42 | 43 | @staticmethod 44 | def sort_key(ex): 45 | # What this does is basically it takes a 16-bit binary representation of lengths and interleaves them. 46 | # Example: lengths len(ex.src)=5 and len(ex.trg)=3 result in f(101, 011)=100111, 7 and 1 in f(111, 001)=101011 47 | # It's basically a heuristic that helps the BucketIterator sort bigger batches first 48 | return interleave_keys(len(ex.src), len(ex.trg)) 49 | 50 | def __init__(self, cache_path, fields, **kwargs): 51 | # save_cache interleaves src and trg examples so here we read the cache file having that format in mind 52 | cached_data = [line.split() for line in open(cache_path, encoding='utf-8')] 53 | 54 | cached_data_src = cached_data[0::2] # Even lines contain source examples 55 | cached_data_trg = cached_data[1::2] # Odd lines contain target examples 56 | 57 | assert len(cached_data_src) == len(cached_data_trg), f'Source and target data should be of the same length.' 58 | 59 | examples = [] 60 | src_dataset_total_number_of_tokens = 0 61 | trg_dataset_total_number_of_tokens = 0 62 | for src_tokenized_data, trg_tokenized_data in zip(cached_data_src, cached_data_trg): 63 | ex = Example() 64 | 65 | setattr(ex, 'src', src_tokenized_data) 66 | setattr(ex, 'trg', trg_tokenized_data) 67 | 68 | examples.append(ex) 69 | 70 | # Update the number of tokens 71 | src_dataset_total_number_of_tokens += len(src_tokenized_data) 72 | trg_dataset_total_number_of_tokens += len(trg_tokenized_data) 73 | 74 | # Print relevant information about the dataset (parsing the cache file name) 75 | filename_parts = os.path.split(cache_path)[1].split('_') 76 | language_name = {'en': 'English', 'de': 'German', 'fr': 'French'} 77 | src_language, trg_language = language_name[filename_parts[0]], language_name[filename_parts[1]] 78 | dataset_name = 'IWSLT' if filename_parts[2] == 'iwslt' else 'WMT-14' 79 | dataset_type = filename_parts[3] 80 | print(f'{dataset_type} dataset ({dataset_name}) has {src_dataset_total_number_of_tokens} tokens in the source language ({src_language}) corpus.') 81 | print(f'{dataset_type} dataset ({dataset_name}) has {trg_dataset_total_number_of_tokens} tokens in the target language ({trg_language}) corpus.') 82 | 83 | # Call the parent class Dataset's constructor 84 | super().__init__(examples, fields, **kwargs) 85 | 86 | class ttextDataset(Dataset): 87 | def __init__(self, dataset, fields, filter_pred): 88 | examples = [] 89 | for dat in dataset: 90 | ex = Example() 91 | 92 | setattr(ex, 'src', dat['de']) 93 | setattr(ex, 'trg', dat['en']) 94 | 95 | examples.append(ex) 96 | 97 | super().__init__(examples, fields, filter_pred) 98 | 99 | 100 | class DatasetWrapper(FastTranslationDataset): 101 | """ 102 | Just a wrapper around the FastTranslationDataset. 103 | 104 | """ 105 | 106 | @classmethod 107 | def get_train_datasets(cls, train_cache_path, fields, **kwargs): 108 | 109 | train_dataset = cls(train_cache_path, fields, **kwargs) 110 | 111 | return train_dataset 112 | 113 | @classmethod 114 | def get_val_datasets(cls, val_cache_path, fields, **kwargs): 115 | 116 | val_dataset = cls(val_cache_path, fields, **kwargs) 117 | 118 | return val_dataset 119 | 120 | @classmethod 121 | def get_test_dataset(cls, test_cache_path, fields, **kwargs): 122 | 123 | test_dataset = cls(test_cache_path, fields, **kwargs) 124 | 125 | return test_dataset 126 | 127 | 128 | def save_cache(cache_path, dataset): 129 | with open(cache_path, 'w', encoding='utf-8') as cache_file: 130 | # Interleave source and target tokenized examples, source is on even lines, target is on odd lines 131 | for ex in dataset.examples: 132 | #cache_file.write(ex.src + '\n') 133 | #cache_file.write(ex.trg + '\n') 134 | cache_file.write(' '.join(ex.src) + '\n') 135 | cache_file.write(' '.join(ex.trg) + '\n') 136 | # 137 | # End of caching mechanism utilities 138 | # 139 | 140 | 141 | 142 | def get_datasets_and_vocabs(dataset_path, language_direction, use_iwslt=True, use_caching_mechanism=True, fix_length = None, max_len_train = 100): 143 | src_lang, trg_lang = language_direction.split('_') 144 | spacy_de = spacy.load('fr_core_news_sm') 145 | spacy_en = spacy.load('en_core_web_sm') 146 | spacy_fr = spacy.load('fr_core_news_sm') 147 | 148 | def tokenize_de(text): 149 | return [tok.text for tok in spacy_de.tokenizer(text)] 150 | 151 | def tokenize_en(text): 152 | return [tok.text for tok in spacy_en.tokenizer(text)] 153 | 154 | def tokenize_fr(text): 155 | return [tok.text for tok in spacy_fr.tokenizer(text)] 156 | 157 | tokenizers = {'en': tokenize_en, 'de': tokenize_de, 'fr': tokenize_fr} 158 | # batch first set to true as my transformer is expecting that format (that's consistent with the format 159 | # used in computer vision), namely (B, C, H, W) -> batch size, number of channels, height and width 160 | src_tokenizer = tokenizers[src_lang] 161 | trg_tokenizer = tokenizers[trg_lang] 162 | src_field_processor = Field(tokenize=src_tokenizer, pad_token=PAD_TOKEN, batch_first=True, fix_length = fix_length) 163 | trg_field_processor = Field(tokenize=trg_tokenizer, init_token=BOS_TOKEN, eos_token=EOS_TOKEN, pad_token=PAD_TOKEN, batch_first=True,fix_length = fix_length) 164 | 165 | fields = [('src', src_field_processor), ('trg', trg_field_processor)] 166 | max_len = max_len_train # filter out examples that have more than MAX_LEN tokens 167 | filter_pred = lambda x: len(x.src) <= max_len and len(x.trg) <= max_len 168 | filter_val_test = lambda x: len(x.src) <= MAX_LEN and len(x.trg) <= MAX_LEN 169 | 170 | # Only call once the splits function it is super slow as it constantly has to redo the tokenization 171 | prefix = language_direction 172 | prefix += '_iwslt' if use_iwslt else '_wmt14' 173 | train_cache_path = os.path.join(dataset_path, f'{prefix}_train_cache.csv') 174 | val_cache_path = os.path.join(dataset_path, f'{prefix}_val_cache.csv') 175 | test_cache_path = os.path.join(dataset_path, f'{prefix}_test_cache.csv') 176 | 177 | # This simple caching mechanism gave me ~30x speedup on my machine! From ~70s -> ~2.5s! 178 | ts = time.time() 179 | if not use_caching_mechanism or not (os.path.exists(train_cache_path) and os.path.exists(val_cache_path) and os.path.exists(test_cache_path)): 180 | # dataset objects have a list of examples where example is simply an empty Python Object that has 181 | # .src and .trg attributes which contain a tokenized list of strings (created by tokenize_en and tokenize_de). 182 | # It's that simple, we can consider our datasets as a table with 2 columns 'src' and 'trg' 183 | # each containing fields with tokenized strings from source and target languages 184 | src_ext = '.' + language_direction.split('_')[0] 185 | trg_ext = '.' + language_direction.split('_')[1] 186 | 187 | train_dataset, val_dataset, test_dataset = TabularDataset.splits(path='./data/prepared_data', train=f'train_{language_direction}.csv', validation=f'val_{language_direction}.csv', test=f'test_{language_direction}.csv', format='csv', fields=fields, skip_header=True, filter_pred=filter_pred) 188 | 189 | # dataset_split_fn = datasets.IWSLT.splits if use_iwslt else datasets.WMT14.splits 190 | #train_dataset, val_dataset, test_dataset = dataset_split_fn( 191 | # exts=(src_ext, trg_ext), 192 | # fields=fields, 193 | # root=dataset_path, 194 | # filter_pred=filter_pred 195 | # ) 196 | 197 | save_cache(train_cache_path, train_dataset) 198 | save_cache(val_cache_path, val_dataset) 199 | save_cache(test_cache_path, test_dataset) 200 | 201 | # it's actually better to load from cache as we'll get rid of '\xa0', '\xa0 ' and '\x85' unicode characters 202 | # which we don't need and which SpaCy unfortunately includes as tokens. 203 | train_dataset = DatasetWrapper.get_train_datasets(train_cache_path,fields,filter_pred=filter_pred) 204 | val_dataset = DatasetWrapper.get_val_datasets( val_cache_path, fields, filter_pred=filter_val_test) 205 | test_dataset = DatasetWrapper.get_test_dataset(test_cache_path, fields, filter_pred=filter_val_test) 206 | 207 | print(f'Time it took to prepare the data: {time.time() - ts:3f} seconds.') 208 | 209 | MIN_FREQ = 2 210 | # __getattr__ implementation in the base Dataset class enables us to call .src on Dataset objects even though 211 | # we only have a list of examples in the Dataset object and the example itself had .src attribute. 212 | # Implementation will yield examples and call .src/.trg attributes on them (and those contain tokenized lists) 213 | src_field_processor.build_vocab(train_dataset.src, min_freq=MIN_FREQ) 214 | trg_field_processor.build_vocab(train_dataset.trg, min_freq=MIN_FREQ) 215 | return train_dataset, val_dataset, test_dataset, src_field_processor, trg_field_processor 216 | 217 | 218 | global longest_src_sentence, longest_trg_sentence 219 | 220 | 221 | def batch_size_fn(new_example, count, sofar): 222 | """ 223 | If we use this function in the BucketIterator the batch_size is no longer the number of examples/sentences 224 | in a batch but a number of tokens in a batch - which allows us to max out VRAM on a given GPU. 225 | 226 | Example: if we don't use this function and we set batch size to say 10 we will sometimes end up with 227 | a tensor of size (10, 100) because the longest sentence had a size of 100 tokens but other times we'll end 228 | up with a size of (10, 5) because the longest sentence had only 5 tokens! 229 | 230 | With this function what we do is we specify that source and target tensors can't go over a certain number 231 | of tokens like 1000. So usually either source or target tensors will contain around 1000 tokens and 232 | in worst case both will be really close to a 1000 tokens each. If that is still below max VRAM availabe on 233 | the system we're using the max potential of our GPU w.r.t. VRAM. 234 | 235 | Note: to understand this function you unfortunately would probably have to dig deeper into torch text's 236 | source code. 237 | 238 | """ 239 | global longest_src_sentence, longest_trg_sentence 240 | 241 | if count == 1: 242 | longest_src_sentence = 0 243 | longest_trg_sentence = 0 244 | 245 | longest_src_sentence = max(longest_src_sentence, len(new_example.src)) 246 | # 2 because of start/end of sentence tokens ( and ) 247 | longest_trg_sentence = max(longest_trg_sentence, len(new_example.trg) + 2) 248 | 249 | num_of_tokens_in_src_tensor = count * longest_src_sentence 250 | num_of_tokens_in_trg_tensor = count * longest_trg_sentence 251 | return max(num_of_tokens_in_src_tensor, num_of_tokens_in_trg_tensor) 252 | 253 | 254 | # https://github.com/pytorch/text/issues/536#issuecomment-719945594 <- there is a "bug" in BucketIterator i.e. it's 255 | # description is misleading as it won't group examples of similar length unless you set sort_within_batch to True! 256 | def get_data_loaders(dataset_path, language_direction, dataset_name, batch_size, device, max_len_train = 100): 257 | train_dataset, val_dataset, test_dataset, src_field_processor, trg_field_processor = get_datasets_and_vocabs(dataset_path, language_direction, dataset_name == DatasetType.IWSLT.name, max_len_train = max_len_train) 258 | train_token_ids_loader, val_token_ids_loader, test_token_ids_loader = BucketIterator.splits( 259 | datasets=(train_dataset, val_dataset, test_dataset), 260 | batch_size=batch_size, 261 | device=device, 262 | sort_within_batch=True, # this part is really important otherwise we won't group similar length sentences 263 | batch_size_fn=batch_size_fn # this helps us max out GPU's VRAM 264 | ) 265 | 266 | return train_token_ids_loader, val_token_ids_loader, test_token_ids_loader, src_field_processor, trg_field_processor 267 | 268 | 269 | def get_masks_and_count_tokens_src(src_token_ids_batch, pad_token_id): 270 | batch_size = src_token_ids_batch.shape[0] 271 | 272 | # src_mask shape = (B, 1, 1, S) check out attention function in transformer_model.py where masks are applied 273 | # src_mask only masks pad tokens as we want to ignore their representations (no information in there...) 274 | src_mask = (src_token_ids_batch != pad_token_id).view(batch_size, 1, 1, -1) 275 | num_src_tokens = torch.sum(src_mask.long()) 276 | 277 | return src_mask, num_src_tokens 278 | 279 | 280 | def get_masks_and_count_tokens_trg(trg_token_ids_batch, pad_token_id): 281 | batch_size = trg_token_ids_batch.shape[0] 282 | device = trg_token_ids_batch.device 283 | 284 | # Same as src_mask but we additionally want to mask tokens from looking forward into the future tokens 285 | # Note: wherever the mask value is true we want to attend to that token, otherwise we mask (ignore) it. 286 | sequence_length = trg_token_ids_batch.shape[1] # trg_token_ids shape = (B, T) where T max trg token-sequence length 287 | trg_padding_mask = (trg_token_ids_batch != pad_token_id).view(batch_size, 1, 1, -1) # shape = (B, 1, 1, T) 288 | trg_no_look_forward_mask = torch.triu(torch.ones((1, 1, sequence_length, sequence_length), device=device) == 1).transpose(2, 3) 289 | 290 | # logic AND operation (both padding mask and no-look-forward must be true to attend to a certain target token) 291 | trg_mask = trg_padding_mask & trg_no_look_forward_mask # final shape = (B, 1, T, T) 292 | num_trg_tokens = torch.sum(trg_padding_mask.long()) 293 | 294 | return trg_mask, num_trg_tokens 295 | 296 | 297 | def get_masks_and_count_tokens(src_token_ids_batch, trg_token_ids_batch, pad_token_id, device): 298 | src_mask, num_src_tokens = get_masks_and_count_tokens_src(src_token_ids_batch, pad_token_id) 299 | trg_mask, num_trg_tokens = get_masks_and_count_tokens_trg(trg_token_ids_batch, pad_token_id) 300 | 301 | return src_mask, trg_mask, num_src_tokens, num_trg_tokens 302 | 303 | 304 | def get_src_and_trg_batches(token_ids_batch): 305 | src_token_ids_batch, trg_token_ids_batch = token_ids_batch.src, token_ids_batch.trg 306 | 307 | # Target input should be shifted by 1 compared to the target output tokens 308 | # Example: if we had a sentence like: [,what,is,up,] then to train the NMT model what we do is we pass 309 | # [,what,is,up] to the input as set [what,is,up,] as the expected output. 310 | trg_token_ids_batch_input = trg_token_ids_batch[:, :-1] 311 | 312 | # We reshape from (B, S) into (BxS, 1) as that's the the shape expected by LabelSmoothing which will produce 313 | # the shape (BxS, V) where V is the target vocab size which is the same shape as the one that comes out 314 | # from the transformer so we can directly pass them into the KL divergence loss 315 | trg_token_ids_batch_gt = trg_token_ids_batch[:, 1:].reshape(-1, 1) 316 | 317 | return src_token_ids_batch, trg_token_ids_batch_input, trg_token_ids_batch_gt 318 | 319 | 320 | # 321 | # Everything below is for testing purposes only - feel free to ignore 322 | # 323 | 324 | 325 | def sample_text_from_loader(src_field_processor, trg_field_processor, token_ids_loader, num_samples=2, sample_src=True, sample_trg=True, show_padded=False): 326 | assert sample_src or sample_trg, f'Either src or trg or both must be enabled.' 327 | 328 | for b_idx, token_ids_batch in enumerate(token_ids_loader): 329 | if b_idx == num_samples: # Number of sentence samples to print 330 | break 331 | 332 | print('*' * 5) 333 | if sample_src: 334 | print("Source text:", end="\t") 335 | for token_id in token_ids_batch.src[0]: # print only the first example from the batch 336 | src_token = src_field_processor.vocab.itos[token_id] 337 | 338 | if src_token == PAD_TOKEN and not show_padded: 339 | continue 340 | 341 | print(src_token, end=" ") 342 | print() 343 | 344 | if sample_trg: 345 | print("Target text:", end="\t") 346 | for token_id in token_ids_batch.trg[0]: 347 | trg_token = trg_field_processor.vocab.itos[token_id] 348 | 349 | if trg_token == PAD_TOKEN and not show_padded: 350 | continue 351 | 352 | print(trg_token, end=" ") 353 | print() 354 | 355 | 356 | if __name__ == "__main__": 357 | # To run this delete the dot from from .constants import - not the most elegant solution but it works 358 | # without me having to add sys.path stuff, if you have a more elegant solution please open an issue <3 359 | batch_size = 8 360 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 361 | dataset_name = DatasetType.IWSLT.name 362 | language_direction = LanguageDirection.de_en.name 363 | train_token_ids_loader, val_token_ids_loader, test_token_ids_loader, src_field_processor, trg_field_processor = get_data_loaders(DATA_DIR_PATH, language_direction, dataset_name, batch_size, device) 364 | 365 | # Verify that the mask logic is correct 366 | pad_token_id = src_field_processor.vocab.stoi[PAD_TOKEN] 367 | for batch in train_token_ids_loader: 368 | # Visually inspect that masks make sense 369 | src_padding_mask, trg_mask, num_src_tokens, num_trg_tokens = get_masks_and_count_tokens(batch.src, batch.trg, pad_token_id, device) 370 | break 371 | 372 | # Check vocab size 373 | print(f'Source vocabulary size={len(src_field_processor.vocab)}') 374 | print(f'Target vocabulary size={len(trg_field_processor.vocab)}') 375 | 376 | # Show text from token loader 377 | sample_text_from_loader(src_field_processor, trg_field_processor, train_token_ids_loader) 378 | 379 | -------------------------------------------------------------------------------- /scripts/full_sentence/training_ALR.py: -------------------------------------------------------------------------------- 1 | from pickle import UnpicklingError 2 | import os 3 | import argparse 4 | import time 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | from torch.utils.data import DataLoader 10 | from torch.optim import Adam 11 | from torch.nn.utils.rnn import pad_sequence 12 | from torch.nn.functional import pad 13 | 14 | # Local imports 15 | from pathlib import Path 16 | import sys 17 | path_root = Path(__file__).parents[2] 18 | sys.path.append(str(path_root)) 19 | import models.definitions.ALR_FF as FF_models 20 | from utils.constants import SCRATCH, MAX_LEN, CHECKPOINTS_SCRATCH, ALR_CHECKPOINT_FORMAT 21 | from utils.data_utils import LanguageDirection 22 | DATA_PATH=os.path.join(SCRATCH,"pytorch-original-transformer", "mha_outputs") 23 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # checking whether you have a GPU, I hope so! 24 | def MAPE(target, output): 25 | #Mean Absolute Percentage Error 26 | with torch.no_grad(): 27 | relative_error = torch.abs(output - target) / torch.max(torch.abs(target), torch.ones(output.shape, device = device)*1e-32) 28 | return torch.mean(relative_error) 29 | 30 | def prepare_data(data_path,language_direction, chosen_layer = 0, batch_size = 5, t = "train", att_replacement = 'encoder'): 31 | if t not in ["train", "test", "val"]: 32 | raise ValueError("ERROR: t must be train, test, or val.") 33 | if t == "val": 34 | print("#"*100) 35 | print("ATTENTION VALIDATION USED IN TRAINING, ONLY OK FOR DEBUGGING") 36 | print("#"*100) 37 | if (att_replacement == 'encoder'): 38 | in_path = os.path.join(data_path,"encoder", f"128emb_20ep_IWSLT_{language_direction}_layer{chosen_layer}_v_inputs_{t}") 39 | out_path = os.path.join(data_path,"encoder", f"128emb_20ep_IWSLT_{language_direction}_layer{chosen_layer}_outputs_{t}") 40 | mask_path = os.path.join(data_path,"encoder", f"128emb_20ep_IWSLT_{language_direction}_masks_{t}") 41 | dataset = AttentionEncoderDataset(in_path, out_path, mask_path, MAX_LEN) 42 | return DataLoader(dataset, collate_fn=collate_batch, batch_size= batch_size) 43 | elif(att_replacement == 'decoder'): 44 | in_path = os.path.join(data_path,"decoder_self", f"128emb_20ep_IWSLT_{language_direction}_layer{chosen_layer}_v_inputs_{t}") 45 | out_path = os.path.join(data_path,"decoder_self", f"128emb_20ep_IWSLT_{language_direction}_layer{chosen_layer}_outputs_{t}") 46 | mask_path = os.path.join(data_path,"decoder_self", f"128emb_20ep_IWSLT_{language_direction}_masks_{t}") 47 | dataset = AttentionDecoderDataset(in_path, out_path, mask_path, MAX_LEN) 48 | return DataLoader(dataset, collate_fn=collate_batch_decoder, batch_size = batch_size ) 49 | elif(att_replacement == 'decoder_ca'): 50 | in_enc_path = os.path.join(data_path,"decoder_cross", f"128emb_20ep_IWSLT_{language_direction}_layer{chosen_layer}_v_inputs_{t}") 51 | in_dec_path = os.path.join(data_path,"decoder_cross", f"128emb_20ep_IWSLT_{language_direction}_layer{chosen_layer}_q_inputs_{t}") 52 | out_path = os.path.join(data_path,"decoder_cross", f"128emb_20ep_IWSLT_{language_direction}_layer{chosen_layer}_outputs_{t}") 53 | src_mask_path = os.path.join(data_path,"decoder_cross", f"128emb_20ep_IWSLT_{language_direction}_masks_{t}_src") 54 | trg_mask_path = os.path.join(data_path,"decoder_cross", f"128emb_20ep_IWSLT_{language_direction}_masks_{t}") 55 | dataset = AttentionDecoderCADataset(in_enc_path, in_dec_path, out_path, src_mask_path, trg_mask_path, MAX_LEN) 56 | return DataLoader(dataset, collate_fn=collate_batch_decoder_ca, batch_size = batch_size ) 57 | else: 58 | raise ValueError("ERROR: att_replacement must be encoder, decoder or decoder_ca.") 59 | 60 | def training_replacement_FF(params): 61 | FF_net = getattr(FF_models, params["substitute_class"]) 62 | print(f"Training model: {FF_net}") 63 | model=FF_net() 64 | if not params["multi_device"]: 65 | model.to(device) 66 | # print(model) 67 | #model.init_weights() 68 | model.train(True) 69 | print("FF model created") 70 | lr_optimizer = Adam(model.parameters(), lr=0.0001,betas=(0.9, 0.98), eps=1e-9) 71 | print("Preparing data") 72 | data_loader=prepare_data(params['dataset_path'], params['language_direction'], chosen_layer = params['num_of_curr_trained_layer'], batch_size = params["batch_size"], att_replacement = params["att_replacement"]) 73 | mse_loss=nn.MSELoss() 74 | for epoch in range(params['num_of_epochs']): 75 | print("Epoch: ",epoch) 76 | epoch_loss=0 77 | num_embeddings=0 78 | mapes = [] 79 | start = time.time() 80 | for (data,label, mask) in data_loader: 81 | lr_optimizer.zero_grad() 82 | pred=model(data,mask) 83 | with torch.no_grad(): 84 | num_embeddings+=torch.sum(torch.flatten(mask)).item() 85 | loss_normalizer=torch.sum(torch.flatten(mask)).item()/(mask.shape[0]*mask.shape[1]) 86 | loss=mse_loss(label,pred)/loss_normalizer 87 | loss.backward() 88 | loss /= loss_normalizer 89 | lr_optimizer.step() 90 | with torch.no_grad(): 91 | epoch_loss+=loss.item()*torch.sum(torch.flatten(mask)).item() 92 | mapes.append(MAPE(label, pred)) 93 | if epoch % 20 == 0: 94 | ckpt_model_name = ALR_CHECKPOINT_FORMAT.format(epoch+1, params['num_of_curr_trained_layer']) 95 | torch.save(model.state_dict(), os.path.join(params["checkpoints_folder"], ckpt_model_name)) 96 | print(f"Loss per embedding element:{epoch_loss/num_embeddings}, MAPE: {MAPE(label, pred)}, time: {time.time() - start}") 97 | 98 | class AttentionEncoderDataset(torch.utils.data.Dataset): 99 | def __init__(self, input_path, output_path, mask_path, n, t = "max"): 100 | print(f"Starting to load datasets from {input_path} and {output_path} and {mask_path}") 101 | start = time.time() 102 | 103 | self.n = n 104 | if t != "max" and t != "exact": 105 | raise ValueError("ERROR: t has to be either 'max' or 'exact'.") 106 | self.t = t 107 | self.input = [] 108 | self.output = [] 109 | if t == "max": 110 | self.mask = [] 111 | mask_cache = f"{mask_path}_fixed_{n}_{t}.cache" 112 | 113 | in_cache = f"{input_path}_fixed_{n}_{t}.cache" 114 | out_cache = f"{output_path}_fixed_{n}_{t}.cache" 115 | 116 | if os.path.exists(in_cache) and os.path.exists(out_cache) and (t == "exact" or os.path.exists(mask_cache)): 117 | self.input = torch.load(in_cache) 118 | self.output = torch.load(out_cache) 119 | if t == "max": 120 | self.mask = torch.load(mask_cache) 121 | print(f"Finished loading mask dataset from cache {mask_cache}") 122 | print(f"Finished loading datasets from cache {in_cache} and {out_cache}") 123 | print(f"Loaded {len(self.output)} samples in {time.time() - start}s") 124 | return 125 | 126 | inf = open(input_path, "rb") 127 | outf = open(output_path, "rb") 128 | maskf = open(mask_path, "rb") 129 | try: 130 | while(True): 131 | # i represents one batch of sentences -> dim: batch size x padded sentence length x embedding size 132 | i = torch.from_numpy(np.load(inf)) 133 | m = torch.from_numpy(np.load(maskf)) 134 | m = torch.squeeze(m, dim=1) 135 | m = torch.squeeze(m, dim=1) 136 | o = torch.from_numpy(np.load(outf)) 137 | l = torch.sum(m, dim = 1) 138 | for j in range(i.shape[0]): 139 | if t == "max": 140 | if l[j] <= n: 141 | self.input.append(i[j, :l[j]]) 142 | self.output.append(o[j,:,:l[j]]) 143 | self.mask.append(m[j, :l[j]]) 144 | else: 145 | if l[j] == n: 146 | self.input.append(i[j, :l[j]]) 147 | self.output.append(o[j, :l[j]]) 148 | except (UnpicklingError, ValueError): 149 | print(f"Finished loading datasets from {input_path} and {output_path}") 150 | print(f"Loaded {len(self.output)} samples in {time.time() - start}s") 151 | finally: 152 | inf.close() 153 | outf.close() 154 | maskf.close() 155 | # self.input = torch.cat(self.input, dim=0) 156 | # self.output = torch.cat(self.output, dim=0) 157 | print(self.input[0].shape) 158 | torch.save(self.input, in_cache) 159 | torch.save(self.output, out_cache) 160 | if t == "max": 161 | # self.mask = torch.cat(self.mask, dim=0) 162 | torch.save(self.mask, mask_cache) 163 | 164 | def __len__(self): 165 | return len(self.input) 166 | 167 | def __getitem__(self, idx): 168 | # if we have exactly the same length, there is no need for padding/masking 169 | if self.t == "exact": 170 | return (self.input[idx], self.output[idx]) 171 | return (self.input[idx], self.output[idx], self.mask[idx]) 172 | 173 | def emb_size(self): 174 | return self.input.shape[1] 175 | 176 | class AttentionDecoderCADataset(torch.utils.data.Dataset): 177 | def __init__(self, in_enc_path, in_dec_path, out_path, src_mask_path, trg_mask_path, n, t = "max"): 178 | print(f"Starting to load datasets from {in_enc_path}, {in_dec_path}, {out_path}, {src_mask_path} and {trg_mask_path}") 179 | start = time.time() 180 | 181 | self.n = n 182 | if t != "max" and t != "exact": 183 | raise ValueError("ERROR: t has to be either 'max' or 'exact'.") 184 | self.t = t 185 | self.input_enc = [] 186 | self.input_dec = [] 187 | self.output = [] 188 | if t == "max": 189 | self.src_mask = [] 190 | self.trg_mask = [] 191 | src_mask_cache = f"{src_mask_path}_fixed_{n}_{t}.cache" 192 | trg_mask_cache = f"{trg_mask_path}_fixed_{n}_{t}.cache" 193 | 194 | in_enc_cache = f"{in_enc_path}_fixed_{n}_{t}.cache" 195 | in_dec_cache = f"{in_dec_path}_fixed_{n}_{t}.cache" 196 | out_cache = f"{out_path}_fixed_{n}_{t}.cache" 197 | 198 | if os.path.exists(in_enc_cache) and os.path.exists(in_dec_cache) and os.path.exists(out_cache) and (t == "exact" or (os.path.exists(src_mask_cache) and os.path.exists(trg_mask_cache))): 199 | self.input_enc = torch.load(in_enc_cache) 200 | self.input_dec = torch.load(in_dec_cache) 201 | self.output = torch.load(out_cache) 202 | if t == "max": 203 | self.src_mask = torch.load(src_mask_cache) 204 | self.trg_mask = torch.load(trg_mask_cache) 205 | print(f"Finished loading mask dataset from cache {src_mask_cache} and {trg_mask_cache}") 206 | print(f"Finished loading datasets from cache {in_enc_cache}, {in_dec_cache} and {out_cache}") 207 | print(f"Loaded {len(self.output)} samples in {time.time() - start}s") 208 | return 209 | 210 | inf_enc = open(in_enc_path, "rb") 211 | inf_dec = open(in_dec_path, "rb") 212 | outf = open(out_path, "rb") 213 | maskf_enc = open(src_mask_path, "rb") 214 | maskf_dec = open(trg_mask_path, "rb") 215 | 216 | i_enc_list = [] 217 | i_dec_list = [] 218 | o_list = [] 219 | m_enc_list = [] 220 | m_dec_list = [] 221 | l1_list = [] 222 | l2_list = [] 223 | 224 | try: 225 | while(True): 226 | # i represents one batch of sentences -> dim: batch size x padded sentence length x embedding size 227 | i_enc = torch.from_numpy(np.load(inf_enc)) 228 | i_dec = torch.from_numpy(np.load(inf_dec)) 229 | o = torch.from_numpy(np.load(outf)) 230 | 231 | m = torch.from_numpy(np.load(maskf_enc)) 232 | m = torch.squeeze(m, dim=1) 233 | m_enc = torch.squeeze(m, dim=1) 234 | 235 | m = torch.from_numpy(np.load(maskf_dec)) 236 | m = m[:,:,-1] 237 | m_dec = torch.squeeze(m, dim=1) 238 | 239 | l1 = torch.sum(m_enc, dim = 1) 240 | l2 = torch.sum(m_dec, dim = 1) 241 | 242 | i_enc_list.extend(list(i_enc)) 243 | i_dec_list.extend(list(i_dec)) 244 | o_list.extend(list(o)) 245 | m_enc_list.extend(list(m_enc)) 246 | m_dec_list.extend(list(m_dec)) 247 | l1_list.extend(list(l1)) 248 | l2_list.extend(list(l2)) 249 | 250 | 251 | except (UnpicklingError, ValueError): 252 | print(f"Finished loading datasets from {in_enc_path}, {in_dec_path} and {out_path}") 253 | print(f"Loaded {len(self.output)} samples in {time.time() - start}s") 254 | finally: 255 | inf_enc.close() 256 | inf_dec.close() 257 | outf.close() 258 | maskf_enc.close() 259 | maskf_dec.close() 260 | 261 | for j in range(len(i_enc_list)): 262 | if t == "max": 263 | if l1_list[j] <= n and l2_list[j] <= n: 264 | self.input_enc.append(i_enc_list[j][:l1_list[j]]) 265 | self.src_mask.append(m_enc_list[j][:l1_list[j]]) 266 | 267 | self.input_dec.append(i_dec_list[j][:l2_list[j]]) 268 | self.output.append(o_list[j][:,:l2_list[j]]) 269 | self.trg_mask.append(m_dec_list[j][:l2_list[j]]) 270 | 271 | print(f"Encoder input shape: {self.input_enc[0].shape}") 272 | print(f"Decoder input shape: {self.input_dec[0].shape}") 273 | torch.save(self.input_enc, in_enc_cache) 274 | torch.save(self.input_dec, in_dec_cache) 275 | torch.save(self.output, out_cache) 276 | if t == "max": 277 | torch.save(self.src_mask, src_mask_cache) 278 | torch.save(self.trg_mask, trg_mask_cache) 279 | 280 | def __len__(self): 281 | return len(self.input_enc) 282 | 283 | def __getitem__(self, idx): 284 | # if we have exactly the same length, there is no need for padding/masking 285 | if self.t == "exact": 286 | return (self.input_enc[idx],self.input_dec[idx], self.output[idx]) 287 | return (self.input_enc[idx],self.input_dec[idx], self.output[idx], self.src_mask[idx], self.trg_mask[idx]) 288 | 289 | def emb_size(self): 290 | return self.input.shape[1] 291 | 292 | class AttentionDecoderDataset(torch.utils.data.Dataset): 293 | def __init__(self, input_path, output_path, mask_path, n, t = "max"): 294 | print(f"Starting to load datasets from {input_path} and {output_path} and {mask_path}") 295 | start = time.time() 296 | 297 | self.n = n 298 | if t != "max" and t != "exact": 299 | raise ValueError("ERROR: t has to be either 'max' or 'exact'.") 300 | self.t = t 301 | self.input = [] 302 | self.output = [] 303 | if t == "max": 304 | self.mask = [] 305 | mask_cache = f"{mask_path}_fixed_{n}_{t}.cache" 306 | 307 | in_cache = f"{input_path}_fixed_{n}_{t}.cache" 308 | out_cache = f"{output_path}_fixed_{n}_{t}.cache" 309 | 310 | if os.path.exists(in_cache) and os.path.exists(out_cache) and (t == "exact" or os.path.exists(mask_cache)): 311 | self.input = torch.load(in_cache) 312 | self.output = torch.load(out_cache) 313 | if t == "max": 314 | self.mask = torch.load(mask_cache) 315 | print(f"Finished loading mask dataset from cache {mask_cache}") 316 | print(f"Finished loading datasets from cache {in_cache} and {out_cache}") 317 | print(f"Loaded {len(self.output)} samples in {time.time() - start}s") 318 | return 319 | 320 | inf = open(input_path, "rb") 321 | outf = open(output_path, "rb") 322 | maskf = open(mask_path, "rb") 323 | try: 324 | while(True): 325 | # i represents one batch of sentences -> dim: batch size x padded sentence length x embedding size 326 | i = torch.from_numpy(np.load(inf)) 327 | m = torch.from_numpy(np.load(maskf)) 328 | m = torch.squeeze(m, dim = 1) 329 | o = torch.from_numpy(np.load(outf)) 330 | l = torch.max(torch.sum(m, dim = -1), dim = -1).values 331 | for j in range(i.shape[0]): 332 | if t == "max": 333 | if l[j] <= n: 334 | self.input.append(i[j, :l[j]]) 335 | self.output.append(o[j,:,:l[j]]) 336 | self.mask.append(m[j, :l[j], :l[j]]) 337 | else: 338 | if l[j] == n: 339 | self.input.append(i[j, :n]) 340 | self.output.append(o[j, :n]) 341 | except (UnpicklingError, ValueError): 342 | print(f"Finished loading datasets from {input_path} and {output_path}") 343 | print(f"Loaded {len(self.output)} samples in {time.time() - start}s") 344 | finally: 345 | inf.close() 346 | outf.close() 347 | maskf.close() 348 | # self.input = torch.cat(self.input, dim=0) 349 | # self.output = torch.cat(self.output, dim=0) 350 | torch.save(self.input, in_cache) 351 | torch.save(self.output, out_cache) 352 | if t == "max": 353 | # self.mask = torch.cat(self.mask, dim=0) 354 | torch.save(self.mask, mask_cache) 355 | 356 | def __len__(self): 357 | return len(self.input) 358 | 359 | def __getitem__(self, idx): 360 | # if we have exactly the same length, there is no need for padding/masking 361 | if self.t == "exact": 362 | return (self.input[idx], self.output[idx]) 363 | return (self.input[idx], self.output[idx], self.mask[idx]) 364 | 365 | def emb_size(self): 366 | return self.input.shape[1] 367 | 368 | def collate_batch_decoder(batch): 369 | NH = batch[0][1].shape[0] 370 | HD = batch[0][1].shape[2] 371 | batch_size = len(batch) 372 | inputs = pad_sequence([x[0] for x in batch], batch_first=True, padding_value=0) 373 | outputs = pad_sequence([x[1].transpose(0,1).reshape(-1, NH * HD) for x in batch], batch_first=True, padding_value=0) # this reshaping must be transfered to the adapter as well 374 | trg_padding_mask = pad_sequence([x[2][-1] for x in batch], batch_first=True, padding_value=0) 375 | inputs = torch.cat([inputs, torch.zeros(pad_shape(inputs))], dim = 1).to(device) 376 | outputs = torch.cat([outputs, torch.zeros(pad_shape(outputs))], dim = 1).to(device) 377 | trg_padding_mask = torch.cat([trg_padding_mask, torch.zeros(pad_shape(trg_padding_mask, masks = True), dtype=torch.bool)], dim = 1).view(batch_size, 1, -1).to(device) 378 | 379 | # Pad to fixed length 380 | trg_no_look_forward_mask = torch.triu(torch.ones((1, MAX_LEN, MAX_LEN), device=device) == 1).transpose(1, 2) 381 | 382 | # logic AND operation (both padding mask and no-look-forward must be true to attend to a certain target token) 383 | trg_mask = trg_padding_mask & trg_no_look_forward_mask # final shape = (B, T, T) 384 | return inputs, outputs, trg_mask 385 | 386 | def pad_shape(batch, masks = False): 387 | shape = batch.shape 388 | if masks: 389 | return shape[0],MAX_LEN-shape[1] 390 | return shape[0], MAX_LEN-shape[1], shape[2] 391 | 392 | def collate_batch(batch): 393 | # print("COLLATE") 394 | # print(batch[0][0].shape) 395 | # print(batch[0][1].shape) 396 | # print(batch[0][2].shape) 397 | 398 | # Pad all elements to the same length 399 | NH = batch[0][1].shape[0] 400 | HD = batch[0][1].shape[2] 401 | inputs = pad_sequence([x[0] for x in batch], batch_first=True, padding_value=0) 402 | outputs = pad_sequence([x[1].transpose(0,1).reshape(-1, NH * HD) for x in batch], batch_first=True, padding_value=0) # this reshaping must be transfered to the adapter as well 403 | masks = pad_sequence([x[2] for x in batch], batch_first=True, padding_value=0) 404 | # print(inputs.shape) 405 | # print(outputs.shape) 406 | # print(masks.shape) 407 | 408 | # Pad to fixed length 409 | inputs = torch.cat([inputs, torch.zeros(pad_shape(inputs))], dim = 1).to(device) 410 | outputs = torch.cat([outputs, torch.zeros(pad_shape(outputs))], dim = 1).to(device) 411 | masks = torch.cat([masks, torch.zeros(pad_shape(masks, masks = True), dtype=torch.bool)], dim = 1).to(device) 412 | 413 | # Reshape concatenating the embeddings for each sentence 414 | masks = torch.repeat_interleave(masks, inputs.shape[-1] ,dim=1) 415 | inputs = torch.reshape(inputs, (inputs.shape[0],inputs.shape[1]*inputs.shape[2])) 416 | outputs = torch.reshape(outputs, (outputs.shape[0],outputs.shape[1]*outputs.shape[2])) 417 | return inputs, outputs, masks 418 | 419 | def collate_batch_decoder_ca(batch): 420 | # Pad all elements to the same length 421 | NH = batch[0][2].shape[0] 422 | HD = batch[0][2].shape[2] 423 | inputs_enc = pad_sequence([x[0] for x in batch], batch_first=True, padding_value=0) 424 | inputs_dec = pad_sequence([x[1] for x in batch], batch_first=True, padding_value=0) 425 | outputs = pad_sequence([x[2].transpose(0,1).reshape(-1, NH * HD) for x in batch], batch_first=True, padding_value=0) # this reshaping must be transfered to the adapter as well 426 | src_masks = pad_sequence([x[3] for x in batch], batch_first=True, padding_value=0) 427 | trg_masks = pad_sequence([x[4] for x in batch], batch_first=True, padding_value=0) 428 | # print(inputs.shape) 429 | # print(outputs.shape) 430 | # print(masks.shape) 431 | 432 | # Pad to fixed length 433 | inputs_enc = torch.cat([inputs_enc, torch.zeros(pad_shape(inputs_enc))], dim = 1).to(device) 434 | inputs_dec = torch.cat([inputs_dec, torch.zeros(pad_shape(inputs_dec))], dim = 1).to(device) 435 | outputs = torch.cat([outputs, torch.zeros(pad_shape(outputs))], dim = 1).to(device) 436 | src_masks = torch.cat([src_masks, torch.zeros(pad_shape(src_masks, masks = True), dtype=torch.bool)], dim = 1).to(device) 437 | trg_masks = torch.cat([trg_masks, torch.zeros(pad_shape(trg_masks, masks = True), dtype=torch.bool)], dim = 1).to(device) 438 | # Reshape concatenating the embeddings for each sentence 439 | src_masks = torch.repeat_interleave(src_masks, inputs_enc.shape[-1] ,dim=1) 440 | trg_masks = torch.repeat_interleave(trg_masks, inputs_dec.shape[-1] ,dim=1) 441 | inputs_enc = torch.reshape(inputs_enc, (inputs_enc.shape[0],inputs_enc.shape[1]*inputs_enc.shape[2])) 442 | inputs_dec = torch.reshape(inputs_dec, (inputs_dec.shape[0],inputs_dec.shape[1]*inputs_dec.shape[2])) 443 | inputs = torch.cat([inputs_enc, inputs_dec], dim = 1) 444 | outputs = torch.reshape(outputs, (outputs.shape[0],outputs.shape[1]*outputs.shape[2])) 445 | return inputs, outputs, trg_masks 446 | 447 | 448 | if __name__ == "__main__": 449 | parser = argparse.ArgumentParser() 450 | parser.add_argument("--num_of_epochs", type=int, help="number of training epochs", default=21) 451 | parser.add_argument("--dataset_path", type=str, help='download dataset to this path', default=DATA_PATH) 452 | parser.add_argument("--model_dimension", type=str, help='embedding size', default=128) 453 | parser.add_argument("--batch_size", type=str, help='batch_size', default=2000) 454 | parser.add_argument("--multi_device", action = "store_true") 455 | 456 | # Params to set 457 | parser.add_argument("--num_of_curr_trained_layer", type=str, help='num_of_curr_trained_layer', default=0) 458 | parser.add_argument("--substitute_class", type = str, help="name of the FF to train defined in models/definitions/ALR.py", required=True) 459 | parser.add_argument("--att_replacement", help = "Which attention to replace", choices = ["encoder", "decoder", "decoder_ca"], default = "encoder") 460 | parser.add_argument("--language_direction", choices=[el.name for el in LanguageDirection], help='which direction to translate', default=LanguageDirection.de_en.name) 461 | args = parser.parse_args() 462 | # Wrapping training configuration into a dictionary 463 | training_config = dict() 464 | for arg in vars(args): 465 | training_config[arg] = getattr(args, arg) 466 | print("Training arguments parsed") 467 | training_config["checkpoints_folder"] = os.path.join(CHECKPOINTS_SCRATCH,"ALR", training_config["substitute_class"], f"layer{training_config['num_of_curr_trained_layer']}") 468 | os.makedirs(training_config["checkpoints_folder"], exist_ok = True) 469 | print(training_config["checkpoints_folder"]) 470 | print(training_config) 471 | training_replacement_FF(training_config) 472 | --------------------------------------------------------------------------------