├── README.md ├── __init__.py ├── aggregate_hyperparameter_correlation.py ├── compute_robust_measures.py ├── compute_ww.py ├── create_experiment.py ├── download_data.sh ├── environment.yml ├── eval_bleu_loss.py ├── hyperparameter_correlation.py ├── metrics.py ├── models └── definitions │ └── transformer_model.py ├── robust_measures.py ├── scripts ├── generate_script.ipynb ├── hyperparameter_correlation.sh ├── run_hyperparameter_correlation.sh ├── run_plot_scatterplot.sh ├── slurm_compute_ww.sh ├── slurm_eval_bleu.sh ├── slurm_robust_measures.sh └── slurm_train_models.sh ├── time_wise_correlation.py ├── training_script.py ├── utils ├── __init__.py ├── constants.py ├── data_utils.py ├── decoding_utils.py ├── optimizers_and_distributions.py ├── resource_downloader.py ├── utils.py ├── utils_CKA.py ├── utils_NMT.py ├── utils_analyze_plots.py ├── utils_huggingface.py ├── utils_ww_results.py └── visualization_utils.py └── visualization ├── Best_ETPL_Lambda.png ├── Model_quality_vs_generalization_gap.png ├── TPL_vs_PL_mediocre.png ├── TPL_vs_PL_mediocre_1.pdf ├── Visualize_example_WW_layers.ipynb ├── calculate_rank_correlation_with_colored_groups.ipynb ├── reproduce_scatterplot.ipynb └── results ├── TPL_vs_PL_bad_evals.npy ├── TPL_vs_PL_good_evals.npy └── TPL_vs_PL_mediocre_evals.npy /README.md: -------------------------------------------------------------------------------- 1 | # NLP metrics 2 | This repository contains the code to reproduce the results from the paper :link: [Evaluating natural language processing models with generalization metrics that do not need access to any training or testing data.](https://arxiv.org/pdf/2202.02842.pdf) Our main results are that metrics from the :link: [HT-SR theory](https://github.com/CalculatedContent/WeightWatcher) can predict the generalization of NLP models. Also, unlike existing generalization metrics that focus on the "generalization gap", the HT-SR theory can predict the quality of NLP models, e.g., measured by the test-time BLEU scores when the NLP task is neural machine translation. 3 | 4 | We mainly study Transformers in this paper. For Transformer training, we follow :link: [Vaswani et al.](https://arxiv.org/abs/1706.03762). We develop our implementation based on an :link: [online repository](https://github.com/gordicaleksa/pytorch-original-transformer). This code reproduces the results from Vaswani et al. with more easily configurable Transformer architectures. In addition to the HT-SR theory, we also evaluate generalization metrics from :link: [Dziugaite et al. 2020.](https://proceedings.neurips.cc/paper/2020/file/86d7c8a08b4aaa1bc7c599473f5dddda-Paper.pdf) and :link: [Jiang et al. 2019.](https://arxiv.org/abs/1912.02178) 5 | 6 | ## Setup the environment 7 | 8 | Step 1. Create a conda environment. 9 | ``` 10 | conda env create 11 | ``` 12 | Activate the environment. 13 | ``` 14 | conda activate NLP_metrics 15 | ``` 16 | 17 | Step 2. Download data and pretrained results. 18 | ``` 19 | ./download_data.sh 20 | ``` 21 | 22 | ## Generate the experiment files. Change the checkpoint repository if necessary. 23 | ``` 24 | python create_experiment.py --CKPT_DIR 25 | ``` 26 | For example, on my machine, the checkpoint directory is `/data/yyaoqing/Generalization_metrics_for_NLP/checkpoint/`. 27 | 28 | ## Reproduce the figures shown in paper 29 | 30 | ### Result 1. Examples of PL fittings. 31 | 32 | You can check the examples of PL and E-TPL fittings. Take a look at `visualization/Visualize_example_WW_layers.ipynb`. 33 | 34 | drawing 35 | 36 | ### Result 2. Scatter plots. 37 | 38 | Then, you can reproduce the scatter plots that compare the generalization metrics with the BLEU scores. Check `visualization/reproduce_scatterplot.ipynb`. 39 | 40 | ![Block](visualization/Best_ETPL_Lambda.png) 41 | 42 | ### Result 3. Box plots. 43 | 44 | You can also reproduce the box plots that rank the generalization metrics considered in the paper. 45 | 46 | ![Block](visualization/Model_quality_vs_generalization_gap.png) 47 | 48 | First, use the following commands to generate the time-wise correlations. The argument `--bleu_type` can be used to choose the correlation with the test BLEU scores or the generalization gap. 49 | ``` 50 | python time_wise_correlation.py --bleu_type test 51 | python time_wise_correlation.py --bleu_type gap 52 | ``` 53 | 54 | Second, Generate the correlation results when a single hyperparameter is varied. 55 | ``` 56 | python aggregate_hyperparameter_correlation.py 57 | ``` 58 | 59 | Now, you should have all the results. Check `visualization/calculate_rank_correlation_with_colored_groups.ipynb` to see the box plots. 60 | 61 | ## Reproduce all the training results. 62 | 63 | Fully reproducing our results requires :link: [slurm](https://slurm.schedmd.com/) and about 6T storage. 64 | 65 | Step 1. Generate slurm configuration files. Check the `scripts/generate_script.ipynb` to generate the training and evaluation slurm configrations. 66 | 67 | Step 2. Submit the slurm files. Remember to change the directories in the slurm file and make a slurm log folder. 68 | ``` 69 | mkdir slurm_logs 70 | ``` 71 | 72 | For training, do the following. 73 | ``` 74 | sbatch ./scripts/slurm_train_models.sh 75 | ``` 76 | For evaluation, use the following bash files. 77 | ``` 78 | sbatch ./scripts/slurm_eval_bleu.sh 79 | sbatch ./scripts/slurm_compute_ww.sh 80 | sbatch ./scripts/slurm_robust_measures.sh 81 | ``` 82 | Notice that we evaluate PL, E-TPL and EXP fittings. To select the distribution, change L23-33 in the file `slurm_compute_ww.sh`. 83 | 84 | Step 3. After generating all the evaluation files, you will get all the json and pickle files similar to the `checkpoint.zip`. Then, you can draw the scatter plots and calculate the rank correlations using the following commands. 85 | ``` 86 | ./scripts/run_plot_scatterplot.sh 87 | ./scripts/run_hyperparameter_correlation.sh 88 | ``` 89 | After that, you will get all the plots and rank correlation results similar to the `plots.zip` and `results.zip`. 90 | 91 | ## Citation 92 | 93 | We appreciate it if you would please cite the following paper if you found the repository useful for your work: 94 | 95 | ``` 96 | @TECHREPORT{yang2022evaluating, 97 | author = {Yang, Yaoqing and Theisen, Ryan and Hodgkinson, Liam and Gonzalez, Joseph E and Ramchandran, Kannan and Martin, Charles H and Mahoney, Michael W}, 98 | title = {Evaluating natural language processing models with generalization metrics that do not need access to any training or testing data}, 99 | number = {Preprint: arXiv:2202.02842}, 100 | year = {2022}, 101 | } 102 | ``` 103 | 104 | License 105 | ---- 106 | 107 | MIT 108 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nsfzyzz/Generalization_metrics_for_NLP/e7a991f8baa15a1651a53016cff532d9d99256fc/__init__.py -------------------------------------------------------------------------------- /aggregate_hyperparameter_correlation.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import pickle 4 | import numpy as np 5 | 6 | datasets=['WMT14', 'IWSLT', 'WMT14'] 7 | size_params = ['depth', 'depth', 'width'] 8 | adjust_measures_suffixs = ['normalized_by_samples', 'not_normalized_by_samples'] 9 | bleu_types = ['id_bleu', 'id_bleu_gap'] 10 | 11 | for bleu_type in bleu_types: 12 | for dataset, size_param in zip(datasets, size_params): 13 | metric_folders = glob.glob(f"results/{dataset}_Simpson/*/*") 14 | 15 | for adjust_measures_suffix in adjust_measures_suffixs: 16 | if size_param == 'depth': 17 | groups = ['sample', 'lr', 'depth'] 18 | elif size_param == 'width': 19 | groups = ['sample', 'lr', 'width'] 20 | for group in groups: 21 | 22 | correlation_file = f'plot_results_test_Simpson_{dataset}_size_param_{size_param}_individual_param_{group}_{bleu_type}_{adjust_measures_suffix}.pkl' 23 | print(f"Generating {correlation_file}.") 24 | 25 | results = {} 26 | 27 | for path in metric_folders: 28 | metric = os.path.basename(path) 29 | corr_result_file = f"corr_{bleu_type}_{dataset}_size_param_{size_param}_{adjust_measures_suffix}.pkl" 30 | results_this_metric = pickle.load(open(os.path.join(path, corr_result_file), 'rb')) 31 | 32 | if metric in ['alpha_weighted', 'log_alpha_norm'] and 'TPL' in path: 33 | continue 34 | 35 | results[metric] = results_this_metric[group] 36 | 37 | if bleu_type == 'id_bleu': 38 | #print('negate the results!') 39 | # Negate the correlation because we want the metrics to be negatively correlated 40 | results[metric] = [-x for x in results[metric] if not np.isnan(x)] 41 | #print(results[metric]) 42 | pickle.dump(results, open(f'results/{correlation_file}', 'wb')) -------------------------------------------------------------------------------- /compute_robust_measures.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import os 4 | import pickle 5 | from models.definitions.transformer_model import Transformer 6 | from utils.data_utils import get_data_loaders 7 | from utils.constants import * 8 | import wandb 9 | from robust_measures import get_all_measures 10 | from utils.utils_CKA import * 11 | 12 | 13 | class fake_dataloader: 14 | def __init__(self, dataset): 15 | self.dataset = dataset 16 | 17 | 18 | def main(args): 19 | 20 | if not args.calculate_margin and not args.calculate_pac_bayes and not args.test_bleu: 21 | device = torch.device("cpu") 22 | else: 23 | device = torch.device("cuda") 24 | 25 | subsampling = args.num_samples!=0 26 | 27 | # Get Transformer model 28 | print("Load transformer model.") 29 | 30 | train_token_ids_loader, _, src_field_processor, trg_field_processor = get_data_loaders( 31 | './data', 32 | 'G2E', 33 | args.dataset, 34 | args.batch_size, 35 | device, 36 | subsampling=subsampling, 37 | num_samples=args.num_samples) 38 | 39 | if args.dataset=='IWSLT': 40 | dataset_len = 200000 41 | elif args.dataset == 'WMT': 42 | dataset_len = 4500000 43 | if args.num_samples!=0: 44 | dataset_len = args.num_samples 45 | fake_NMT_loader = fake_dataloader(dataset=[0]*dataset_len) 46 | 47 | pad_token_id = src_field_processor.vocab.stoi[PAD_TOKEN] # pad token id is the same for target as well 48 | src_vocab_size = len(src_field_processor.vocab) 49 | trg_vocab_size = len(trg_field_processor.vocab) 50 | 51 | # Load initialized model 52 | baseline_transformer_init = Transformer( 53 | model_dimension=args.width, 54 | src_vocab_size=src_vocab_size, 55 | trg_vocab_size=trg_vocab_size, 56 | number_of_heads=args.num_heads, 57 | number_of_layers=args.num_layers, 58 | dropout_probability=BASELINE_MODEL_DROPOUT_PROB 59 | ).to(device) 60 | 61 | ckpt_epoch = os.path.join(args.ckpt, f"net_epoch_{args.starting_epoch}.ckpt") 62 | ckpt = torch.load(ckpt_epoch, map_location='cpu') 63 | baseline_transformer_init.load_state_dict(ckpt["state_dict"]) 64 | baseline_transformer_init.eval() 65 | 66 | wandb.init(name = args.ckpt + '_eval_measure') 67 | 68 | final_evals = {} 69 | for epoch in range(1, 1+args.num_epochs): 70 | 71 | all_complexities = {} 72 | print(f"Loading the checkpoint for epoch {epoch}.") 73 | baseline_transformer = Transformer( 74 | model_dimension=args.width, 75 | src_vocab_size=src_vocab_size, 76 | trg_vocab_size=trg_vocab_size, 77 | number_of_heads=args.num_heads, 78 | number_of_layers=args.num_layers, 79 | dropout_probability=BASELINE_MODEL_DROPOUT_PROB 80 | ).to(device) 81 | 82 | baseline_transformer_path_norm = Transformer( 83 | model_dimension=args.width, 84 | src_vocab_size=src_vocab_size, 85 | trg_vocab_size=trg_vocab_size, 86 | number_of_heads=args.num_heads, 87 | number_of_layers=args.num_layers, 88 | dropout_probability=BASELINE_MODEL_DROPOUT_PROB, 89 | customize_layer_norm=True 90 | ) 91 | 92 | ckpt_epoch = os.path.join(args.ckpt, f"net_epoch_{epoch}.ckpt") 93 | ckpt = torch.load(ckpt_epoch, map_location='cpu') 94 | baseline_transformer.load_state_dict(ckpt["state_dict"]) 95 | baseline_transformer.eval() 96 | 97 | baseline_transformer_path_norm.load_state_dict(ckpt["state_dict"]) 98 | baseline_transformer_path_norm.eval() 99 | 100 | if args.calculate_margin: 101 | measure_loader = train_token_ids_loader 102 | else: 103 | measure_loader = fake_NMT_loader 104 | 105 | print("Start analysis on different types of measures.") 106 | 107 | all_complexities = get_all_measures(baseline_transformer, 108 | baseline_transformer_init, 109 | measure_loader, 110 | None, 111 | seed=2021, 112 | no_pac_bayes=not args.calculate_pac_bayes, 113 | no_margin=not args.calculate_margin, 114 | no_basics=False, 115 | no_path_norm=False, 116 | no_CKA=False, 117 | path_norm_transformer=baseline_transformer_path_norm, 118 | pad_token_id=pad_token_id, 119 | trg_vocab_size=trg_vocab_size, 120 | pacbayes_depth=8) 121 | final_evals[epoch] = all_complexities 122 | 123 | wandb.log(all_complexities) 124 | pickle.dump(final_evals, open( os.path.join(args.ckpt, args.result_suffix), "wb" ) ) 125 | 126 | print("Experiment finished. Save and exit.") 127 | 128 | if __name__ == "__main__": 129 | 130 | parser = argparse.ArgumentParser() 131 | parser.add_argument("ckpt", type=str, help="path of checkpoint") 132 | parser.add_argument("--result_suffix", type=str, default='robust_measures.pkl', help="name of result") 133 | parser.add_argument('--starting-epoch', type=int, default=1) 134 | parser.add_argument('--num-epochs', type=int, default=20) 135 | parser.add_argument("--width", type=int, help="embedding dimension", default=64) 136 | parser.add_argument("--dataset", type=str, help="dataset", choices=['IWSLT', 'WMT'], default='IWSLT') 137 | parser.add_argument("--batch_size", type=int, help="batch size to create dataset", default=1500) 138 | parser.add_argument("--num-samples", type=int, help="number of samples", default=0) 139 | parser.add_argument("--calculate_margin", action='store_true') 140 | parser.add_argument("--calculate_pac_bayes", action='store_true') 141 | parser.add_argument("--num-layers", type=int, help="number of Transformer layers", default=6) 142 | parser.add_argument("--num-heads", type=int, help="number of Transformer heads", default=BASELINE_MODEL_NUMBER_OF_HEADS) 143 | 144 | args = parser.parse_args() 145 | 146 | print("Arguments for the experiment.") 147 | for arg in vars(args): 148 | print(arg, getattr(args, arg)) 149 | 150 | main(args) 151 | 152 | -------------------------------------------------------------------------------- /compute_ww.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import weightwatcher as ww 3 | import argparse 4 | import os 5 | import pickle 6 | from models.definitions.transformer_model import Transformer 7 | from utils.data_utils import get_data_loaders 8 | from utils.constants import * 9 | import wandb 10 | 11 | def main(args): 12 | 13 | device = torch.device("cpu") 14 | subsampling = args.num_samples!=0 15 | 16 | # Get Transformer model 17 | print("Load transformer model.") 18 | 19 | _, _, src_field_processor, trg_field_processor = get_data_loaders( 20 | './data', 21 | 'G2E', 22 | args.dataset, 23 | args.batch_size, 24 | device, 25 | subsampling=subsampling, 26 | num_samples=args.num_samples) 27 | 28 | #pad_token_id = src_field_processor.vocab.stoi[PAD_TOKEN] # pad token id is the same for target as well 29 | src_vocab_size = len(src_field_processor.vocab) 30 | trg_vocab_size = len(trg_field_processor.vocab) 31 | 32 | baseline_transformer = Transformer( 33 | model_dimension=args.width, 34 | src_vocab_size=src_vocab_size, 35 | trg_vocab_size=trg_vocab_size, 36 | number_of_heads=args.num_heads, 37 | number_of_layers=args.num_layers, 38 | dropout_probability=BASELINE_MODEL_DROPOUT_PROB 39 | ) 40 | 41 | # Compute metrics for all epochs 42 | ww_metrics = {} # key: epoch, value: results dict 43 | wandb.init(name = args.ckpt + '_ww') 44 | 45 | if args.distribution == 'truncated_power_law': 46 | distribution = 'E_TPL' 47 | elif args.distribution == 'power_law': 48 | distribution = 'PL' 49 | elif args.distribution == 'exponential': 50 | distribution = 'EXP' 51 | 52 | for epoch in range(args.starting_epoch, args.num_epochs+1): 53 | print(f"\nEPOCH {epoch}") 54 | ckpt = torch.load(os.path.join(args.ckpt,f"net_epoch_{epoch}.ckpt"), map_location='cpu') 55 | baseline_transformer.load_state_dict(ckpt["state_dict"]) 56 | 57 | print("Start weight watcher analysis.") 58 | 59 | watcher = ww.WeightWatcher(model=baseline_transformer) 60 | details = watcher.analyze( 61 | mp_fit=args.mp_fit, 62 | randomize=args.randomize, 63 | plot=args.save_plot, 64 | savefig=args.result, 65 | fit=distribution, # distribution is only for WeightWatcher2 66 | ) 67 | summary = watcher.get_summary(details) 68 | results = {'details':details, 'summary':summary} 69 | ww_metrics[epoch] = results 70 | 71 | wandb.log(summary) 72 | 73 | # Write all results into one file 74 | #with open(os.path.join(args.result, args.result_suffix), 'wb') as f: 75 | # pickle.dump(ww_metrics, f) 76 | 77 | print("Experiment finished. Save and exit.") 78 | 79 | if __name__ == "__main__": 80 | 81 | parser = argparse.ArgumentParser() 82 | parser.add_argument("ckpt", type=str, help="path of checkpoint") 83 | parser.add_argument("result", type=str, help="path to save result") 84 | parser.add_argument("--result-suffix", type=str, default='results.pkl') 85 | parser.add_argument("--width", type=int, help="embedding dimension", default=64) 86 | parser.add_argument("--dataset", type=str, help="dataset", choices=['IWSLT', 'WMT'], default='WMT') 87 | parser.add_argument("--batch_size", type=int, help="batch size to create dataset", default=1500) 88 | parser.add_argument("--num-samples", type=int, help="number of samples", default=0) 89 | parser.add_argument("--save-plot", action='store_true', help="save plot of the weightwatcher results") 90 | parser.add_argument("--mp-fit", action='store_true', help="fitting the model using MP Fit.") 91 | parser.add_argument("--randomize", action='store_true', help="use randomized matrix to check correlation trap.") 92 | parser.add_argument("--distribution", choices=["truncated_power_law", "power_law", "lognormal", "exponential"]) 93 | parser.add_argument("--num-layers", type=int, help="number of Transformer layers", default=6) 94 | parser.add_argument("--num-epochs", type=int, help="number of epochs", default=20) 95 | parser.add_argument("--starting-epoch", type=int, help="The starting epoch number", default=1) 96 | parser.add_argument("--num-heads", type=int, help="number of Transformer heads", default=BASELINE_MODEL_NUMBER_OF_HEADS) 97 | 98 | #parser.add_argument("--negative-lambda", action='store_true', default=False) 99 | 100 | args = parser.parse_args() 101 | print(ww.__file__) 102 | 103 | print("Arguments for the experiment.") 104 | for arg in vars(args): 105 | print(arg, getattr(args, arg)) 106 | 107 | main(args) 108 | 109 | -------------------------------------------------------------------------------- /create_experiment.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This file is used to generate the directories of all experiments. 3 | ''' 4 | 5 | import argparse 6 | 7 | WMT_sample_list = [160000, 320000, 640000, 1280000, 2560000] 8 | IWSLT_sample_list = [40000, 80000, 120000, 160000, 200000] 9 | depth_list = [4, 5, 6, 7, 8] 10 | lr_list = ["0.0625", "0.125", "0.25", "0.375", "0.5", "0.625", "0.75", "1.0"] 11 | width_list = [256, 384, 512, 768, 1024] 12 | head_list = [4, 6, 8, 12, 16] 13 | 14 | 15 | def generate_WMT_depth_experiments(fwrite): 16 | ## Generate sample x learning rate x depth grid on WMT 17 | for sample in WMT_sample_list: 18 | for depth in depth_list: 19 | for lr in lr_list: 20 | fwrite.write(" os.path.join(CKPT_DIR, " + f"\"WMT14_sample{sample}_depth{depth}_width512_lr{lr}_dropout0.1\"" + "),\n") 21 | 22 | 23 | def generate_IWSLT_depth_experiments(fwrite): 24 | ## Generate sample x learning rate x depth grid on IWSLT 25 | for sample in IWSLT_sample_list: 26 | for depth in depth_list: 27 | for lr in lr_list: 28 | fwrite.write(" os.path.join(CKPT_DIR, " + f"\"IWSLT_sample{sample}_depth{depth}_width512_lr{lr}_dropout0.1\"" + "),\n") 29 | 30 | 31 | def generate_WMT_width_experiments(fwrite): 32 | ## Generate sample x learning rate x width grid on WMT 33 | for sample in WMT_sample_list: 34 | for width, head in zip(width_list, head_list): 35 | for lr in lr_list: 36 | fwrite.write(" os.path.join(CKPT_DIR, " + f"\"WMT14_sample{sample}_depth6_width{width}_lr{lr}_dropout0.1\"" + "),\n") 37 | 38 | 39 | if __name__ == "__main__": 40 | 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument("--CKPT_DIR", type=str, default = '/data/yyaoqing/Generalization_metrics_for_NLP/checkpoint', 43 | help="path to save all the checkpoints") 44 | args = parser.parse_args() 45 | 46 | ## First, generate experiments for varying different hyperparameters 47 | starting_text = """""" 48 | starting_text += "import os\n\n" 49 | starting_text += f"CKPT_DIR = \"{args.CKPT_DIR}\"\n\n" 50 | starting_text += "EXPERIMENTS = {\n" 51 | 52 | with open('experiments_hyperparameters.py', 'w') as fwrite: 53 | fwrite.write(starting_text) 54 | fwrite.write(''' "IWSLT_depth": [\n''') 55 | generate_IWSLT_depth_experiments(fwrite) 56 | fwrite.write('''], 57 | 58 | "WMT14_depth": [\n''') 59 | generate_WMT_depth_experiments(fwrite) 60 | fwrite.write('''], 61 | 62 | "WMT14_width": [\n''') 63 | generate_WMT_width_experiments(fwrite) 64 | fwrite.write(''']}''') 65 | 66 | ## Then, generate experiments for time-wise correlations 67 | starting_text = """""" 68 | starting_text += "import os\n\n" 69 | starting_text += f"CKPT_DIR = \"{args.CKPT_DIR}\"\n\n" 70 | starting_text += "EXPERIMENTS = {\n" 71 | 72 | with open('experiments_time_wise.py', 'w') as fwrite: 73 | fwrite.write(starting_text) 74 | fwrite.write(''' "IWSLT": [\n''') 75 | generate_IWSLT_depth_experiments(fwrite) 76 | fwrite.write('''], 77 | 78 | "WMT": [\n''') 79 | generate_WMT_width_experiments(fwrite) 80 | generate_WMT_depth_experiments(fwrite) 81 | fwrite.write(''']}''') 82 | 83 | -------------------------------------------------------------------------------- /download_data.sh: -------------------------------------------------------------------------------- 1 | ### Download the generalization metric results for the checkpoints ### 2 | curl -O https://zenodo.org/record/7136098/files/checkpoint.zip 3 | unzip checkpoint.zip -d checkpoint/ 4 | 5 | ### Download the rank correlation results for the checkpoints ### 6 | curl -O https://zenodo.org/record/7136714/files/results.zip 7 | unzip results.zip 8 | 9 | ### Download the scatter plots ### 10 | curl -O https://zenodo.org/record/7136704/files/plots.zip 11 | unzip plots.zip -d plots/ 12 | 13 | ### Download the processed data for Transformer training ### 14 | curl -O https://zenodo.org/record/7134119/files/data.zip 15 | unzip data.zip 16 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: NLP_metrics 2 | channels: 3 | - pytorch 4 | - anaconda 5 | - defaults 6 | dependencies: 7 | - cudatoolkit=10.2.89 8 | - matplotlib=3.1.3 9 | - pip=20.0.2 10 | - python=3.8.3 11 | - pytorch=1.5.0 12 | - torchtext=0.6.0 13 | - pip: 14 | - gitpython==3.1.2 15 | - gpustat==1.0.0 16 | - jupyter==1.0.0 17 | - numpy==1.20.3 18 | - pandas==1.5.0 19 | - powerlaw==1.5 20 | - tensorboard==2.10.1 21 | - nltk==3.5 22 | - seaborn==0.11.0 23 | - spacy==2.3.2 24 | - wandb==0.13.3 25 | - weightwatcher==0.5.6 26 | - https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-2.3.1/en_core_web_sm-2.3.1.tar.gz#egg=en_core_web_sm 27 | - https://github.com/explosion/spacy-models/releases/download/de_core_news_sm-2.3.0/de_core_news_sm-2.3.0.tar.gz#egg=de_core_news_sm -------------------------------------------------------------------------------- /eval_bleu_loss.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Computes training and validation BLEU scores and losses. 3 | ''' 4 | 5 | import torch 6 | import torch.nn as nn 7 | import argparse 8 | import os, json, re 9 | 10 | from models.definitions.transformer_model import Transformer 11 | import utils.utils as utils 12 | from utils.constants import * 13 | from utils.data_utils import get_data_loaders, get_masks_and_count_tokens, get_src_and_trg_batches, DatasetType, LanguageDirection 14 | from utils.optimizers_and_distributions import LabelSmoothingDistribution 15 | 16 | def eval_loss(model, dataloader, max_batches): 17 | kl_div_loss = nn.KLDivLoss(reduction='batchmean') # gives better BLEU score than "mean" 18 | label_smoothing = LabelSmoothingDistribution(BASELINE_MODEL_LABEL_SMOOTHING_VALUE, pad_token_id, trg_vocab_size, DEVICE) 19 | 20 | total_loss, n_batches = 0, 0 21 | 22 | with torch.no_grad(): 23 | for batch_idx, token_ids_batch in enumerate(dataloader): 24 | src_token_ids_batch, trg_token_ids_batch_input, trg_token_ids_batch_gt = get_src_and_trg_batches(token_ids_batch) 25 | src_mask, trg_mask, num_src_tokens, num_trg_tokens = get_masks_and_count_tokens(src_token_ids_batch, trg_token_ids_batch_input, pad_token_id, DEVICE) 26 | 27 | # log because the KL loss expects log probabilities (just an implementation detail) 28 | predicted_log_distributions = model(src_token_ids_batch, trg_token_ids_batch_input, src_mask, trg_mask) 29 | smooth_target_distributions = label_smoothing(trg_token_ids_batch_gt) # these are regular probabilities 30 | 31 | loss = kl_div_loss(predicted_log_distributions, smooth_target_distributions) 32 | loss.item() 33 | total_loss += loss.item() 34 | n_batches += 1 35 | 36 | if batch_idx >= max_batches: 37 | break 38 | 39 | return total_loss / n_batches # average loss per batch 40 | 41 | def eval_bleu(model, dataloader, trg_field_processor, max_batches): 42 | bleu_score = utils.calculate_bleu_score( 43 | model, 44 | dataloader, 45 | trg_field_processor, 46 | max_batch=max_batches, 47 | ) 48 | return bleu_score 49 | 50 | if __name__ == "__main__": 51 | parser = argparse.ArgumentParser() 52 | parser.add_argument("--checkpoint_dir", type=str, default="") 53 | parser.add_argument("--max_batches", type=int, default=200) 54 | parser.add_argument("--seed", type=int, default=24) 55 | parser.add_argument("--num_epochs", type=int, default=20) 56 | parser.add_argument("--starting_epoch", type=int, default=1) 57 | parser.add_argument("--dataset", type=str, default='IWSLT') 58 | parser.add_argument("--num-heads", type=int, help="number of Transformer layers", default=BASELINE_MODEL_NUMBER_OF_HEADS) 59 | parser.add_argument("--embedding-dimension", type=int, help="the dimension to save a checkpoint", default=BASELINE_MODEL_DIMENSION) 60 | parser.add_argument("--efficient-eval", action='store_true', help="only evaluate the train BLEU to calculate the generalization gap") 61 | 62 | args = parser.parse_args() 63 | 64 | print(f"DATASET: {args.dataset}") 65 | 66 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 67 | 68 | ### Extract info from checkpoint ### 69 | NUM_SAMPLES = re.search("sample(\d+)", args.checkpoint_dir) 70 | NUM_SAMPLES = int(NUM_SAMPLES.group(1)) 71 | print(f"NUM_SAMPLES: {NUM_SAMPLES}") 72 | 73 | DEPTH = re.search("depth(\d+)", args.checkpoint_dir) 74 | DEPTH = int(DEPTH.group(1)) 75 | print(f"DEPTH: {DEPTH}") 76 | 77 | if args.dataset == 'WMT': 78 | id_data, ood_data = DatasetType.WMT14.name, DatasetType.IWSLT.name 79 | elif args.dataset == 'IWSLT': 80 | ood_data, id_data = DatasetType.WMT14.name, DatasetType.IWSLT.name 81 | else: 82 | raise NameError('Dataset not implemented yet.') 83 | 84 | ood_train_token_ids_loader, ood_val_token_ids_loader, _, _ = get_data_loaders( 85 | DATA_DIR_PATH, 86 | LanguageDirection.G2E.name, 87 | ood_data, 88 | 1500, 89 | DEVICE, 90 | subsampling=True, 91 | num_samples=NUM_SAMPLES, 92 | ood=True, 93 | ) 94 | 95 | id_train_token_ids_loader, id_val_token_ids_loader, src_field_processor, trg_field_processor = get_data_loaders( 96 | DATA_DIR_PATH, 97 | LanguageDirection.G2E.name, 98 | id_data, 99 | 1500, 100 | DEVICE, 101 | subsampling=True, 102 | num_samples=NUM_SAMPLES, 103 | ood=False, 104 | ) 105 | 106 | pad_token_id = src_field_processor.vocab.stoi[PAD_TOKEN] # pad token id is the same for target as well 107 | src_vocab_size = len(src_field_processor.vocab) 108 | trg_vocab_size = len(trg_field_processor.vocab) 109 | 110 | model = Transformer( 111 | model_dimension=args.embedding_dimension, 112 | src_vocab_size=src_vocab_size, 113 | trg_vocab_size=trg_vocab_size, 114 | number_of_heads=args.num_heads, 115 | number_of_layers=DEPTH, 116 | dropout_probability=0.0, 117 | ) 118 | ### 119 | 120 | output = "" 121 | 122 | # Compute metrics for epochs 1-20 123 | for EPOCH in range(args.starting_epoch, args.num_epochs+1): 124 | ckpt_file = os.path.join(args.checkpoint_dir, f"net_epoch_{EPOCH}.ckpt") 125 | model.load_state_dict( 126 | torch.load(ckpt_file, map_location=DEVICE)["state_dict"] 127 | ) 128 | model.to(DEVICE) 129 | model.eval() 130 | 131 | metrics = {} 132 | print(f"Getting metrics for epoch {EPOCH}...") 133 | # Compute training BLEU/loss 134 | metrics[f'epoch{EPOCH}_id_train_bleu_score'] = eval_bleu(model, id_train_token_ids_loader, trg_field_processor, args.max_batches) 135 | if not args.efficient_eval: 136 | metrics[f'epoch{EPOCH}_id_train_loss'] = eval_loss(model, id_train_token_ids_loader, args.max_batches) 137 | 138 | # Compute validation BLEU/loss 139 | metrics[f'epoch{EPOCH}_id_val_loss'] = eval_loss(model, id_val_token_ids_loader, args.max_batches) 140 | metrics[f'epoch{EPOCH}_id_bleu_score'] = eval_bleu(model, id_val_token_ids_loader, trg_field_processor, args.max_batches) 141 | metrics[f'epoch{EPOCH}_ood_val_loss'] = eval_loss(model, ood_val_token_ids_loader, args.max_batches) 142 | metrics[f'epoch{EPOCH}_ood_bleu_score'] = eval_bleu(model, ood_val_token_ids_loader, trg_field_processor, args.max_batches) 143 | print(metrics) 144 | output += (json.dumps(metrics) + "\n") 145 | 146 | # Write entire output to file at once 147 | if args.efficient_eval: 148 | bleu_file_name = "bleu_loss_only_train.jsonl" 149 | else: 150 | bleu_file_name = "bleu_loss.jsonl" 151 | 152 | with open(os.path.join(args.checkpoint_dir, bleu_file_name), "w+") as file: 153 | file.write(output) -------------------------------------------------------------------------------- /hyperparameter_correlation.py: -------------------------------------------------------------------------------- 1 | from experiments_hyperparameters import EXPERIMENTS 2 | from metrics import METRIC_FILES 3 | import argparse, pickle, os, json, re 4 | import pandas as pd 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | import seaborn as sns 8 | import mpmath 9 | import numpy as np 10 | import pickle 11 | from scipy import odr 12 | 13 | def f(B, x): 14 | return B[0]*x + B[1] 15 | 16 | def plot_odr(x,y,ax,color=''): 17 | 18 | data = odr.Data(x, y) 19 | linear = odr.Model(f) 20 | 21 | # Using linear regression to find the starting point 22 | lr_reg = odr.ODR(data, linear, beta0=[1., 2.]) 23 | lr_reg.set_job(fit_type=2) 24 | lr_out = lr_reg.run() 25 | 26 | ordinal_distance_reg = odr.ODR(data, linear, beta0=lr_out.beta) 27 | ordinal_distance_reg.set_job(fit_type=0) 28 | out = ordinal_distance_reg.run() 29 | 30 | xx = np.linspace(min(x),max(x),100) 31 | #out.pprint() 32 | yy = out.beta[0]*xx + out.beta[1] 33 | # delete large y values 34 | valid_indices = (yy<=max(y)) & (yy>=min(y)) 35 | yy = yy[valid_indices] 36 | xx = xx[valid_indices] 37 | if color: 38 | ax.plot(xx,yy,color=color,linewidth=2) 39 | else: 40 | ax.plot(xx,yy,linewidth=2) 41 | 42 | 43 | def logdet_tpl_scalar(lam, beta): 44 | numer = mpmath.meijerg([[],[beta,beta]],[[0,-1+beta,-1+beta],[]],lam) 45 | denom = mpmath.expint(beta,lam) 46 | return float(numer / denom) 47 | 48 | logdet_tpl = np.vectorize(logdet_tpl_scalar) 49 | 50 | def adjust_measure(metric, val, dataset_size): 51 | 52 | if metric.startswith('LOG_'): 53 | return 2*val + np.log(dataset_size) 54 | #return 0.5 * (value - np.log(m)) 55 | elif 'CKA' in metric or 'TRUE_MARGIN' in metric: 56 | return val 57 | else: 58 | #print(val) 59 | #print(dataset_size) 60 | return (val**2)*dataset_size 61 | #return np.sqrt(value / m) 62 | 63 | # TODO: Move into a utils file 64 | def get_metric_bleu_df(experiment, distribution, adjust_measures_back, metric): 65 | ''' 66 | Constructs a DataFrame of length num_epochs. 67 | The columns are [epoch, id_bleu, ood_bleu, metric1, metric2, ...] 68 | ''' 69 | print(experiment) 70 | 71 | ### Get metrics ### 72 | metric_vals = [] 73 | metric_file = METRIC_FILES[metric] 74 | 75 | # Special cases: PL vs TPL alpha 76 | if metric in ['PL_alpha', 'E_TPL_beta']: 77 | if metric == 'PL_alpha': 78 | FILE = os.path.join(experiment, "results_original_alpha.pkl") 79 | elif metric == 'E_TPL_beta': 80 | FILE = os.path.join(experiment, "results.pkl") 81 | with open(FILE, 'rb') as file: 82 | d = pickle.load(file) 83 | epochs = d.keys() 84 | for epoch in epochs: 85 | metric_vals.append(d[epoch]['details']['alpha'].mean()) # averaging over layers 86 | 87 | elif metric == 'E_TPL_lambda': 88 | FILE = os.path.join(experiment, "results.pkl") 89 | with open(FILE, 'rb') as file: 90 | d = pickle.load(file) 91 | epochs = d.keys() 92 | for epoch in epochs: 93 | metric_vals.append(d[epoch]['details']['exponent'].mean()) # averaging over layers 94 | 95 | elif metric == 'EXP_lambda': 96 | FILE = os.path.join(experiment, "results_exponential.pkl") 97 | with open(FILE, 'rb') as file: 98 | d = pickle.load(file) 99 | epochs = d.keys() 100 | for epoch in epochs: 101 | metric_vals.append(d[epoch]['details']['exponent'].mean()) # averaging over layers 102 | 103 | elif metric_file == 'ww': 104 | # Get from results.pkl 105 | if distribution == "power_law": 106 | FILE = os.path.join(experiment, "results_original_alpha.pkl") 107 | elif distribution == "truncated_power_law": 108 | FILE = os.path.join(experiment, "results.pkl") 109 | else: 110 | raise ValueError('Unknown distribution.') 111 | with open(FILE, 'rb') as file: 112 | d = pickle.load(file) 113 | epochs = d.keys() 114 | for epoch in epochs: 115 | # Special case for KS_distance 116 | if metric in ['PL_KS_distance', 'E_TPL_KS_distance']: 117 | metric_vals.append(d[epoch]['details']['D'].mean()) # averaging over layers 118 | elif metric in d[epoch]['details']: 119 | metric_vals.append(d[epoch]['details'][metric].mean()) # averaging over layers 120 | else: 121 | print(f"{FILE} missing {metric}") 122 | metric_vals.append(np.nan) 123 | 124 | elif metric_file == 'robust': 125 | # Get from robust_measures.pkl 126 | FILE = os.path.join(experiment, "robust_measures.pkl") 127 | with open(FILE, 'rb') as file: 128 | d = pickle.load(file) 129 | epochs = d.keys() 130 | for epoch in epochs: 131 | if metric in d[epoch]: 132 | _val = d[epoch][metric] 133 | if adjust_measures_back: 134 | # Reverse the effect of dataset_size 135 | dataset_size = int(re.search("sample(\d+)", experiment).group(1)) 136 | _val_adjusted = adjust_measure(metric, _val, dataset_size) 137 | metric_vals.append(_val_adjusted) 138 | else: 139 | metric_vals.append(_val) 140 | else: 141 | print(f"{FILE} missing {metric}") 142 | metric_vals.append(np.nan) 143 | 144 | metrics = {} # Key: metric name, Value: list of metric values (length num_epochs) 145 | metrics[metric] = metric_vals 146 | 147 | ### Get BLEU scores ### 148 | id_bleu_scores, ood_bleu_scores, id_bleu_gaps, id_bleu_train_scores, id_loss_gaps, id_train_losses, id_val_losses = [], [], [], [], [], [], [] 149 | 150 | EPOCH = 1 # Epochs are numbered 1-20 151 | FILE = os.path.join(experiment, "bleu_loss.jsonl") 152 | with open(FILE, "rb") as file: 153 | for line in file: 154 | d = json.loads(line) 155 | id_bleu_scores.append(d[f'epoch{EPOCH}_id_bleu_score'] * 100) 156 | ood_bleu_scores.append(d[f'epoch{EPOCH}_ood_bleu_score'] * 100) 157 | id_bleu_train_scores.append(d[f'epoch{EPOCH}_id_train_bleu_score'] * 100) 158 | id_bleu_gaps.append((d[f'epoch{EPOCH}_id_train_bleu_score'] - d[f'epoch{EPOCH}_id_bleu_score'])* 100) 159 | id_loss_gaps.append(d[f'epoch{EPOCH}_id_val_loss'] - d[f'epoch{EPOCH}_id_train_loss']) 160 | id_train_losses.append(d[f'epoch{EPOCH}_id_train_loss']) 161 | id_val_losses.append(d[f'epoch{EPOCH}_id_val_loss']) 162 | 163 | EPOCH += 1 164 | 165 | ### Construct the DataFrame ### 166 | data = { 167 | 'epoch': epochs, 'id_bleu': id_bleu_scores, 'ood_bleu': ood_bleu_scores, 168 | 'id_bleu_gap': id_bleu_gaps, 'id_bleu_train': id_bleu_train_scores, 'id_loss_gap': id_loss_gaps, 169 | 'id_loss_train': id_train_losses, 'id_loss_val': id_val_losses 170 | } 171 | data.update(metrics) 172 | 173 | try: 174 | df = pd.DataFrame(data=data) 175 | except ValueError: 176 | print('The dimension does not match! The experiment is') 177 | print(experiment) 178 | for key in data.keys(): 179 | print(key + " dimension is "+str(len(data[key]))) 180 | 181 | return df 182 | 183 | if __name__ == '__main__': 184 | parser = argparse.ArgumentParser() 185 | parser.add_argument("--metric", type=str, default="") 186 | parser.add_argument("--bleu_type", type=str, choices=["id_bleu", "ood_bleu", "id_bleu_gap", "id_bleu_train", "id_loss_gap", "id_loss_train", "id_loss_val"]) 187 | parser.add_argument("--dataset", type=str, choices=["WMT14", "IWSLT"], default="WMT14") 188 | parser.add_argument("--group", type=str, default="sample", choices=["sample", "depth", "lr", "width"]) 189 | parser.add_argument("--fitting_method", type=str, default="LR", choices=["LR", "ODR"]) 190 | parser.add_argument("--distribution", type=str, default="power_law", choices=["power_law", "truncated_power_law", "exponential"]) 191 | parser.add_argument("--model_size_param", type=str, default="depth", choices=["width", "depth"]) 192 | parser.add_argument("--adjust_measures_back", dest='adjust_measures_back', action='store_true', help='adjust the measure back using the dataset size (default: off)') 193 | parser.add_argument("--calculate_or_plot", dest='calculate_or_plot', type=str, choices=["calculate", "plot", "both"]) 194 | 195 | args = parser.parse_args() 196 | assert args.metric in METRIC_FILES.keys() 197 | 198 | # Construct a DataFrame of length num_experiments 199 | # The columns are [id_bleu, ood_bleu, metric, sample, depth, lr, dropout] 200 | records = [] 201 | for experiment in EXPERIMENTS[f"{args.dataset}_{args.model_size_param}"]: 202 | metric_bleu_df = get_metric_bleu_df(experiment, args.distribution, args.adjust_measures_back, args.metric) 203 | # Get the last three epochs' BLEU/metric 204 | average_length = 6 205 | 206 | record = { 207 | 'id_bleu': sum([metric_bleu_df.iloc[-x]['id_bleu'] for x in range(1,1+average_length)])/average_length, 208 | 'id_bleu_train': sum([metric_bleu_df.iloc[-x]['id_bleu_train'] for x in range(1,1+average_length)])/average_length, 209 | 'id_bleu_gap': sum([metric_bleu_df.iloc[-x]['id_bleu_gap'] for x in range(1,1+average_length)])/average_length, 210 | 'id_loss_gap': sum([metric_bleu_df.iloc[-x]['id_loss_gap'] for x in range(1,1+average_length)])/average_length, 211 | 'id_loss_train': sum([metric_bleu_df.iloc[-x]['id_loss_train'] for x in range(1,1+average_length)])/average_length, 212 | 'id_loss_val': sum([metric_bleu_df.iloc[-x]['id_loss_val'] for x in range(1,1+average_length)])/average_length, 213 | 'ood_bleu': sum([metric_bleu_df.iloc[-x]['ood_bleu'] for x in range(1,1+average_length)])/average_length, 214 | f'{args.metric}': sum([metric_bleu_df.iloc[-x][f'{args.metric}'] for x in range(1,1+average_length)])/average_length, 215 | 'sample': int(re.search("sample(\d+)", experiment).group(1)), 216 | 'depth': int(re.search("depth(\d+)", experiment).group(1)), 217 | 'width': int(re.search("width(\d+)", experiment).group(1)), 218 | 'lr': float(re.search("lr([\d.]+)", experiment).group(1)), 219 | 'dropout': float(re.search("dropout([\d.]+)", experiment).group(1)), 220 | } 221 | records.append(record) 222 | 223 | df = pd.DataFrame.from_records(records) 224 | 225 | plot_metric_name = args.metric.lower() 226 | plot_bleu_type_name = args.bleu_type 227 | if plot_bleu_type_name == 'id_bleu': 228 | plot_bleu_type_name = 'BLEU score' 229 | 230 | plot_group_name = args.group 231 | if plot_group_name == 'lr': 232 | plot_group_name = 'Learning rate' 233 | if plot_group_name == 'sample': 234 | plot_group_name = 'Num samples' 235 | 236 | if args.calculate_or_plot in ["calculate", "both"]: 237 | 238 | ### Compute spearman's rank correlations ### 239 | if args.group == 'sample': 240 | SAVE_DIR_CORR = f"results/{args.dataset}_Simpson/{args.distribution}/{plot_metric_name}" 241 | if not os.path.exists(SAVE_DIR_CORR): 242 | os.makedirs(SAVE_DIR_CORR) 243 | 244 | if args.model_size_param == 'depth': 245 | rank_correlation_result = {'sample':[], 'depth':[], 'lr':[]} 246 | three_parameters_grids = [('sample', 'depth', 'lr'), ('depth', 'lr', 'sample'), ('lr', 'sample', 'depth')] 247 | elif args.model_size_param == 'width': 248 | rank_correlation_result = {'sample':[], 'width':[], 'lr':[]} 249 | three_parameters_grids = [('sample', 'width', 'lr'), ('width', 'lr', 'sample'), ('lr', 'sample', 'width')] 250 | 251 | for g0, g1, g2 in three_parameters_grids: 252 | for g1_value in df[g1].unique(): 253 | for g2_value in df[g2].unique(): 254 | one_slice = df.loc[ (df[g1] == g1_value) & (df[g2] == g2_value)][[args.bleu_type, args.metric]] 255 | corr = one_slice.corr(method='spearman').values[0][1] 256 | rank_correlation_result[g0].append(corr) 257 | 258 | if args.adjust_measures_back: 259 | adjust_measures_suffix = 'not_normalized_by_samples' 260 | else: 261 | adjust_measures_suffix = 'normalized_by_samples' 262 | 263 | pickle.dump(rank_correlation_result, open(os.path.join(SAVE_DIR_CORR, f'corr_{args.bleu_type}_{args.dataset}_size_param_{args.model_size_param}_{adjust_measures_suffix}.pkl'), 'wb')) 264 | 265 | if args.calculate_or_plot in ["plot", "both"]: 266 | 267 | ### Make scatterplots ### 268 | SAVE_DIR = f"plots/{args.dataset}_Simpson/{args.distribution}/{plot_metric_name}" 269 | if not os.path.exists(SAVE_DIR): 270 | os.makedirs(SAVE_DIR) 271 | 272 | # Regular scatterplot 273 | fig, ax = plt.subplots(figsize=(9,9)) 274 | lm = sns.lmplot( 275 | data=df, 276 | x=f'{args.metric}', 277 | y=f'{args.bleu_type}', 278 | hue=f'{args.group}', 279 | fit_reg=False, 280 | legend=False, 281 | ) 282 | ax = lm.axes[0, 0] 283 | ax.set_xlabel(plot_metric_name, fontsize=18) 284 | ax.set_ylabel(plot_bleu_type_name, fontsize=18) 285 | ax.set_title(f"{plot_metric_name} vs. {plot_bleu_type_name}", fontsize=18) 286 | legend = plt.legend(title=plot_group_name, bbox_to_anchor=(1.01, 1), loc='upper left', fontsize=14)#, labels=['Hell Yeh', 'Nah Bruh']) 287 | plt.setp(legend.get_title(),fontsize=14) 288 | plt.savefig( 289 | os.path.join(SAVE_DIR, f"{args.bleu_type}_{plot_metric_name}_{args.group}"), 290 | bbox_inches='tight', 291 | dpi=150, 292 | ) 293 | xmin,xmax,ymin,ymax = plt.axis() # save for making best within group plot 294 | 295 | # Simpson's scatterplot 296 | 297 | if args.fitting_method == 'ODR': 298 | fig, ax = plt.subplots(figsize=(6,6)) 299 | 300 | group_values = df[args.group].unique() 301 | group_values = sorted(group_values, key = float) 302 | for group_value in group_values: 303 | 304 | subgroup = df.loc[df[args.group] == group_value] 305 | y = subgroup[args.bleu_type].values 306 | x = subgroup[args.metric].values 307 | ax.scatter(x,y,s=35, label=group_value) 308 | plot_odr(x,y,ax) 309 | 310 | y = df[args.bleu_type].values 311 | x = df[args.metric].values 312 | plot_odr(x,y,ax,color='gray') 313 | 314 | xmin, xmax = df[args.metric].min(), df[args.metric].max() 315 | 316 | ax.set_xlabel(plot_metric_name, fontsize=18) 317 | ax.set_ylabel(plot_bleu_type_name, fontsize=18) 318 | #ax.set_ylim([-1,30]) 319 | ax.set_xlim([xmin-(xmax-xmin)*0.1,xmax+(xmax-xmin)*0.1]) 320 | ax.tick_params(axis='both', which='major', labelsize=14) 321 | ax.tick_params(axis='both', which='minor', labelsize=12) 322 | ax.spines['top'].set_visible(False) 323 | ax.spines['right'].set_visible(False) 324 | ax.set_title(f"{plot_metric_name} vs. {plot_bleu_type_name}", fontsize=18) 325 | legend = plt.legend(title=plot_group_name, bbox_to_anchor=(1.01, 1.0), loc='upper left', fontsize=14) 326 | plt.setp(legend.get_title(),fontsize=14) 327 | plt.savefig( 328 | os.path.join(SAVE_DIR, f"{args.bleu_type}_{plot_metric_name}_{args.group}_simpson_ODR.pdf"), 329 | bbox_inches='tight', 330 | #dpi=150, 331 | format='pdf', 332 | ) 333 | 334 | elif args.fitting_method == 'LR': 335 | fig, ax = plt.subplots(figsize=(9,9)) 336 | lm = sns.lmplot( 337 | data=df, 338 | x=f'{args.metric}', 339 | y=f'{args.bleu_type}', 340 | hue=f'{args.group}', 341 | fit_reg=True, 342 | ci=None, 343 | legend=False, 344 | ) 345 | ax = lm.axes[0, 0] 346 | sns.regplot( 347 | data=df, 348 | x=f'{args.metric}', 349 | y=f'{args.bleu_type}', 350 | scatter=False, 351 | fit_reg=True, 352 | ci=None, 353 | color='gray', 354 | ) 355 | ax.set_xlabel(plot_metric_name, fontsize=18) 356 | ax.set_ylabel(plot_bleu_type_name, fontsize=18) 357 | ax.tick_params(axis='both', which='major', labelsize=14) 358 | ax.tick_params(axis='both', which='minor', labelsize=12) 359 | ax.set_title(f"{plot_metric_name} vs. {plot_bleu_type_name}", fontsize=18) 360 | legend = plt.legend(title=plot_group_name, bbox_to_anchor=(1.01, 1), loc='upper left', fontsize=14) 361 | plt.setp(legend.get_title(),fontsize=14) 362 | plt.savefig( 363 | os.path.join(SAVE_DIR, f"{args.bleu_type}_{plot_metric_name}_{args.group}_simpson.pdf"), 364 | bbox_inches='tight', 365 | #dpi=150, 366 | format='pdf', 367 | ) 368 | 369 | # Only best performing in each group 370 | fig, ax = plt.subplots(figsize=(9,9)) 371 | lm = sns.lmplot( 372 | data=df.sort_values(by=f'{args.bleu_type}', ascending=False).groupby(f'{args.group}', as_index=False).first(), 373 | x=f'{args.metric}', 374 | y=f'{args.bleu_type}', 375 | hue=f'{args.group}', 376 | fit_reg=False, 377 | legend=False, 378 | ) 379 | ax = lm.axes[0, 0] 380 | ax.set_xlabel(plot_metric_name, fontsize=18) 381 | ax.set_ylabel(plot_bleu_type_name, fontsize=18) 382 | ax.set_xlim([xmin,xmax]) 383 | ax.set_ylim([ymin,ymax]) 384 | ax.set_title(f"{plot_metric_name} vs. {plot_bleu_type_name}\nbest performing model in each group", fontsize=18) 385 | legend = plt.legend(title=plot_group_name, bbox_to_anchor=(1.01, 1), loc='upper left', fontsize=14) 386 | plt.setp(legend.get_title(),fontsize=14) 387 | plt.savefig( 388 | os.path.join(SAVE_DIR, f"{args.bleu_type}_{plot_metric_name}_{args.group}_best"), 389 | bbox_inches='tight', 390 | dpi=150, 391 | ) 392 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | METRIC_FILES = { 2 | # robust 3 | 'L2': 'robust', 4 | 'L2_DIST': 'robust', 5 | 'PARAM_NORM': 'robust', 6 | 'FRO_DIST': 'robust', 7 | 'LOG_SUM_OF_FRO': 'robust', 8 | 'DIST_SPEC_INIT': 'robust', 9 | 'LOG_PROD_OF_FRO': 'robust', 10 | 'LOG_SUM_OF_SPEC': 'robust', 11 | 'LOG_PROD_OF_SPEC': 'robust', 12 | 'PATH_NORM': 'robust', 13 | 'INVERSE_MARGIN': 'robust', 14 | 'LOG_PROD_OF_SPEC_OVER_MARGIN': 'robust', 15 | 'LOG_SUM_OF_SPEC_OVER_MARGIN': 'robust', 16 | 'LOG_PROD_OF_FRO_OVER_MARGIN': 'robust', 17 | 'LOG_SUM_OF_FRO_OVER_MARGIN': 'robust', 18 | 'PATH_NORM_OVER_MARGIN': 'robust', 19 | 'PACBAYES_INIT': 'robust', 20 | 'PACBAYES_ORIG': 'robust', 21 | 'PACBAYES_FLATNESS': 'robust', 22 | 'PACBAYES_MAG_INIT': 'robust', 23 | 'PACBAYES_MAG_ORIG': 'robust', 24 | 'PACBAYES_MAG_FLATNESS': 'robust', 25 | 'W_CKA': 'robust', 26 | # 'FRO_OVER_SPEC': 'robust', # Repeated with Stable Rank 27 | 'LOG_SPEC_INIT_MAIN': 'robust', 28 | 'LOG_SPEC_ORIG_MAIN': 'robust', 29 | 30 | # ww 31 | 'log_norm': 'ww', 32 | 'log_spectral_norm': 'ww', 33 | 'mp_softrank': 'ww', 34 | 'stable_rank': 'ww', 35 | 'PL_alpha': 'ww', 36 | 'E_TPL_beta': 'ww', 37 | 'E_TPL_lambda': 'ww', 38 | 'EXP_lambda': 'ww', 39 | 'PL_KS_distance': 'ww', 40 | 'E_TPL_KS_distance': 'ww', 41 | 'tail_mean_vec_entropy': 'ww', 42 | 'bulk_mean_vec_entropy': 'ww', 43 | 'entropy': 'ww', 44 | 'rand_distance': 'ww', 45 | 'alpha_weighted': 'ww', 46 | 'log_alpha_norm': 'ww', 47 | #'logdet_tpl_per_layer': 'ww', # Testing this combined measure. 48 | #'exponent_adjusted': 'ww', # Testing this combined measure. 49 | 50 | # combined metrics calculated from existing ones 51 | #'logdet_tpl': 'combine' # Testing this combined measure. 52 | } -------------------------------------------------------------------------------- /models/definitions/transformer_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contains the implementation of the original transformer paper "Attention is all you need". 3 | 4 | Paper link: https://arxiv.org/pdf/1706.03762.pdf 5 | 6 | Certain modifications: 7 | 1. LayerNorm (before instead of after) 8 | 2. Dropout (Added additionally to attention weights and point-wise feed-forward net sublayer 9 | 10 | Suggested theory: https://jalammar.github.io/illustrated-transformer/ (amazing blog!) 11 | 12 | """ 13 | 14 | 15 | import math 16 | import copy 17 | 18 | 19 | import torch 20 | import torch.nn as nn 21 | 22 | 23 | from utils.constants import * 24 | 25 | 26 | class LayerNormalizationForPathNorm(nn.Module): 27 | 28 | def __init__(self, 29 | normal_shape, 30 | weight=True, 31 | bias=True, 32 | epsilon=1e-5): 33 | """Layer normalization layer for path normalization 34 | The normalization layer function is essentially a linear function 35 | We change the layernorm function for the purpose of calculating path norm 36 | See: [Layer Normalization](https://arxiv.org/pdf/1607.06450.pdf) 37 | :param normal_shape: The shape of the input tensor or the last dimension of the input tensor. 38 | :param weight: Add a scale parameter if it is True. 39 | :param bias: Add an offset parameter if it is True. 40 | :param epsilon: Epsilon for calculating variance. 41 | """ 42 | super(LayerNormalizationForPathNorm, self).__init__() 43 | if isinstance(normal_shape, int): 44 | normal_shape = (normal_shape,) 45 | else: 46 | normal_shape = (normal_shape[-1],) 47 | self.normal_shape = torch.Size(normal_shape) 48 | self.epsilon = epsilon 49 | if weight: 50 | self.weight = nn.Parameter(torch.Tensor(*normal_shape)) 51 | else: 52 | self.register_parameter('weight', None) 53 | if bias: 54 | self.bias = nn.Parameter(torch.Tensor(*normal_shape)) 55 | else: 56 | self.register_parameter('bias', None) 57 | self.reset_parameters() 58 | 59 | def reset_parameters(self): 60 | if self.weight is not None: 61 | self.weight.data.fill_(1) 62 | if self.bias is not None: 63 | self.bias.data.zero_() 64 | 65 | def forward(self, x): 66 | mean = x.mean(dim=-1, keepdim=True) 67 | var = ((x - mean) ** 2).mean(dim=-1, keepdim=True) 68 | #std = (var + self.epsilon).sqrt() 69 | # change this part to perform path-norm operations 70 | y = (x + mean**2) / (var + self.epsilon) 71 | # The following does not need to change 72 | if self.weight is not None: 73 | y *= self.weight 74 | if self.bias is not None: 75 | y += self.bias 76 | return y 77 | 78 | def extra_repr(self): 79 | return 'normal_shape={}, weight={}, bias={}, epsilon={}'.format( 80 | self.normal_shape, self.weight is not None, self.bias is not None, self.epsilon, 81 | ) 82 | 83 | 84 | class Transformer(nn.Module): 85 | 86 | def __init__(self, model_dimension, src_vocab_size, trg_vocab_size, number_of_heads, number_of_layers, dropout_probability, 87 | log_attention_weights=False, customize_layer_norm=False): 88 | super().__init__() 89 | 90 | # Embeds source/target token ids into embedding vectors 91 | self.src_embedding = Embedding(src_vocab_size, model_dimension) 92 | self.trg_embedding = Embedding(trg_vocab_size, model_dimension) 93 | 94 | # Adds positional information to source/target token's embedding vector 95 | # (otherwise we'd lose the positional information which is important in human languages) 96 | self.src_pos_embedding = PositionalEncoding(model_dimension, dropout_probability) 97 | self.trg_pos_embedding = PositionalEncoding(model_dimension, dropout_probability) 98 | 99 | # All of these will get deep-copied multiple times internally 100 | mha = MultiHeadedAttention(model_dimension, number_of_heads, dropout_probability, log_attention_weights) 101 | pwn = PositionwiseFeedForwardNet(model_dimension, dropout_probability) 102 | encoder_layer = EncoderLayer(model_dimension, dropout_probability, mha, pwn, customize_layer_norm) 103 | decoder_layer = DecoderLayer(model_dimension, dropout_probability, mha, pwn, customize_layer_norm) 104 | 105 | self.encoder = Encoder(encoder_layer, number_of_layers, customize_layer_norm) 106 | self.decoder = Decoder(decoder_layer, number_of_layers, customize_layer_norm) 107 | 108 | # Converts final target token representations into log probabilities vectors of the target vocab size 109 | self.decoder_generator = DecoderGenerator(model_dimension, trg_vocab_size, customize_layer_norm) 110 | self.customize_layer_norm = customize_layer_norm 111 | self.init_params() 112 | 113 | def init_params(self, default_initialization=False): 114 | # Not mentioned in the paper, but other implementations used xavier. 115 | # I tested both PyTorch's default initialization and this, and xavier has tremendous impact! I didn't expect 116 | # a model's perf, with normalization layers, to be so much dependent on the choice of weight initialization. 117 | if not default_initialization: 118 | for name, p in self.named_parameters(): 119 | if p.dim() > 1: 120 | nn.init.xavier_uniform_(p) 121 | 122 | def forward(self, src_token_ids_batch, trg_token_ids_batch, src_mask, trg_mask, no_softmax=False): 123 | src_representations_batch = self.encode(src_token_ids_batch, src_mask) 124 | trg_log_probs = self.decode(trg_token_ids_batch, src_representations_batch, trg_mask, src_mask, no_softmax=no_softmax) 125 | return trg_log_probs 126 | 127 | # Modularized into encode/decode functions for optimizing the decoding/translation process (see translation script) 128 | def encode(self, src_token_ids_batch, src_mask): 129 | src_embeddings_batch = self.src_embedding(src_token_ids_batch) # get embedding vectors for src token ids 130 | src_embeddings_batch = self.src_pos_embedding(src_embeddings_batch) # add positional embedding 131 | src_representations_batch = self.encoder(src_embeddings_batch, src_mask) # forward pass through the encoder 132 | 133 | return src_representations_batch 134 | 135 | def decode(self, trg_token_ids_batch, src_representations_batch, trg_mask, src_mask, no_softmax=False): 136 | trg_embeddings_batch = self.trg_embedding(trg_token_ids_batch) # get embedding vectors for trg token ids 137 | trg_embeddings_batch = self.trg_pos_embedding(trg_embeddings_batch) # add positional embedding 138 | # Shape (B, T, D), where B - batch size, T - longest target token-sequence length and D - model dimension 139 | trg_representations_batch = self.decoder(trg_embeddings_batch, src_representations_batch, trg_mask, src_mask) 140 | 141 | # After this line we'll have a shape (B, T, V), where V - target vocab size, decoder generator does a simple 142 | # linear projection followed by log softmax 143 | trg_log_probs = self.decoder_generator(trg_representations_batch, no_softmax=no_softmax) 144 | 145 | # Reshape into (B*T, V) as that's a suitable format for passing it into KL div loss 146 | trg_log_probs = trg_log_probs.reshape(-1, trg_log_probs.shape[-1]) 147 | 148 | return trg_log_probs # the reason I use log here is that PyTorch's nn.KLDivLoss expects log probabilities 149 | 150 | 151 | # 152 | # Encoder architecture 153 | # 154 | 155 | 156 | class Encoder(nn.Module): 157 | 158 | def __init__(self, encoder_layer, number_of_layers, customize_layer_norm): 159 | super().__init__() 160 | assert isinstance(encoder_layer, EncoderLayer), f'Expected EncoderLayer got {type(encoder_layer)}.' 161 | 162 | self.encoder_layers = get_clones(encoder_layer, number_of_layers) 163 | if not customize_layer_norm: 164 | self.norm = nn.LayerNorm(encoder_layer.model_dimension) 165 | else: 166 | self.norm = LayerNormalizationForPathNorm(encoder_layer.model_dimension) 167 | 168 | def forward(self, src_embeddings_batch, src_mask): 169 | # Just update the naming so as to reflect the semantics of what this var will become (the initial encoder layer 170 | # has embedding vectors as input but later layers have richer token representations) 171 | src_representations_batch = src_embeddings_batch 172 | 173 | # Forward pass through the encoder stack 174 | for encoder_layer in self.encoder_layers: 175 | # src_mask's role is to mask/ignore padded token representations in the multi-headed self-attention module 176 | src_representations_batch = encoder_layer(src_representations_batch, src_mask) 177 | 178 | # Not mentioned explicitly in the paper (a consequence of using LayerNorm before instead of after the sublayer 179 | # check out the SublayerLogic module) 180 | return self.norm(src_representations_batch) 181 | 182 | 183 | class EncoderLayer(nn.Module): 184 | 185 | def __init__(self, model_dimension, dropout_probability, multi_headed_attention, pointwise_net, customize_layer_norm): 186 | super().__init__() 187 | num_of_sublayers_encoder = 2 188 | self.sublayers = get_clones(SublayerLogic(model_dimension, dropout_probability, customize_layer_norm), num_of_sublayers_encoder) 189 | 190 | self.multi_headed_attention = multi_headed_attention 191 | self.pointwise_net = pointwise_net 192 | 193 | self.model_dimension = model_dimension 194 | 195 | def forward(self, src_representations_batch, src_mask): 196 | # Define anonymous (lambda) function which only takes src_representations_batch (srb) as input, 197 | # this way we have a uniform interface for the sublayer logic. 198 | encoder_self_attention = lambda srb: self.multi_headed_attention(query=srb, key=srb, value=srb, mask=src_mask) 199 | 200 | # Self-attention MHA sublayer followed by point-wise feed forward net sublayer 201 | src_representations_batch = self.sublayers[0](src_representations_batch, encoder_self_attention) 202 | src_representations_batch = self.sublayers[1](src_representations_batch, self.pointwise_net) 203 | 204 | return src_representations_batch 205 | 206 | 207 | # 208 | # Decoder architecture 209 | # 210 | 211 | 212 | class Decoder(nn.Module): 213 | 214 | def __init__(self, decoder_layer, number_of_layers, customize_layer_norm): 215 | super().__init__() 216 | assert isinstance(decoder_layer, DecoderLayer), f'Expected DecoderLayer got {type(decoder_layer)}.' 217 | 218 | self.decoder_layers = get_clones(decoder_layer, number_of_layers) 219 | if not customize_layer_norm: 220 | self.norm = nn.LayerNorm(decoder_layer.model_dimension) 221 | else: 222 | self.norm = LayerNormalizationForPathNorm(decoder_layer.model_dimension) 223 | 224 | 225 | def forward(self, trg_embeddings_batch, src_representations_batch, trg_mask, src_mask): 226 | # Just update the naming so as to reflect the semantics of what this var will become 227 | trg_representations_batch = trg_embeddings_batch 228 | 229 | # Forward pass through the decoder stack 230 | for decoder_layer in self.decoder_layers: 231 | # Target mask masks pad tokens as well as future tokens (current target token can't look forward) 232 | trg_representations_batch = decoder_layer(trg_representations_batch, src_representations_batch, trg_mask, src_mask) 233 | 234 | # Not mentioned explicitly in the paper (a consequence of using LayerNorm before instead of after the sublayer 235 | # check out the SublayerLogic module) 236 | return self.norm(trg_representations_batch) 237 | 238 | 239 | class DecoderLayer(nn.Module): 240 | 241 | def __init__(self, model_dimension, dropout_probability, multi_headed_attention, pointwise_net, customize_layer_norm): 242 | super().__init__() 243 | num_of_sublayers_decoder = 3 244 | self.sublayers = get_clones(SublayerLogic(model_dimension, dropout_probability, customize_layer_norm), num_of_sublayers_decoder) 245 | 246 | self.trg_multi_headed_attention = copy.deepcopy(multi_headed_attention) 247 | self.src_multi_headed_attention = copy.deepcopy(multi_headed_attention) 248 | self.pointwise_net = pointwise_net 249 | 250 | self.model_dimension = model_dimension 251 | 252 | def forward(self, trg_representations_batch, src_representations_batch, trg_mask, src_mask): 253 | # Define anonymous (lambda) function which only takes trg_representations_batch (trb - funny name I know) 254 | # as input - this way we have a uniform interface for the sublayer logic. 255 | # The inputs which are not passed into lambdas are "cached" here that's why the thing works. 256 | srb = src_representations_batch # simple/short alias 257 | decoder_trg_self_attention = lambda trb: self.trg_multi_headed_attention(query=trb, key=trb, value=trb, mask=trg_mask) 258 | decoder_src_attention = lambda trb: self.src_multi_headed_attention(query=trb, key=srb, value=srb, mask=src_mask) 259 | 260 | # Self-attention MHA sublayer followed by a source-attending MHA and point-wise feed forward net sublayer 261 | trg_representations_batch = self.sublayers[0](trg_representations_batch, decoder_trg_self_attention) 262 | trg_representations_batch = self.sublayers[1](trg_representations_batch, decoder_src_attention) 263 | trg_representations_batch = self.sublayers[2](trg_representations_batch, self.pointwise_net) 264 | 265 | return trg_representations_batch 266 | 267 | 268 | # 269 | # Helper modules (designed with modularity in mind) and organized top to bottom. 270 | # 271 | 272 | 273 | # Note: the original paper had LayerNorm AFTER the residual connection and addition operation 274 | # multiple experiments I found showed that it's more effective to do it BEFORE, how did they figure out which one is 275 | # better? Experiments! There is a similar thing in DCGAN and elsewhere. 276 | class SublayerLogic(nn.Module): 277 | def __init__(self, model_dimension, dropout_probability, customize_layer_norm): 278 | super().__init__() 279 | if not customize_layer_norm: 280 | self.norm = nn.LayerNorm(model_dimension) 281 | else: 282 | self.norm = LayerNormalizationForPathNorm(model_dimension) 283 | self.dropout = nn.Dropout(p=dropout_probability) 284 | 285 | def forward(self, representations_batch, sublayer_module): 286 | # Residual connection between input and sublayer output, details: Page 7, Chapter 5.4 "Regularization", 287 | return representations_batch + self.dropout(sublayer_module(self.norm(representations_batch))) 288 | 289 | 290 | class DecoderGenerator(nn.Module): 291 | def __init__(self, model_dimension, vocab_size, customize_layer_norm): 292 | super().__init__() 293 | 294 | self.linear = nn.Linear(model_dimension, vocab_size) 295 | self.customize_layer_norm = customize_layer_norm 296 | 297 | # -1 stands for apply the log-softmax along the last dimension i.e. over the vocab dimension as the output from 298 | # the linear layer has shape (B, T, V), B - batch size, T - max target token-sequence, V - target vocab size 299 | # again using log softmax as PyTorch's nn.KLDivLoss expects log probabilities (just a technical detail) 300 | self.log_softmax = nn.LogSoftmax(dim=-1) 301 | 302 | def forward(self, trg_representations_batch, no_softmax=False): 303 | # Project from D (model dimension) into V (target vocab size) and apply the log softmax along V dimension 304 | if self.customize_layer_norm or no_softmax: 305 | return self.linear(trg_representations_batch) 306 | else: 307 | return self.log_softmax(self.linear(trg_representations_batch)) 308 | 309 | 310 | class PositionwiseFeedForwardNet(nn.Module): 311 | """ 312 | It's position-wise because this feed forward net will be independently applied to every token's representation. 313 | 314 | Representations batch is of the shape (batch size, max token sequence length, model dimension). 315 | This net will basically be applied independently to every token's representation (you can think of it as if 316 | there was a nested for-loop going over the batch size and max token sequence length dimensions 317 | and applied this net to token representations. PyTorch does this auto-magically behind the scenes. 318 | 319 | """ 320 | def __init__(self, model_dimension, dropout_probability, width_mult=4): 321 | super().__init__() 322 | 323 | self.linear1 = nn.Linear(model_dimension, width_mult * model_dimension) 324 | self.linear2 = nn.Linear(width_mult * model_dimension, model_dimension) 325 | 326 | # This dropout layer is not explicitly mentioned in the paper but it's common to use to avoid over-fitting 327 | self.dropout = nn.Dropout(p=dropout_probability) 328 | self.relu = nn.ReLU() 329 | 330 | def forward(self, representations_batch): 331 | return self.linear2(self.dropout(self.relu(self.linear1(representations_batch)))) 332 | 333 | 334 | class MultiHeadedAttention(nn.Module): 335 | """ 336 | This module already exists in PyTorch. The reason I implemented it here from scratch is that 337 | PyTorch implementation is super complicated as they made it as generic/robust as possible whereas 338 | on the other hand I only want to support a limited use-case. 339 | 340 | Also this is arguable the most important architectural component in the Transformer model. 341 | 342 | Additional note: 343 | This is conceptually super easy stuff. It's just that matrix implementation makes things a bit less intuitive. 344 | If you take your time and go through the code and figure out all of the dimensions + write stuff down on paper 345 | you'll understand everything. Also do check out this amazing blog for conceptual understanding: 346 | 347 | https://jalammar.github.io/illustrated-transformer/ 348 | 349 | Optimization notes: 350 | 351 | qkv_nets could be replaced by Parameter(torch.empty(3 * model_dimension, model_dimension)) and one more matrix 352 | for bias, which would make the implementation a bit more optimized. For the sake of easier understanding though, 353 | I'm doing it like this - using 3 "feed forward nets" (without activation/identity hence the quotation marks). 354 | Conceptually both implementations are the same. 355 | 356 | PyTorch's query/key/value are of different shape namely (max token sequence length, batch size, model dimension) 357 | whereas I'm using (batch size, max token sequence length, model dimension) because it's easier to understand 358 | and consistent with computer vision apps (batch dimension is always first followed by the number of channels (C) 359 | and image's spatial dimensions height (H) and width (W) -> (B, C, H, W). 360 | 361 | This has an important optimization implication, they can reshape their matrix into (B*NH, S/T, HD) 362 | (where B - batch size, S/T - max src/trg sequence length, NH - number of heads, HD - head dimension) 363 | in a single step and I can only get to (B, NH, S/T, HD) in single step 364 | (I could call contiguous() followed by view but that's expensive as it would incur additional matrix copy) 365 | 366 | """ 367 | 368 | def __init__(self, model_dimension, number_of_heads, dropout_probability, log_attention_weights): 369 | super().__init__() 370 | assert model_dimension % number_of_heads == 0, f'Model dimension must be divisible by the number of heads.' 371 | 372 | self.head_dimension = int(model_dimension / number_of_heads) 373 | self.number_of_heads = number_of_heads 374 | 375 | self.qkv_nets = get_clones(nn.Linear(model_dimension, model_dimension), 3) # identity activation hence "nets" 376 | self.out_projection_net = nn.Linear(model_dimension, model_dimension) 377 | 378 | self.attention_dropout = nn.Dropout(p=dropout_probability) # no pun intended, not explicitly mentioned in paper 379 | self.softmax = nn.Softmax(dim=-1) # -1 stands for apply the softmax along the last dimension 380 | 381 | self.log_attention_weights = log_attention_weights # should we log attention weights 382 | self.attention_weights = None # for visualization purposes, I cache the weights here (translation_script.py) 383 | 384 | def attention(self, query, key, value, mask): 385 | # Step 1: Scaled dot-product attention, Page 4, Chapter 3.2.1 "Scaled Dot-Product Attention" 386 | # Notation: B - batch size, S/T max src/trg token-sequence length, NH - number of heads, HD - head dimension 387 | # query/key/value shape = (B, NH, S/T, HD), scores shape = (B, NH, S, S), (B, NH, T, T) or (B, NH, T, S) 388 | # scores have different shapes as MHA is used in 3 contexts, self attention for src/trg and source attending MHA 389 | scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_dimension) 390 | 391 | # Step 2: Optionally mask tokens whose representations we want to ignore by setting a big negative number 392 | # to locations corresponding to those tokens (force softmax to output 0 probability on those locations). 393 | # mask shape = (B, 1, 1, S) or (B, 1, T, T) will get broad-casted (copied) as needed to match scores shape 394 | if mask is not None: 395 | scores.masked_fill_(mask == torch.tensor(False), float("-inf")) 396 | 397 | # Step 3: Calculate the attention weights - how much should we attend to surrounding token representations 398 | attention_weights = self.softmax(scores) 399 | 400 | # Step 4: Not defined in the original paper apply dropout to attention weights as well 401 | attention_weights = self.attention_dropout(attention_weights) 402 | 403 | # Step 5: based on attention weights calculate new token representations 404 | # attention_weights shape = (B, NH, S, S)/(B, NH, T, T) or (B, NH, T, S), value shape = (B, NH, S/T, HD) 405 | # Final shape (B, NH, S, HD) for source MHAs or (B, NH, T, HD) target MHAs (again MHAs are used in 3 contexts) 406 | intermediate_token_representations = torch.matmul(attention_weights, value) 407 | 408 | return intermediate_token_representations, attention_weights # attention weights for visualization purposes 409 | 410 | def forward(self, query, key, value, mask): 411 | batch_size = query.shape[0] 412 | 413 | # Step 1: Input linear projection 414 | # Notation: B - batch size, NH - number of heads, S/T - max src/trg token-sequence length, HD - head dimension 415 | # Shape goes from (B, S/T, NH*HD) over (B, S/T, NH, HD) to (B, NH, S/T, HD) (NH*HD=D where D is model dimension) 416 | query, key, value = [net(x).view(batch_size, -1, self.number_of_heads, self.head_dimension).transpose(1, 2) 417 | for net, x in zip(self.qkv_nets, (query, key, value))] 418 | 419 | # Step 2: Apply attention - compare query with key and use that to combine values (see the function for details) 420 | intermediate_token_representations, attention_weights = self.attention(query, key, value, mask) 421 | 422 | # Potentially, for visualization purposes, log the attention weights, turn off during training though! 423 | # I had memory problems when I leave this on by default 424 | if self.log_attention_weights: 425 | self.attention_weights = attention_weights 426 | 427 | # Step 3: Reshape from (B, NH, S/T, HD) over (B, S/T, NH, HD) (via transpose) into (B, S/T, NHxHD) which is 428 | # the same shape as in the beginning of this forward function i.e. input to MHA (multi-head attention) module 429 | reshaped = intermediate_token_representations.transpose(1, 2).reshape(batch_size, -1, self.number_of_heads * self.head_dimension) 430 | 431 | # Step 4: Output linear projection 432 | token_representations = self.out_projection_net(reshaped) 433 | 434 | return token_representations 435 | 436 | 437 | # 438 | # Input modules 439 | # 440 | 441 | 442 | class Embedding(nn.Module): 443 | 444 | def __init__(self, vocab_size, model_dimension): 445 | super().__init__() 446 | self.embeddings_table = nn.Embedding(vocab_size, model_dimension) 447 | self.model_dimension = model_dimension 448 | 449 | def forward(self, token_ids_batch): 450 | assert token_ids_batch.ndim == 2, f'Expected: (batch size, max token sequence length), got {token_ids_batch.shape}' 451 | 452 | # token_ids_batch has shape (B, S/T), where B - batch size, S/T max src/trg token-sequence length 453 | # Final shape will be (B, S/T, D) where D is the model dimension, every token id has associated vector 454 | embeddings = self.embeddings_table(token_ids_batch) 455 | 456 | # (stated in the paper) multiply the embedding weights by the square root of model dimension 457 | # Page 5, Chapter 3.4 "Embeddings and Softmax" 458 | return embeddings * math.sqrt(self.model_dimension) 459 | 460 | 461 | class PositionalEncoding(nn.Module): 462 | 463 | def __init__(self, model_dimension, dropout_probability, expected_max_sequence_length=5000): 464 | super().__init__() 465 | self.dropout = nn.Dropout(p=dropout_probability) 466 | 467 | # (stated in the paper) Use sine functions whose frequencies form a geometric progression as position encodings, 468 | # (learning encodings will also work so feel free to change it!). Page 6, Chapter 3.5 "Positional Encoding" 469 | position_id = torch.arange(0, expected_max_sequence_length).unsqueeze(1) 470 | frequencies = torch.pow(10000., -torch.arange(0, model_dimension, 2, dtype=torch.float) / model_dimension) 471 | 472 | # Checkout playground.py for visualization of how these look like (it's super simple don't get scared) 473 | positional_encodings_table = torch.zeros(expected_max_sequence_length, model_dimension) 474 | positional_encodings_table[:, 0::2] = torch.sin(position_id * frequencies) # sine on even positions 475 | positional_encodings_table[:, 1::2] = torch.cos(position_id * frequencies) # cosine on odd positions 476 | 477 | # Register buffer because we want to save the positional encodings table inside state_dict even though 478 | # these are not trainable (not model's parameters) so they otherwise would be excluded from the state_dict 479 | self.register_buffer('positional_encodings_table', positional_encodings_table) 480 | 481 | def forward(self, embeddings_batch): 482 | assert embeddings_batch.ndim == 3 and embeddings_batch.shape[-1] == self.positional_encodings_table.shape[1], \ 483 | f'Expected (batch size, max token sequence length, model dimension) got {embeddings_batch.shape}' 484 | 485 | # embedding_batch's shape = (B, S/T, D), where S/T max src/trg token-sequence length, D - model dimension 486 | # So here we get (S/T, D) shape which will get broad-casted to (B, S/T, D) when we try and add it to embeddings 487 | positional_encodings = self.positional_encodings_table[:embeddings_batch.shape[1]] 488 | 489 | # (stated in the paper) Applying dropout to the sum of positional encodings and token embeddings 490 | # Page 7, Chapter 5.4 "Regularization" 491 | return self.dropout(embeddings_batch + positional_encodings) 492 | 493 | 494 | # 495 | # Helper model functions 496 | # 497 | 498 | 499 | def get_clones(module, num_of_deep_copies): 500 | # Create deep copies so that we can tweak each module's weights independently 501 | return nn.ModuleList([copy.deepcopy(module) for _ in range(num_of_deep_copies)]) 502 | 503 | 504 | # Count how many trainable weights the model has <- just for having a feeling for how big the model is 505 | def count_parameters(model): 506 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 507 | 508 | 509 | def analyze_state_dict_shapes_and_names(model): 510 | # This part helped me figure out that I don't have positional encodings saved in the state dict 511 | print(model.state_dict().keys()) 512 | 513 | # This part helped me see that src MHA was missing in the decoder since both it and trg MHA were referencing 514 | # the same MHA object in memory - stupid mistake, happens all the time, embrace the suck! 515 | for name, param in model.named_parameters(): 516 | print(name, param.shape) 517 | if not param.requires_grad: 518 | raise Exception('Expected all of the params to be trainable - no param freezing used.') 519 | 520 | 521 | # Testing the correctness of the transformer model - feel free to ignore - I used it during model development 522 | if __name__ == "__main__": 523 | use_big_transformer = False 524 | 525 | # Dummy data 526 | src_vocab_size = 11 527 | trg_vocab_size = 11 528 | src_token_ids_batch = torch.randint(1, 10, size=(3, 2)) 529 | trg_token_ids_batch = torch.randint(1, 10, size=(3, 2)) 530 | 531 | transformer = Transformer( 532 | model_dimension=BIG_MODEL_DIMENSION if use_big_transformer else BASELINE_MODEL_DIMENSION, 533 | src_vocab_size=src_vocab_size, 534 | trg_vocab_size=trg_vocab_size, 535 | number_of_heads=BIG_MODEL_NUMBER_OF_HEADS if use_big_transformer else BASELINE_MODEL_NUMBER_OF_HEADS, 536 | number_of_layers=BIG_MODEL_NUMBER_OF_LAYERS if use_big_transformer else BASELINE_MODEL_NUMBER_OF_LAYERS, 537 | dropout_probability=BIG_MODEL_DROPOUT_PROB if use_big_transformer else BASELINE_MODEL_DROPOUT_PROB 538 | ) 539 | 540 | # These 2 functions helped me figure out the 2 bugs I had: 541 | # 1) I did not register positional encodings and thus they wouldn't be saved and later model-loading would fail 542 | # 2) I had a bug with MHA (attention) in decoder, where both src and trg were referencing the same MHA object in mem 543 | # It's a good practice to see whether the names, shapes and number of params make sense. 544 | # e.g. I knew that the big transformer had ~175 M params and I verified that here. 545 | analyze_state_dict_shapes_and_names(transformer) 546 | print(f'Size of the {"big" if use_big_transformer else "baseline"} transformer = {count_parameters(transformer)}') 547 | 548 | out = transformer(src_token_ids_batch, trg_token_ids_batch, src_mask=None, trg_mask=None) 549 | -------------------------------------------------------------------------------- /robust_measures.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from "In search of robust measures of generalization" by Dziugaite et al. 3 | https://github.com/nitarshan/robust-generalization-measures.git 4 | """ 5 | 6 | from contextlib import contextmanager 7 | from copy import deepcopy 8 | import math 9 | from typing import List, Tuple 10 | from enum import Enum 11 | import numpy as np 12 | import torch 13 | from torch import Tensor 14 | from torch.utils.data.dataloader import DataLoader 15 | from torch import Tensor 16 | import torch.nn as nn 17 | 18 | from utils.utils_CKA import * 19 | from utils.data_utils import get_masks_and_count_tokens, get_src_and_trg_batches 20 | from utils.optimizers_and_distributions import LabelSmoothingDistribution 21 | 22 | class CT(Enum): 23 | # Measures from Fantastic Generalization Measures (equation numbers) 24 | PARAMS = 20 25 | INVERSE_MARGIN = 22 26 | LOG_SPEC_INIT_MAIN = 29 27 | LOG_SPEC_ORIG_MAIN = 30 28 | LOG_PROD_OF_SPEC_OVER_MARGIN = 31 29 | LOG_PROD_OF_SPEC = 32 30 | FRO_OVER_SPEC = 33 31 | LOG_SUM_OF_SPEC_OVER_MARGIN = 34 32 | LOG_SUM_OF_SPEC = 35 33 | LOG_PROD_OF_FRO_OVER_MARGIN = 36 34 | LOG_PROD_OF_FRO = 37 35 | LOG_SUM_OF_FRO_OVER_MARGIN = 38 36 | LOG_SUM_OF_FRO = 39 37 | FRO_DIST = 40 38 | DIST_SPEC_INIT = 41 39 | PARAM_NORM = 42 40 | PATH_NORM_OVER_MARGIN = 43 41 | PATH_NORM = 44 42 | PACBAYES_INIT = 48 43 | PACBAYES_ORIG = 49 44 | PACBAYES_FLATNESS = 53 45 | PACBAYES_MAG_INIT = 56 46 | PACBAYES_MAG_ORIG = 57 47 | PACBAYES_MAG_FLATNESS = 61 48 | # Other Measures 49 | L2 = 100 50 | L2_DIST = 101 51 | # FFT Spectral Measures 52 | LOG_SPEC_INIT_MAIN_FFT = 129 53 | LOG_SPEC_ORIG_MAIN_FFT = 130 54 | LOG_PROD_OF_SPEC_OVER_MARGIN_FFT = 131 55 | LOG_PROD_OF_SPEC_FFT = 132 56 | FRO_OVER_SPEC_FFT = 133 57 | LOG_SUM_OF_SPEC_OVER_MARGIN_FFT = 134 58 | LOG_SUM_OF_SPEC_FFT = 135 59 | DIST_SPEC_INIT_FFT = 141 60 | 61 | @torch.no_grad() 62 | def eval_batch(batch, model, labels): 63 | 64 | input_ids = batch['input_ids'].cuda() 65 | attention_mask = batch['attention_mask'].cuda() 66 | outputs = model(input_ids, attention_mask=attention_mask, labels=labels) 67 | return outputs 68 | 69 | @torch.no_grad() 70 | def eval_acc(model, eval_loader): 71 | 72 | model.eval() 73 | 74 | num = 0 75 | correct = 0 76 | 77 | for batch in eval_loader: 78 | labels = batch['labels'].cuda() 79 | outputs = eval_batch(batch, model, labels) 80 | predictions = outputs.logits.argmax(dim=-1) 81 | num += len(labels) 82 | correct += (labels==predictions).sum().item() 83 | 84 | assert num>0 85 | acc = correct/num 86 | print(f"Evaluate accuracy = {acc}.") 87 | return acc 88 | 89 | @torch.no_grad() 90 | def eval_NMT_loss(model, dataloader, pad_token_id=None, trg_vocab_size=0, NMT_maximum_samples = 10000): 91 | ## This function is used to calculate the training loss of machine translation. 92 | 93 | num_processed_samples = 0 94 | device = next(model.parameters()).device 95 | training_loss = 0 96 | loss_step = 0 97 | 98 | for _, token_ids_batch in enumerate(dataloader): 99 | 100 | src_token_ids_batch, trg_token_ids_batch_input, target = get_src_and_trg_batches(token_ids_batch) 101 | num_processed_samples += token_ids_batch.batch_size 102 | src_mask, trg_mask, _, _ = get_masks_and_count_tokens(src_token_ids_batch, trg_token_ids_batch_input, pad_token_id, device) 103 | logits = model(src_token_ids_batch, trg_token_ids_batch_input, src_mask, trg_mask) 104 | 105 | kl_div_loss = nn.KLDivLoss(reduction='batchmean') 106 | label_smoothing = LabelSmoothingDistribution(0, pad_token_id, trg_vocab_size, device) # Use label smoothing = 0 here 107 | smooth_target_distributions = label_smoothing(target) # these are regular probabilities 108 | loss = kl_div_loss(logits, smooth_target_distributions) 109 | training_loss += loss.item() 110 | loss_step += 1 111 | 112 | if num_processed_samples>=NMT_maximum_samples: 113 | break 114 | 115 | training_loss = training_loss/loss_step 116 | print(f"NMT training loss is {training_loss}") 117 | 118 | return training_loss 119 | 120 | # Adapted from https://github.com/bneyshabur/generalization-bounds/blob/master/measures.py 121 | @torch.no_grad() 122 | def _reparam(model): 123 | def in_place_reparam(model, prev_layer=None): 124 | for child in model.children(): 125 | prev_layer = in_place_reparam(child, prev_layer) 126 | if child._get_name() == 'Conv2d': 127 | prev_layer = child 128 | elif child._get_name() == 'BatchNorm2d': 129 | scale = child.weight / ((child.running_var + child.eps).sqrt()) 130 | prev_layer.bias.copy_( child.bias + ( scale * (prev_layer.bias - child.running_mean) ) ) 131 | perm = list(reversed(range(prev_layer.weight.dim()))) 132 | prev_layer.weight.copy_((prev_layer.weight.permute(perm) * scale ).permute(perm)) 133 | child.bias.fill_(0) 134 | child.weight.fill_(1) 135 | child.running_mean.fill_(0) 136 | child.running_var.fill_(1) 137 | return prev_layer 138 | model = deepcopy(model) 139 | in_place_reparam(model) 140 | return model 141 | 142 | 143 | @contextmanager 144 | def _perturbed_model( 145 | model, 146 | sigma, 147 | rng, 148 | magnitude_eps = None 149 | ): 150 | device = next(model.parameters()).device 151 | if magnitude_eps is not None: 152 | noise = [torch.normal(0,sigma**2 * torch.abs(p) ** 2 + magnitude_eps ** 2, generator=rng) for p in model.parameters()] 153 | else: 154 | noise = [torch.normal(0,sigma**2,p.shape, generator=rng).to(device) for p in model.parameters()] 155 | model = deepcopy(model) 156 | try: 157 | [p.add_(n) for p,n in zip(model.parameters(), noise)] 158 | yield model 159 | finally: 160 | [p.sub_(n) for p,n in zip(model.parameters(), noise)] 161 | del model 162 | 163 | 164 | # Adapted from https://drive.google.com/file/d/1_6oUG94d0C3x7x2Vd935a2QqY-OaAWAM/view 165 | def _pacbayes_sigma( 166 | model, 167 | dataloader, 168 | accuracy, 169 | seed, 170 | magnitude_eps = None, 171 | search_depth = 15, 172 | montecarlo_samples = 10, 173 | accuracy_displacement = 0.1, 174 | displacement_tolerance = 1e-2, 175 | task_type = 'normal', 176 | pad_token_id = None, 177 | trg_vocab_size = 0, 178 | search_upper_limit = 0.2 179 | ) -> float: 180 | 181 | if task_type == 'NMT' and magnitude_eps: 182 | # This is a tricky case. It seems that using search_upper_limit=0.2 is not large enough 183 | search_upper_limit = 2 184 | 185 | lower, upper = 0, search_upper_limit 186 | sigma = 0.1 187 | 188 | BIG_NUMBER = 10348628753 189 | device = next(model.parameters()).device 190 | rng = torch.Generator(device=device) if magnitude_eps is not None else torch.Generator() 191 | rng.manual_seed(BIG_NUMBER + seed) 192 | 193 | if not accuracy and task_type == 'NMT': 194 | # In this case, the training accuracy is hard to evaluate 195 | # So we use the training loss instead 196 | # It is training loss, but we still call it "accuracy" to follow the convention 197 | print("Evaluate training loss using the original model.") 198 | accuracy = eval_NMT_loss(model, dataloader, pad_token_id=pad_token_id, trg_vocab_size=trg_vocab_size) 199 | accuracy_displacement = 0.5 200 | displacement_tolerance = 0.05 201 | 202 | print(f"Start binary search for PAC-Bayes sigma.") 203 | for _ in range(search_depth): 204 | sigma = (lower + upper) / 2 205 | # If sigma > search_upper_limit - 0.01, most likely the search is stuck because upper limit is too small 206 | if sigma > search_upper_limit * 0.95: 207 | return search_upper_limit 208 | 209 | accuracy_samples = [] 210 | print(f"Getting samples for current sigma.") 211 | for _ in range(montecarlo_samples): 212 | print(f"current sigma is {sigma}") 213 | with _perturbed_model(model, sigma, rng, magnitude_eps) as p_model: 214 | # The following code is replaced with a method of evaluating accuracy 215 | #loss_estimate = 0 216 | #for data, target in dataloader: 217 | # logits = p_model(data) 218 | # pred = logits.data.max(1, keepdim=True)[1] # get the index of the max logits 219 | # batch_correct = pred.eq(target.data.view_as(pred)).type(torch.FloatTensor).cpu() 220 | # loss_estimate += batch_correct.sum() 221 | #loss_estimate /= len(dataloader.dataset) 222 | if task_type == 'NMT': 223 | loss_estimate = eval_NMT_loss(p_model, dataloader, pad_token_id=pad_token_id, trg_vocab_size=trg_vocab_size) 224 | else: 225 | loss_estimate = eval_acc(p_model, dataloader) 226 | accuracy_samples.append(loss_estimate) 227 | displacement = abs(np.mean(accuracy_samples) - accuracy) 228 | if abs(displacement - accuracy_displacement) < displacement_tolerance: 229 | break 230 | elif displacement > accuracy_displacement: 231 | # Too much perturbation 232 | upper = sigma 233 | else: 234 | # Not perturbed enough to reach target displacement 235 | lower = sigma 236 | return sigma 237 | 238 | def W_CKA(p,q, feature_space=True): 239 | 240 | eps=1e-15 241 | p = p.data.numpy() 242 | q = q.data.numpy() 243 | if np.sqrt(np.sum((p-q)**2)) < eps: 244 | return 1.0 245 | if feature_space: 246 | return feature_space_linear_cka(p, q) 247 | else: 248 | return cka_compute(gram_linear(p, q)) 249 | 250 | @torch.no_grad() 251 | def get_all_measures( 252 | model, 253 | init_model, 254 | dataloader, 255 | acc, 256 | seed, 257 | no_path_norm=True, 258 | no_exact_spectral_norm=True, 259 | no_pac_bayes=False, 260 | no_margin=False, 261 | no_basics=False, 262 | no_CKA=True, 263 | task_type='NMT', 264 | path_norm_transformer=None, 265 | pad_token_id=None, 266 | trg_vocab_size=0, 267 | pacbayes_depth=15 268 | ): 269 | measures = {} 270 | 271 | model = _reparam(model) 272 | init_model = _reparam(init_model) 273 | 274 | device = next(model.parameters()).device 275 | m = len(dataloader.dataset) 276 | 277 | def get_weights_only(model): 278 | blacklist = {'bias', 'bn'} 279 | return [p for name, p in model.named_parameters() if all(x not in name for x in blacklist)] 280 | 281 | weights = get_weights_only(model) 282 | init_weights = get_weights_only(init_model) 283 | weights_cpu = [p.to("cpu") for p in weights] 284 | init_weights_cpu = [p.to("cpu") for p in init_weights] 285 | dist_init_weights = [p-q for p,q in zip(weights, init_weights)] 286 | d = len(weights) 287 | 288 | def get_vec_params(weights: List[Tensor]) -> Tensor: 289 | return torch.cat([p.view(-1) for p in weights], dim=0) 290 | 291 | w_vec = get_vec_params(weights) 292 | dist_w_vec = get_vec_params(dist_init_weights) 293 | num_params = len(w_vec) 294 | 295 | if not no_CKA: 296 | measures["W_CKA"] = np.mean([W_CKA(p,q, feature_space=True) for p,q in zip(weights_cpu, init_weights_cpu) if len(p.shape)>1]) 297 | 298 | def get_reshaped_weights(weights: List[Tensor]) -> List[Tensor]: 299 | # If the weight is a tensor (e.g. a 4D Conv2d weight), it will be reshaped to a 2D matrix 300 | return [p.view(p.shape[0],-1) for p in weights] 301 | 302 | reshaped_weights = get_reshaped_weights(weights) 303 | dist_reshaped_weights = get_reshaped_weights(dist_init_weights) 304 | 305 | if not no_basics: 306 | print("Vector Norm Measures") 307 | measures["L2"] = w_vec.norm(p=2) 308 | measures["L2_DIST"] = dist_w_vec.norm(p=2) 309 | 310 | print("VC-Dimension Based Measures") 311 | measures["PARAMS"] = torch.tensor(num_params) # 20 312 | 313 | if not no_margin: 314 | print("Measures on the output of the network") 315 | def _calculate_margin( 316 | logits, 317 | target 318 | ): 319 | correct_logit = logits[torch.arange(logits.shape[0]), target].clone() 320 | logits[torch.arange(logits.shape[0]), target] = float('-inf') 321 | max_other_logit = logits.data.max(1).values # get the index of the max logits 322 | margin = correct_logit - max_other_logit 323 | return margin 324 | 325 | @torch.no_grad() 326 | def _margin( 327 | model, 328 | dataloader, 329 | task_type='normal', 330 | pad_token_id=None, 331 | NMT_maximum_samples = 10000, 332 | ) -> Tensor: 333 | margins = [] 334 | if task_type=='NMT': 335 | num_processed_samples = 0 336 | for batch_id, token_ids_batch in enumerate(dataloader): 337 | src_token_ids_batch, trg_token_ids_batch_input, target = get_src_and_trg_batches(token_ids_batch) 338 | num_processed_samples += token_ids_batch.batch_size 339 | src_mask, trg_mask, num_src_tokens, num_trg_tokens = get_masks_and_count_tokens(src_token_ids_batch, trg_token_ids_batch_input, pad_token_id, device) 340 | logits = model(src_token_ids_batch, trg_token_ids_batch_input, src_mask, trg_mask, no_softmax=True) # do not use softmax 341 | margins.append(_calculate_margin(logits.clone(),target.flatten())) 342 | 343 | if num_processed_samples >= NMT_maximum_samples: 344 | print(f"There are {num_processed_samples} sentences processed when calculating the margin.") 345 | break 346 | 347 | margin_distribution = torch.cat(margins) 348 | return margin_distribution.kthvalue(len(margin_distribution) // 10)[0] 349 | 350 | else: 351 | for batch in dataloader: 352 | target = batch['labels'].cuda() 353 | outputs = eval_batch(batch, model, target) 354 | logits = outputs.logits 355 | margins.append(_calculate_margin(logits,target)) 356 | 357 | return torch.cat(margins).kthvalue(m // 10)[0] 358 | 359 | true_margin = _margin(model, dataloader, task_type, pad_token_id) 360 | measures["TRUE_MARGIN"] = true_margin # Only used for checking if the true margin could become negative 361 | margin = true_margin.abs() 362 | measures["INVERSE_MARGIN"] = torch.tensor(1, device=device) / margin ** 2 # 22 363 | 364 | if not no_basics: 365 | print("(Norm & Margin)-Based Measures") 366 | fro_norms = torch.cat([p.norm('fro').unsqueeze(0) ** 2 for p in reshaped_weights]) 367 | print("Starting SVD calculations which may occupy large memory.") 368 | spec_norms = torch.cat([p.svd().S.max().unsqueeze(0) ** 2 for p in reshaped_weights]) 369 | print("End SVD calculations.") 370 | dist_fro_norms = torch.cat([p.norm('fro').unsqueeze(0) ** 2 for p in dist_reshaped_weights]) 371 | dist_spec_norms = torch.cat([p.svd().S.max().unsqueeze(0) ** 2 for p in dist_reshaped_weights]) 372 | 373 | print("Approximate Spectral Norm") 374 | # Note that these use an approximation from [Yoshida and Miyato, 2017] 375 | # https://arxiv.org/abs/1705.10941 (Section 3.2, Convolutions) 376 | measures["LOG_PROD_OF_SPEC"] = spec_norms.log().sum() # 32 377 | measures["FRO_OVER_SPEC"] = (fro_norms / spec_norms).sum() # 33 378 | measures["LOG_SUM_OF_SPEC"] = math.log(d) + (1/d) * measures["LOG_PROD_OF_SPEC"] # 35 379 | 380 | if not no_margin: 381 | measures["LOG_PROD_OF_SPEC_OVER_MARGIN"] = measures["LOG_PROD_OF_SPEC"] - 2 * margin.log() # 31 382 | measures["LOG_SPEC_INIT_MAIN"] = measures["LOG_PROD_OF_SPEC_OVER_MARGIN"] + (dist_fro_norms / spec_norms).sum().log() # 29 383 | measures["LOG_SPEC_ORIG_MAIN"] = measures["LOG_PROD_OF_SPEC_OVER_MARGIN"] + measures["FRO_OVER_SPEC"].log() # 30 384 | measures["LOG_SUM_OF_SPEC_OVER_MARGIN"] = math.log(d) + (1/d) * (measures["LOG_PROD_OF_SPEC"] - 2 * margin.log()) # 34 385 | 386 | if not no_basics: 387 | print("Frobenius Norm") 388 | measures["LOG_PROD_OF_FRO"] = fro_norms.log().sum() # 37 389 | measures["LOG_SUM_OF_FRO"] = math.log(d) + (1/d) * measures["LOG_PROD_OF_FRO"] # 39 390 | if not no_margin: 391 | measures["LOG_PROD_OF_FRO_OVER_MARGIN"] = measures["LOG_PROD_OF_FRO"] - 2 * margin.log() # 36 392 | measures["LOG_SUM_OF_FRO_OVER_MARGIN"] = math.log(d) + (1/d) * (measures["LOG_PROD_OF_FRO"] - 2 * margin.log()) # 38 393 | 394 | print("Distance to Initialization") 395 | measures["FRO_DIST"] = dist_fro_norms.sum() # 40 396 | measures["DIST_SPEC_INIT"] = dist_spec_norms.sum() # 41 397 | measures["PARAM_NORM"] = fro_norms.sum() # 42 398 | 399 | if not no_path_norm: 400 | print("Path-norm") 401 | # Adapted from https://github.com/bneyshabur/generalization-bounds/blob/master/measures.py#L98 402 | def _path_norm(model): 403 | model = deepcopy(model) 404 | model.eval() 405 | for param in model.parameters(): 406 | if param.requires_grad: 407 | param.data.pow_(2) 408 | # path norm requires all 1 input 409 | # we construct the all 1 input using length-1 sequence 410 | model.src_embedding.embeddings_table.weight.data = torch.ones_like(model.src_embedding.embeddings_table.weight.data) 411 | model.src_pos_embedding.positional_encodings_table.data = torch.zeros_like(model.src_pos_embedding.positional_encodings_table.data) 412 | model.trg_embedding.embeddings_table.weight.data = torch.ones_like(model.trg_embedding.embeddings_table.weight.data) 413 | model.trg_pos_embedding.positional_encodings_table.data = torch.zeros_like(model.trg_pos_embedding.positional_encodings_table.data) 414 | 415 | if task_type == 'NMT': 416 | src_token=torch.ones(1,1).long() 417 | trg_token=torch.ones(1,1).long() 418 | src_mask=torch.ones(1,1,1,1)>0 419 | trg_mask=torch.ones(1,1,1,1)>0 420 | x = model(src_token, trg_token, src_mask, trg_mask) 421 | else: 422 | raise ValueError 423 | del model 424 | return x.sum() 425 | 426 | measures["PATH_NORM"] = _path_norm(path_norm_transformer) # 44 427 | if not no_margin: 428 | measures["PATH_NORM_OVER_MARGIN"] = measures["PATH_NORM"] / margin ** 2 # 43 429 | 430 | if not no_exact_spectral_norm: 431 | print("Exact Spectral Norm") 432 | # Proposed in https://arxiv.org/abs/1805.10408 433 | # Adapted from https://github.com/brain-research/conv-sv/blob/master/conv2d_singular_values.py#L52 434 | def _spectral_norm_fft(kernel: Tensor, input_shape: Tuple[int, int]) -> Tensor: 435 | # PyTorch conv2d filters use Shape(out,in,kh,kw) 436 | # [Sedghi 2018] code expects filters of Shape(kh,kw,in,out) 437 | # Pytorch doesn't support complex FFT and SVD, so we do this in numpy 438 | np_kernel = np.einsum('oihw->hwio', kernel.data.cpu().numpy()) 439 | transforms = np.fft.fft2(np_kernel, input_shape, axes=[0, 1]) # Shape(ih,iw,in,out) 440 | singular_values = np.linalg.svd(transforms, compute_uv=False) # Shape(ih,iw,min(in,out)) 441 | spec_norm = singular_values.max() 442 | return torch.tensor(spec_norm, device=kernel.device) 443 | 444 | input_shape = (model.dataset_type.D[1], model.dataset_type.D[2]) 445 | fft_spec_norms = torch.cat([_spectral_norm_fft(p, input_shape).unsqueeze(0) ** 2 for p in weights]) 446 | fft_dist_spec_norms = torch.cat([_spectral_norm_fft(p, input_shape).unsqueeze(0) ** 2 for p in dist_init_weights]) 447 | 448 | measures[CT.LOG_PROD_OF_SPEC_FFT] = fft_spec_norms.log().sum() # 32 449 | measures[CT.LOG_PROD_OF_SPEC_OVER_MARGIN_FFT] = measures[CT.LOG_PROD_OF_SPEC_FFT] - 2 * margin.log() # 31 450 | measures[CT.FRO_OVER_SPEC_FFT] = (fro_norms / fft_spec_norms).sum() # 33 451 | measures[CT.LOG_SUM_OF_SPEC_OVER_MARGIN_FFT] = math.log(d) + (1/d) * (measures[CT.LOG_PROD_OF_SPEC_FFT] - 2 * margin.log()) # 34 452 | measures[CT.LOG_SUM_OF_SPEC_FFT] = math.log(d) + (1/d) * measures[CT.LOG_PROD_OF_SPEC_FFT] # 35 453 | measures[CT.DIST_SPEC_INIT_FFT] = fft_dist_spec_norms.sum() # 41 454 | measures[CT.LOG_SPEC_INIT_MAIN_FFT] = measures[CT.LOG_PROD_OF_SPEC_OVER_MARGIN_FFT] + (dist_fro_norms / fft_spec_norms).sum().log() # 29 455 | measures[CT.LOG_SPEC_ORIG_MAIN_FFT] = measures[CT.LOG_PROD_OF_SPEC_OVER_MARGIN_FFT] + measures[CT.FRO_OVER_SPEC_FFT].log() # 30 456 | 457 | if not no_pac_bayes: 458 | print("Flatness-based measures") 459 | sigma = _pacbayes_sigma(model, dataloader, acc, seed, search_depth=pacbayes_depth, task_type=task_type, pad_token_id=pad_token_id, trg_vocab_size=trg_vocab_size) 460 | def _pacbayes_bound(reference_vec: Tensor) -> Tensor: 461 | return (reference_vec.norm(p=2) ** 2) / (4 * sigma ** 2) + math.log(m / sigma) + 10 462 | measures["PACBAYES_INIT"] = _pacbayes_bound(dist_w_vec) # 48 463 | measures["PACBAYES_ORIG"] = _pacbayes_bound(w_vec) # 49 464 | measures["PACBAYES_FLATNESS"] = torch.tensor(1 / sigma ** 2) # 53 465 | 466 | print("Magnitude-aware Perturbation Bounds") 467 | mag_eps = 1e-3 468 | mag_sigma = _pacbayes_sigma(model, dataloader, acc, seed, mag_eps, search_depth=pacbayes_depth, task_type=task_type, pad_token_id=pad_token_id, trg_vocab_size=trg_vocab_size) 469 | omega = num_params 470 | def _pacbayes_mag_bound(reference_vec: Tensor) -> Tensor: 471 | numerator = mag_eps ** 2 + (mag_sigma ** 2 + 1) * (reference_vec.norm(p=2)**2) / omega 472 | denominator = mag_eps ** 2 + mag_sigma ** 2 * dist_w_vec ** 2 473 | return 1/4 * (numerator / denominator).log().sum() + math.log(m / mag_sigma) + 10 474 | measures["PACBAYES_MAG_INIT"] = _pacbayes_mag_bound(dist_w_vec) # 56 475 | measures["PACBAYES_MAG_ORIG"] = _pacbayes_mag_bound(w_vec) # 57 476 | measures["PACBAYES_MAG_FLATNESS"] = torch.tensor(1 / mag_sigma ** 2) # 61 477 | 478 | # Adjust for dataset size 479 | def adjust_measure(measure: CT, value: float) -> float: 480 | #if measure.name.startswith('LOG_'): 481 | if measure.startswith('LOG_'): 482 | return 0.5 * (value - np.log(m)) 483 | elif 'CKA' in measure or 'TRUE_MARGIN' in measure: 484 | return value 485 | else: 486 | return np.sqrt(value / m) 487 | return {k: adjust_measure(k, v.item()) for k, v in measures.items()} -------------------------------------------------------------------------------- /scripts/generate_script.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Use this file to generate the slurm array configurations \n", 8 | "## Also, note that the code will generate the checkpoint folders in the folder that you specify below" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "metadata": {}, 15 | "outputs": [], 16 | "source": [ 17 | "# Remember to change the checkpoint directory\n", 18 | "ckpt_root = '/work/yyaoqing/Good_vs_bad_data/checkpoint/NMT_epochs/Simpson/'\n", 19 | "\n", 20 | "# Select the \"sample x learning rate x depth\" grid or the \"sample x learning rate x width\" grid\n", 21 | "grid = 'depth' # choices = ['depth', 'width']\n", 22 | "\n", 23 | "# Select training or evaluating generalization metrics\n", 24 | "experiment = 'train' # choices = ['train', 'eval_metrics']" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 2, 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "import os\n", 34 | "\n", 35 | "assert os.path.exists(ckpt_root)\n", 36 | "assert grid in ['depth', 'width']\n", 37 | "assert experiment in ['train', 'eval_metrics']\n", 38 | "\n", 39 | "lrs = [0.0625, 0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 1.0]\n", 40 | "widths = [256, 384, 512, 768, 1024]\n", 41 | "width_standard = 512\n", 42 | "head_standard = 8\n", 43 | "heads = [4, 6, 8, 12, 16]\n", 44 | "samples = [160000, 320000, 640000, 1280000, 2560000]\n", 45 | "depths = [4, 5, 6, 7, 8]\n", 46 | "depth_standard = 6\n", 47 | "dropout = 0.1" 48 | ] 49 | }, 50 | { 51 | "cell_type": "markdown", 52 | "metadata": {}, 53 | "source": [ 54 | "### Use the following code to generate the config files for training and evaluation" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 17, 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "def change_configure_file(configs, sample, depth, width, lr, dropout, head):\n", 64 | " \n", 65 | " hyperparameter_string = f'WMT14_sample{sample}_depth{depth}_width{width}_lr{lr}_dropout{dropout}'\n", 66 | " ckpt_folder = os.path.join(ckpt_root, hyperparameter_string)\n", 67 | " if not os.path.exists(ckpt_folder):\n", 68 | " os.makedirs(ckpt_folder)\n", 69 | " \n", 70 | " for task, suffix in zip(tasks, suffixes):\n", 71 | " task_completion_file = os.path.join(ckpt_folder, suffix)\n", 72 | "\n", 73 | " if task == 'train':\n", 74 | " ## Check if the final training result exists.\n", 75 | " ## If not, put the task in the training list.\n", 76 | " if os.path.exists(task_completion_file):\n", 77 | " print(f\"Task {task} finished for sample={sample}, lr={lr}, depth={depth}, width={width}.\")\n", 78 | " else:\n", 79 | " configs[task].append(f\"{sample} {depth} {width} {lr} {dropout} {head}\")\n", 80 | " else:\n", 81 | " ## Check if the final training result exists and if the evaluation result does not exist.\n", 82 | " ## If so, put the task in the evaluation list.\n", 83 | " training_completion_file = os.path.join(ckpt_folder, 'net_epoch_20.ckpt')\n", 84 | " if os.path.exists(training_completion_file) and not os.path.exists(task_completion_file):\n", 85 | " configs[task].append(f\"{sample} {depth} {width} {lr} {dropout} {head}\")\n", 86 | "\n", 87 | " \n", 88 | "if experiment=='train':\n", 89 | " tasks = ['train']\n", 90 | " suffixes = ['net_epoch_20.ckpt']\n", 91 | "else:\n", 92 | " tasks = ['bleu', 'ww_tpl', 'ww_pl', 'ww_exponential', 'robust']\n", 93 | " suffixes = ['bleu_loss.jsonl', 'results.pkl', 'results_original_alpha.pkl', 'results_exponential.pkl', 'robust_measures.pkl']\n", 94 | "\n", 95 | "configs = {x:[] for x in tasks}\n", 96 | "\n", 97 | "## The following code only puts unfinished tasks to the configure file\n", 98 | "\n", 99 | "for sample in samples:\n", 100 | " for lr in lrs:\n", 101 | " if grid == 'depth':\n", 102 | " for depth in depths:\n", 103 | " width = width_standard\n", 104 | " head = head_standard\n", 105 | " change_configure_file(configs, sample, depth, width, lr, dropout, head)\n", 106 | " else:\n", 107 | " for width, head in zip(widths, heads):\n", 108 | " depth = depth_standard\n", 109 | " change_configure_file(configs, sample, depth, width, lr, dropout, head)\n", 110 | " \n", 111 | "## write the unfinished tasks into the final configuration file\n", 112 | "for task in tasks:\n", 113 | " with open(f'{task}_config.txt', 'w') as f:\n", 114 | " for line in configs[task]:\n", 115 | " f.write(line+'\\n')" 116 | ] 117 | } 118 | ], 119 | "metadata": { 120 | "kernelspec": { 121 | "display_name": "NLP_metrics", 122 | "language": "python", 123 | "name": "nlp_metrics" 124 | }, 125 | "language_info": { 126 | "codemirror_mode": { 127 | "name": "ipython", 128 | "version": 3 129 | }, 130 | "file_extension": ".py", 131 | "mimetype": "text/x-python", 132 | "name": "python", 133 | "nbconvert_exporter": "python", 134 | "pygments_lexer": "ipython3", 135 | "version": "3.8.3" 136 | } 137 | }, 138 | "nbformat": 4, 139 | "nbformat_minor": 4 140 | } 141 | -------------------------------------------------------------------------------- /scripts/hyperparameter_correlation.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -p rise # partition (queue) 3 | #SBATCH -N 1 # number of nodes requested 4 | #SBATCH -n 1 # number of tasks (i.e. processes) 5 | #SBATCH --cpus-per-task=16 # number of cores per task 6 | ##SBATCH --gres=gpu:1 # number of GPUs (should match -n) 7 | #SBATCH --nodelist=havoc # if you need specific nodes 8 | ##SBATCH --exclude=como,manchester,blaze,flaminio,freddie,r[1-6,8-16],havoc,steropes,atlas 9 | #SBATCH -t 0-01:00 # time requested (D-HH:MM) 10 | #SBATCH -D /data/yyaoqing/Generalization_metrics_for_NLP/ # working directory 11 | #SBATCH -o slurm_logs/slurm.%N.%j..out # STDOUT 12 | #SBATCH -e slurm_logs/slurm.%N.%j..err # STDERR 13 | pwd 14 | hostname 15 | date 16 | echo starting job... 17 | source ~/.bashrc 18 | conda activate ww 19 | export PYTHONUNBUFFERED=1 20 | 21 | export OMP_NUM_THREADS=1 22 | python hyperparameter_correlation.py \ 23 | --metric $1 \ 24 | --bleu_type $2 \ 25 | --group $3 \ 26 | --distribution $4 \ 27 | --fitting_method $5 \ 28 | --dataset $6 \ 29 | --calculate_or_plot $7 \ 30 | --model_size_param $8 $9 \ 31 | 32 | wait 33 | date 34 | -------------------------------------------------------------------------------- /scripts/run_hyperparameter_correlation.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # There are three hyperparameter grids, two for WMT14 and one for IWSLT 4 | datasets=("WMT14" "IWSLT" "WMT14") 5 | size_params=("depth" "depth" "width") 6 | 7 | # For each grid, we calculate both the bleu scores and the generalization gap (train bleu - test bleu) 8 | bleu_types=("id_bleu" "id_bleu_gap") 9 | 10 | # For this experiment, we only calculate the rank correlations. The plots are generated from another bash file. 11 | calculation_parameter="calculate" 12 | 13 | # We test both with and without normalizing the metric by the number of samples 14 | for adjust_measure in '--adjust_measures_back' '' 15 | do 16 | for i in ${!datasets[@]}; do 17 | dataset=${datasets[$i]} 18 | size_param=${size_params[$i]} 19 | 20 | for fitting_method in 'ODR' 21 | do 22 | 23 | for bleu_type in ${bleu_types[@]} 24 | do 25 | for group in 'sample' 26 | do 27 | 28 | # This is for PL 29 | for metric in 'PL_alpha' 'PL_KS_distance' 'mp_softrank' 'log_norm' 'log_spectral_norm' 'PARAM_NORM' 'FRO_DIST' 'DIST_SPEC_INIT' 'PATH_NORM' 'stable_rank' 'alpha_weighted' 'log_alpha_norm' 'INVERSE_MARGIN' 'LOG_PROD_OF_SPEC_OVER_MARGIN' 'LOG_SUM_OF_SPEC_OVER_MARGIN' 'LOG_PROD_OF_FRO_OVER_MARGIN' 'LOG_SUM_OF_FRO_OVER_MARGIN' 'PATH_NORM_OVER_MARGIN' 'PACBAYES_INIT' 'PACBAYES_ORIG' 'PACBAYES_FLATNESS' 'PACBAYES_MAG_INIT' 'PACBAYES_MAG_ORIG' 'PACBAYES_MAG_FLATNESS' 30 | do 31 | sbatch scripts/hyperparameter_correlation.sh $metric $bleu_type $group power_law $fitting_method $dataset $calculation_parameter $size_param $adjust_measure 32 | done 33 | 34 | # This is for TPL 35 | for metric in 'E_TPL_lambda' 'E_TPL_KS_distance' 'E_TPL_beta' 'alpha_weighted' 'log_alpha_norm' 36 | do 37 | sbatch scripts/hyperparameter_correlation.sh $metric $bleu_type $group truncated_power_law $fitting_method $dataset $calculation_parameter $size_param $adjust_measure 38 | done 39 | 40 | 41 | # This is for EXP 42 | for metric in 'EXP_lambda' 43 | do 44 | sbatch scripts/hyperparameter_correlation.sh $metric $bleu_type $group exponential $fitting_method $dataset $calculation_parameter $size_param $adjust_measure 45 | done 46 | done 47 | done 48 | done 49 | done 50 | done -------------------------------------------------------------------------------- /scripts/run_plot_scatterplot.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Two datasets 4 | datasets=("WMT14" "IWSLT") 5 | size_params=("depth" "depth") 6 | 7 | # For each data, we plot the scatter plots for both the bleu scores and the generalization gap (train bleu - test bleu) 8 | bleu_types=("id_bleu" "id_bleu_gap") 9 | 10 | # We do plots instead of calculating correlations 11 | calculation_parameter="plot" 12 | 13 | # Generate scatter plots for metrics that are not normalized by the number of training samples 14 | for adjust_measure in '--adjust_measures_back' 15 | do 16 | for i in ${!datasets[@]}; do 17 | dataset=${datasets[$i]} 18 | size_param=${size_params[$i]} 19 | 20 | # We test two fitting algorithms, namely linear regression and orthogonal distance regression 21 | for fitting_method in 'ODR' 'LR' 22 | do 23 | 24 | for bleu_type in ${bleu_types[@]} 25 | do 26 | for group in 'sample' 'depth' 'lr' 27 | do 28 | 29 | # This is for PL 30 | for metric in 'PL_alpha' 'PL_KS_distance' 'mp_softrank' 'log_norm' 'log_spectral_norm' 'PARAM_NORM' 'FRO_DIST' 'DIST_SPEC_INIT' 'PATH_NORM' 'stable_rank' 'alpha_weighted' 'log_alpha_norm' 'INVERSE_MARGIN' 'LOG_PROD_OF_SPEC_OVER_MARGIN' 'LOG_SUM_OF_SPEC_OVER_MARGIN' 'LOG_PROD_OF_FRO_OVER_MARGIN' 'LOG_SUM_OF_FRO_OVER_MARGIN' 'PATH_NORM_OVER_MARGIN' 'PACBAYES_INIT' 'PACBAYES_ORIG' 'PACBAYES_FLATNESS' 'PACBAYES_MAG_INIT' 'PACBAYES_MAG_ORIG' 'PACBAYES_MAG_FLATNESS' 31 | do 32 | sbatch scripts/hyperparameter_correlation.sh $metric $bleu_type $group power_law $fitting_method $dataset $calculation_parameter $size_param $adjust_measure 33 | done 34 | 35 | # This is for TPL 36 | for metric in 'E_TPL_lambda' 'E_TPL_KS_distance' 'E_TPL_beta' 'alpha_weighted' 'log_alpha_norm' 37 | do 38 | sbatch scripts/hyperparameter_correlation.sh $metric $bleu_type $group truncated_power_law $fitting_method $dataset $calculation_parameter $size_param $adjust_measure 39 | done 40 | 41 | 42 | # This is for EXP 43 | for metric in 'EXP_lambda' 44 | do 45 | sbatch scripts/hyperparameter_correlation.sh $metric $bleu_type $group exponential $fitting_method $dataset $calculation_parameter $size_param $adjust_measure 46 | done 47 | done 48 | done 49 | done 50 | done 51 | done -------------------------------------------------------------------------------- /scripts/slurm_compute_ww.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --array=1-200 3 | #SBATCH -p rise # partition (queue) 4 | #SBATCH -N 1 # number of nodes requested 5 | #SBATCH -n 1 # number of tasks (i.e. processes) 6 | #SBATCH --cpus-per-task=2 # number of cores per task 7 | #SBATCH --nodelist=havoc 8 | ##SBATCH --exclude=blaze,flaminio,freddie,r[1-6,8-16],havoc,steropes,atlas 9 | #SBATCH -t 2-00:00 # time requested (D-HH:MM) 10 | #SBATCH -D /data/yyaoqing/Generalization_metrics_for_NLP/ 11 | #SBATCH -o slurm_logs/slurm.%N.%j..out # STDOUT 12 | #SBATCH -e slurm_logs/slurm.%N.%j..err # STDERR 13 | pwd 14 | hostname 15 | date 16 | echo starting job... 17 | source ~/.bashrc 18 | conda activate NLP_metrics 19 | source normalize_powerlaw.sh 20 | export PYTHONUNBUFFERED=1 21 | export OMP_NUM_THREADS=1 22 | 23 | #distribution=truncated_power_law 24 | #result_file=results.pkl 25 | #config_file=ww_tpl_config 26 | 27 | #distribution=exponential 28 | #result_file=results_exponential.pkl 29 | #config_file=ww_exponential_config 30 | 31 | distribution=power_law 32 | result_file=results_original_alpha.pkl 33 | config_file=ww_pl_config 34 | 35 | cfg=$(sed -n "$SLURM_ARRAY_TASK_ID"p scripts/"$config_file".txt) 36 | sample=$(echo $cfg | cut -f 1 -d ' ') 37 | depth=$(echo $cfg | cut -f 2 -d ' ') 38 | width=$(echo $cfg | cut -f 3 -d ' ') 39 | lr=$(echo $cfg | cut -f 4 -d ' ') 40 | dropout=$(echo $cfg | cut -f 5 -d ' ') 41 | head=$(echo $cfg | cut -f 6 -d ' ') 42 | 43 | CKPTPATH=/data/yyaoqing/Generalization_metrics_for_NLP/checkpoint/WMT14_sample"$sample"_depth"$depth"_width"$width"_lr"$lr"_dropout"$dropout" 44 | echo $CKPTPATH 45 | #mkdir $CKPTPATH 46 | 47 | python compute_ww.py \ 48 | $CKPTPATH $CKPTPATH \ 49 | --result-suffix $result_file \ 50 | --width $width \ 51 | --dataset WMT \ 52 | --num-samples $sample \ 53 | --num-layers $depth \ 54 | --mp-fit --randomize \ 55 | --distribution $distribution \ 56 | --num-epochs 20 \ 57 | --starting-epoch 1 \ 58 | --num-heads $head & 59 | 60 | wait 61 | date 62 | -------------------------------------------------------------------------------- /scripts/slurm_eval_bleu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --array=1-200 3 | #SBATCH -p rise # partition (queue) 4 | #SBATCH -N 1 # number of nodes requested 5 | #SBATCH -n 1 # number of tasks (i.e. processes) 6 | #SBATCH --cpus-per-task=8 # number of cores per task 7 | #SBATCH --gres=gpu:1 # number of GPUs (should match -n) 8 | ##SBATCH --nodelist=ace,manchester,bombe,como,pavia,luigi,zanino # if you need specific nodes 9 | #SBATCH --exclude=blaze,flaminio,freddie,r[1-6,8-16],havoc,steropes,atlas,zanino,luigi,como,pavia,ace,bombe 10 | #SBATCH -t 2-00:00 # time requested (D-HH:MM) 11 | #SBATCH -D /data/yyaoqing/Generalization_metrics_for_NLP/ 12 | #SBATCH -o slurm_logs/slurm.%N.%j..out # STDOUT 13 | #SBATCH -e slurm_logs/slurm.%N.%j..err # STDERR 14 | pwd 15 | hostname 16 | date 17 | echo starting job... 18 | source ~/.bashrc 19 | conda activate NLP_metrics 20 | export PYTHONUNBUFFERED=1 21 | 22 | 23 | cfg=$(sed -n "$SLURM_ARRAY_TASK_ID"p scripts/bleu_config.txt) 24 | sample=$(echo $cfg | cut -f 1 -d ' ') 25 | depth=$(echo $cfg | cut -f 2 -d ' ') 26 | width=$(echo $cfg | cut -f 3 -d ' ') 27 | lr=$(echo $cfg | cut -f 4 -d ' ') 28 | dropout=$(echo $cfg | cut -f 5 -d ' ') 29 | head=$(echo $cfg | cut -f 6 -d ' ') 30 | 31 | CKPTPATH=/data/yyaoqing/Generalization_metrics_for_NLP/checkpoint/WMT14_sample"$sample"_depth"$depth"_width"$width"_lr"$lr"_dropout"$dropout" 32 | echo $CKPTPATH 33 | #mkdir $CKPTPATH 34 | 35 | srun -N 1 -n 1 python eval_bleu_loss.py \ 36 | --checkpoint_dir $CKPTPATH \ 37 | --max_batches 200 \ 38 | --dataset WMT \ 39 | --num_epochs 20 \ 40 | --starting_epoch 1 \ 41 | --num-heads $head \ 42 | --embedding-dimension $width 43 | 44 | wait 45 | date 46 | -------------------------------------------------------------------------------- /scripts/slurm_robust_measures.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --array=1-200 3 | #SBATCH -p rise # partition (queue) 4 | #SBATCH -N 1 # number of nodes requested 5 | #SBATCH -n 1 # number of tasks (i.e. processes) 6 | #SBATCH --cpus-per-task=8 # number of cores per task 7 | #SBATCH --gres=gpu:1 # number of GPUs (should match -n) 8 | ##SBATCH --nodelist=ace,manchester,bombe,como,pavia,luigi,zanino # if you need specific nodes 9 | #SBATCH --exclude=blaze,flaminio,freddie,r[1-6,8-16],havoc,steropes,atlas,zanino,luigi,como 10 | #SBATCH -t 2-00:00 # time requested (D-HH:MM) 11 | #SBATCH -D /data/yyaoqing/Generalization_metrics_for_NLP/ 12 | #SBATCH -o slurm_logs/slurm.%N.%j..out # STDOUT 13 | #SBATCH -e slurm_logs/slurm.%N.%j..err # STDERR 14 | pwd 15 | hostname 16 | date 17 | echo starting job... 18 | source ~/.bashrc 19 | conda activate NLP_metrics 20 | export PYTHONUNBUFFERED=1 21 | 22 | 23 | cfg=$(sed -n "$SLURM_ARRAY_TASK_ID"p scripts/robust_config.txt) 24 | sample=$(echo $cfg | cut -f 1 -d ' ') 25 | depth=$(echo $cfg | cut -f 2 -d ' ') 26 | width=$(echo $cfg | cut -f 3 -d ' ') 27 | lr=$(echo $cfg | cut -f 4 -d ' ') 28 | dropout=$(echo $cfg | cut -f 5 -d ' ') 29 | head=$(echo $cfg | cut -f 6 -d ' ') 30 | 31 | CKPTPATH=/data/yyaoqing/Generalization_metrics_for_NLP/checkpoint/WMT14_sample"$sample"_depth"$depth"_width"$width"_lr"$lr"_dropout"$dropout" 32 | echo $CKPTPATH 33 | #mkdir $CKPTPATH 34 | 35 | srun -N 1 -n 1 python test_measures_collections.py \ 36 | $CKPTPATH \ 37 | --result_suffix "robust_measures.pkl" \ 38 | --num-epochs 20 \ 39 | --width $width \ 40 | --dataset WMT \ 41 | --num-samples $sample \ 42 | --calculate_margin \ 43 | --calculate_pac_bayes \ 44 | --num-layers $depth \ 45 | --num-heads $head & 46 | 47 | wait 48 | date 49 | -------------------------------------------------------------------------------- /scripts/slurm_train_models.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --array=1-200 3 | #SBATCH -p rise # partition (queue) 4 | #SBATCH -N 1 # number of nodes requested 5 | #SBATCH -n 1 # number of tasks (i.e. processes) 6 | #SBATCH --cpus-per-task=8 # number of cores per task 7 | #SBATCH --gres=gpu:1 # number of GPUs (should match -n) 8 | ##SBATCH --nodelist=ace,manchester,bombe,como,pavia,luigi,zanino # if you need specific nodes 9 | #SBATCH --exclude=blaze,flaminio,freddie,r[1-6,8-16],havoc,steropes,atlas,zanino,como,luigi,pavia 10 | #SBATCH -t 7-00:00 # time requested (D-HH:MM) 11 | #SBATCH -D /data/yyaoqing/Generalization_metrics_for_NLP/ 12 | #SBATCH -o slurm_logs/slurm.%N.%j..out # STDOUT 13 | #SBATCH -e slurm_logs/slurm.%N.%j..err # STDERR 14 | pwd 15 | hostname 16 | date 17 | echo starting job... 18 | source ~/.bashrc 19 | conda activate NLP_metrics 20 | export PYTHONUNBUFFERED=1 21 | 22 | cfg=$(sed -n "$SLURM_ARRAY_TASK_ID"p scripts/train_config.txt) 23 | sample=$(echo $cfg | cut -f 1 -d ' ') 24 | depth=$(echo $cfg | cut -f 2 -d ' ') 25 | width=$(echo $cfg | cut -f 3 -d ' ') 26 | lr=$(echo $cfg | cut -f 4 -d ' ') 27 | dropout=$(echo $cfg | cut -f 5 -d ' ') 28 | head=$(echo $cfg | cut -f 6 -d ' ') 29 | 30 | CKPTPATH=/data/yyaoqing/Generalization_metrics_for_NLP/checkpoint/WMT14_sample"$sample"_depth"$depth"_width"$width"_lr"$lr"_dropout"$dropout" 31 | echo $CKPTPATH 32 | #mkdir $CKPTPATH 33 | 34 | # Note here we changed the embedding factor dimension!! 35 | 36 | srun -N 1 -n 1 python training_script.py \ 37 | --num_of_epochs 20 \ 38 | --dataset_name WMT14 \ 39 | --language_direction G2E \ 40 | --subsampling --num-samples $sample \ 41 | --embedding-dimension $width \ 42 | --num-heads $head \ 43 | --num-layers $depth \ 44 | --lr-inverse-dim \ 45 | --lr-factor $lr \ 46 | --max-gradient-steps 100000000 \ 47 | --dropout $dropout \ 48 | --checkpoint-path $CKPTPATH \ 49 | 1>$CKPTPATH/log_0.txt \ 50 | 2>$CKPTPATH/err_0.txt & 51 | 52 | wait 53 | date 54 | -------------------------------------------------------------------------------- /time_wise_correlation.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This file calculates time-wise correlation between BLEU scores and metrics. 3 | ''' 4 | import argparse, pickle, json, os 5 | import pandas as pd 6 | import numpy as np 7 | from scipy.stats import spearmanr 8 | from metrics import METRIC_FILES 9 | import pickle 10 | 11 | 12 | def get_corr_df(metrics_df): 13 | ''' 14 | Correlations for a single checkpoint 15 | ''' 16 | correlations = [] 17 | for metric, _ in METRIC_FILES.items(): 18 | # TODO: If doing phases, need to change this to calculate correlation within each phase 19 | # corr, _ = spearmanr(metrics_df['ood_bleu'], metrics_df[metric]) 20 | corr, _ = spearmanr(metrics_df['id_bleu'], metrics_df[metric]) 21 | if metric == 'rand_distance': 22 | corr = corr * -1.0 23 | correlations.append((metric, corr)) 24 | 25 | data = list(zip(*correlations)) # list of length 3: element 0 is metric names, element 1 is metric types, element 2 is correlations 26 | corr_df = pd.DataFrame(data={ 27 | 'metric': data[0], 28 | 'correlation': data[1] 29 | }) 30 | return corr_df 31 | 32 | 33 | def get_metrics_df(checkpoint, bleu_type = 'test'): 34 | ''' 35 | Create a dataframe of metrics and BLEU scores for a given checkpoint directory 36 | ''' 37 | print(checkpoint) 38 | # Get ww metrics 39 | ww_metrics = {} # Key: metric, Value: list of values for that metric 40 | 41 | # Epochs are numbered 1-20 42 | EPOCHS = 20 43 | epochs = range(1, EPOCHS+1) 44 | 45 | # Load results 46 | FILE_PL = os.path.join(checkpoint, f"results_original_alpha.pkl") 47 | with open(FILE_PL, "rb") as file: 48 | results_PL = pickle.load(file) 49 | FILE_TPL = os.path.join(checkpoint, f"results.pkl") 50 | with open(FILE_TPL, "rb") as file: 51 | results_TPL = pickle.load(file) 52 | FILE_EXP = os.path.join(checkpoint, f"results_exponential.pkl") 53 | with open(FILE_EXP, "rb") as file: 54 | results_EXP = pickle.load(file) 55 | FILE_ROBUST = os.path.join(checkpoint, f"robust_measures.pkl") 56 | with open(FILE_ROBUST, "rb") as file: 57 | results_robust = pickle.load(file) 58 | 59 | for metric, _ in METRIC_FILES.items(): 60 | metric_vals = [] 61 | 62 | if METRIC_FILES[metric] == 'ww': 63 | if metric in ['PL_alpha', 'rand_distance', 'mp_softrank', 'PL_KS_distance', 'alpha_weighted', 'log_alpha_norm', 'stable_rank']: 64 | for epoch in epochs: 65 | results_metrics = results_PL[epoch] 66 | if metric == 'PL_alpha': 67 | #if 'alpha' in results_metrics['details']: 68 | metric_vals.append(results_metrics['details']['alpha'].mean()) # averaging over layers 69 | elif metric == 'PL_KS_distance': 70 | metric_vals.append(results_metrics['details']['D'].mean()) 71 | elif metric in d[epoch]['details']: 72 | metric_vals.append(results_metrics['details'][metric].mean()) 73 | else: 74 | # Fill in missing metrics with null (not all checkpoints have all metrics calculated) 75 | metric_vals.append(np.nan) 76 | print(f"{FILE_PL}\n\tepoch {epoch} missing {metric}") 77 | elif metric == 'EXP_lambda': 78 | d = results_EXP 79 | for epoch in epochs: 80 | metric_vals.append(d[epoch]['details']['exponent'].mean()) 81 | else: 82 | d = results_TPL 83 | for epoch in epochs: 84 | if metric == 'E_TPL_KS_distance': 85 | metric_vals.append(d[epoch]['details']['D'].mean()) 86 | elif metric == 'E_TPL_beta': 87 | metric_vals.append(d[epoch]['details']['alpha'].mean()) 88 | elif metric == 'E_TPL_lambda': 89 | metric_vals.append(d[epoch]['details']['exponent'].mean()) 90 | elif metric in d[epoch]['details']: 91 | metric_vals.append(d[epoch]['details'][metric].mean()) 92 | else: 93 | metric_vals.append(np.nan) 94 | print(f"{FILE_TPL}\n\tepoch {epoch} missing {metric}") 95 | 96 | elif METRIC_FILES[metric] == 'robust': 97 | margin_metrics = results_robust 98 | for epoch in epochs: 99 | if metric in margin_metrics[epoch]: 100 | metric_vals.append(margin_metrics[epoch][metric]) 101 | else: 102 | # Fill in missing metrics with null (not all checkpoints have all metrics calculated) 103 | metric_vals.append(np.nan) 104 | print(f"{FILE_ROBUST}\n\tepoch {epoch} missing {metric}") 105 | 106 | else: 107 | print(f"{metric} not found") 108 | 109 | ww_metrics[metric] = metric_vals 110 | 111 | # Get BLEU scores 112 | id_bleu_scores, ood_bleu_scores = [], [] 113 | FILE = os.path.join(checkpoint, "bleu_loss.jsonl") 114 | 115 | EPOCH = 1 # Epochs are numbered 1-20 116 | with (open(FILE, "rb")) as file: 117 | for line in file: 118 | d = json.loads(line) 119 | # Multiply BLEU by -1 because we are computing correlations between BLEU 120 | # and generalization metrics for which lower values are better 121 | if bleu_type == 'test': 122 | id_bleu_scores.append(d[f'epoch{EPOCH}_id_bleu_score'] * 100 * -1.0) 123 | ood_bleu_scores.append(d[f'epoch{EPOCH}_ood_bleu_score'] * 100 * -1.0) 124 | elif bleu_type == 'gap': 125 | #TODO: Results for OOD generalization gap 126 | id_bleu_scores.append((d[f'epoch{EPOCH}_id_train_bleu_score'] - d[f'epoch{EPOCH}_id_bleu_score'])* 100) 127 | ood_bleu_scores.append(d[f'epoch{EPOCH}_ood_bleu_score'] * 100 * -1.0) 128 | else: 129 | raise ValueError('Bleu type not implemented.') 130 | EPOCH += 1 131 | ### 132 | 133 | assert len(ww_metrics['log_spectral_norm']) == len(id_bleu_scores) == len(ood_bleu_scores) 134 | 135 | # Create a dataframe 136 | data={'epoch': list(range(1, EPOCHS+1)), 'id_bleu': id_bleu_scores, 'ood_bleu': ood_bleu_scores} 137 | data.update(ww_metrics) 138 | df = pd.DataFrame(data=data) 139 | ### 140 | 141 | return df 142 | 143 | 144 | if __name__ == "__main__": 145 | parser = argparse.ArgumentParser() 146 | parser.add_argument("--id", type=str, default="WMT") 147 | parser.add_argument("--bleu_type", type=str, default='test', choices=['test', 'gap']) 148 | #TODO: update the WW results using WeightWatcher 0.5.6 149 | #parser.add_argument("--reproduce", action='store_true') 150 | 151 | args = parser.parse_args() 152 | ood = 'WMT' if args.id == 'IWSLT' else 'IWSLT' 153 | 154 | # Plot correlations across all experiments 155 | from experiments_time_wise import EXPERIMENTS 156 | exps = EXPERIMENTS[f"{args.id}"] #+ EXPERIMENTS[ood] 157 | all_metrics = [get_metrics_df(exp, args.bleu_type) for exp in exps] 158 | corr_dfs = [get_corr_df(metric_df) for metric_df in all_metrics] 159 | all_corrs = pd.concat(corr_dfs) 160 | 161 | rank_correlations_aggregated = {} 162 | 163 | # Converting all results into an aggregated array 164 | for key, val in zip(all_corrs['metric'].values, all_corrs['correlation'].values): 165 | if key not in rank_correlations_aggregated: 166 | rank_correlations_aggregated[key] = [val] 167 | else: 168 | rank_correlations_aggregated[key].append(val) 169 | 170 | # Remove nan's which are failed measurements 171 | for key in rank_correlations_aggregated.keys(): 172 | rank_correlations_aggregated[key] = [x for x in rank_correlations_aggregated[key] if not np.isnan(x)] 173 | 174 | with open(f'results/plot_results_{args.bleu_type}_Simpson_{args.id}.pkl', 'wb') as f: 175 | pickle.dump(rank_correlations_aggregated, f) 176 | 177 | #pickle.dump(all_corrs, open(f'results/Simpson_correlation_{args.id}_{args.bleu_type}.pkl', "wb")) -------------------------------------------------------------------------------- /training_script.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from "Pytorch Original Transformer" by Aleksa Gordić 3 | https://github.com/gordicaleksa/pytorch-original-transformer 4 | """ 5 | 6 | import argparse 7 | import time 8 | 9 | 10 | import torch 11 | from torch import nn 12 | from torch.optim import Adam 13 | 14 | from utils.optimizers_and_distributions import CustomLRAdamOptimizer, LabelSmoothingDistribution 15 | from models.definitions.transformer_model import Transformer 16 | from utils.data_utils import get_data_loaders, get_masks_and_count_tokens, get_src_and_trg_batches, DatasetType, LanguageDirection 17 | import utils.utils as utils 18 | from utils.constants import * 19 | 20 | 21 | # Global vars for logging purposes 22 | num_of_trg_tokens_processed = 0 23 | bleu_scores = [] 24 | global_train_step, global_val_step = [0, 0] 25 | #writer = SummaryWriter() # (tensorboard) writer will output to ./runs/ directory by default 26 | best_val_loss = None 27 | #import weightwatcher as ww 28 | import wandb 29 | 30 | # Simple decorator function so that I don't have to pass these arguments every time I call get_train_val_loop 31 | def get_train_val_loop(baseline_transformer, custom_lr_optimizer, kl_div_loss, label_smoothing, pad_token_id, time_start, training_config, save_ckpt=True): 32 | 33 | def train_val_loop(is_train, token_ids_loader, epoch): 34 | #global num_of_trg_tokens_processed, global_train_step, global_val_step, writer 35 | 36 | global num_of_trg_tokens_processed, global_train_step, global_val_step, best_val_loss 37 | 38 | if is_train: 39 | baseline_transformer.train() 40 | else: 41 | baseline_transformer.eval() 42 | 43 | device = next(baseline_transformer.parameters()).device 44 | 45 | # 46 | # Main loop - start of the CORE PART 47 | # 48 | 49 | validation_loss = 0 50 | val_step = 0 51 | 52 | 53 | for batch_idx, token_ids_batch in enumerate(token_ids_loader): 54 | src_token_ids_batch, trg_token_ids_batch_input, trg_token_ids_batch_gt = get_src_and_trg_batches(token_ids_batch) 55 | src_mask, trg_mask, num_src_tokens, num_trg_tokens = get_masks_and_count_tokens(src_token_ids_batch, trg_token_ids_batch_input, pad_token_id, device) 56 | 57 | # log because the KL loss expects log probabilities (just an implementation detail) 58 | predicted_log_distributions = baseline_transformer(src_token_ids_batch, trg_token_ids_batch_input, src_mask, trg_mask) 59 | smooth_target_distributions = label_smoothing(trg_token_ids_batch_gt) # these are regular probabilities 60 | 61 | if is_train: 62 | custom_lr_optimizer.zero_grad() # clean the trainable weights gradients in the computational graph 63 | 64 | loss = kl_div_loss(predicted_log_distributions, smooth_target_distributions) 65 | 66 | if is_train: 67 | loss.backward() # compute the gradients for every trainable weight in the computational graph 68 | custom_lr_optimizer.step() # apply the gradients to weights 69 | 70 | 71 | if batch_idx%training_config['sharpness_frequency']==0 and training_config['sharpness_transform'] and training_config['sharpness_perbatch']: 72 | Sharpness_transform(baseline_transformer, training_config) 73 | 74 | # End of CORE PART 75 | 76 | # 77 | # Logging and metrics 78 | # 79 | 80 | if is_train: 81 | global_train_step += 1 82 | num_of_trg_tokens_processed += num_trg_tokens 83 | 84 | #if training_config['enable_tensorboard']: 85 | # writer.add_scalar('training_loss', loss.item(), global_train_step) 86 | 87 | training_loss = loss.item() 88 | if training_config['console_log_freq'] is not None and batch_idx % training_config['console_log_freq'] == 0: 89 | print(f'Transformer training: time elapsed= {(time.time() - time_start):.2f} [s] ' 90 | f'| epoch={epoch + 1} | batch= {batch_idx + 1} ' 91 | f'| target tokens/batch= {num_of_trg_tokens_processed / training_config["console_log_freq"]} ' 92 | f'| global training step= {global_train_step} ' 93 | f'| training loss= {training_loss}') 94 | 95 | num_of_trg_tokens_processed = 0 96 | 97 | # Save model checkpoint 98 | if training_config['checkpoint_freq'] is not None and (epoch + 1) % training_config['checkpoint_freq'] == 0 and batch_idx == 0: 99 | 100 | ckpt_name = os.path.join(args.checkpoint_path, f'net_epoch_{(epoch+1)}{args.checkpoint_suffix}.ckpt') 101 | torch.save(utils.get_training_state(training_config, baseline_transformer), ckpt_name) 102 | else: 103 | val_step += 1 104 | validation_loss += loss.item() 105 | 106 | #if training_config['enable_tensorboard']: 107 | # writer.add_scalar('val_loss', loss.item(), global_val_step) 108 | 109 | if not is_train: 110 | 111 | validation_loss = validation_loss/val_step 112 | 113 | print('-'*30) 114 | print(f'Validation loss at epoch={epoch + 1} is {validation_loss}') 115 | print('-'*30) 116 | 117 | wandb.log({'Validation_loss': validation_loss}) 118 | 119 | if save_ckpt and (not best_val_loss or best_val_loss > validation_loss): 120 | best_val_loss = validation_loss 121 | print("The newly trained model has better validation loss. Save this model!") 122 | 123 | ckpt_name = os.path.join(args.checkpoint_path, f'net_exp_{args.exp_ind}{args.checkpoint_suffix}_best.ckpt') 124 | torch.save(utils.get_training_state(training_config, baseline_transformer), ckpt_name) 125 | 126 | 127 | return train_val_loop 128 | 129 | 130 | def Sharpness_transform(model, config=None): 131 | 132 | """ 133 | eps=1e-8 134 | model = model.cpu() 135 | watcher = ww.WeightWatcher(model=model) 136 | sharper_model = watcher.SVDSharpness(model=model, layers=[329]) 137 | 138 | # This part might need to be replaced with changing layer weights 139 | 140 | weight1 = sharper_model.decoder.decoder_layers[5].pointwise_net.linear2.weight.data.float().detach() 141 | weight2 = model.decoder.decoder_layers[5].pointwise_net.linear2.weight.data.detach() 142 | 143 | if (weight1-weight2).norm() > eps: 144 | print("Get spikes and appiled the Sharpness transform!") 145 | 146 | model.decoder.decoder_layers[5].pointwise_net.linear2.weight.data = weight1 147 | model = model.cuda() 148 | """ 149 | return 150 | 151 | 152 | def train_transformer(training_config): 153 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # checking whether you have a GPU, I hope so! 154 | 155 | # Step 1: Prepare data loaders 156 | train_token_ids_loader, val_token_ids_loader, src_field_processor, trg_field_processor = get_data_loaders( 157 | training_config['dataset_path'], 158 | training_config['language_direction'], 159 | training_config['dataset_name'], 160 | training_config['batch_size'], 161 | device, 162 | subsampling=args.subsampling, 163 | num_samples=args.num_samples) 164 | 165 | pad_token_id = src_field_processor.vocab.stoi[PAD_TOKEN] # pad token id is the same for target as well 166 | src_vocab_size = len(src_field_processor.vocab) 167 | trg_vocab_size = len(trg_field_processor.vocab) 168 | 169 | # Step 2: Prepare the model (original transformer) and push to GPU 170 | baseline_transformer = Transformer( 171 | model_dimension=args.embedding_dimension, 172 | src_vocab_size=src_vocab_size, 173 | trg_vocab_size=trg_vocab_size, 174 | number_of_heads=args.num_heads, 175 | number_of_layers=args.num_layers, 176 | dropout_probability=args.dropout, 177 | ).to(device) 178 | #embedding_factor_dimension=args.embedding_factor_dimension 179 | 180 | # Step 3: Prepare other training related utilities 181 | kl_div_loss = nn.KLDivLoss(reduction='batchmean') # gives better BLEU score than "mean" 182 | 183 | # Makes smooth target distributions as opposed to conventional one-hot distributions 184 | # My feeling is that this is a really dummy and arbitrary heuristic but time will tell. 185 | label_smoothing = LabelSmoothingDistribution(BASELINE_MODEL_LABEL_SMOOTHING_VALUE, pad_token_id, trg_vocab_size, device) 186 | 187 | # Check out playground.py for an intuitive visualization of how the LR changes with time/training steps, easy stuff. 188 | custom_lr_optimizer = CustomLRAdamOptimizer( 189 | Adam(baseline_transformer.parameters(), betas=(0.9, 0.98), eps=1e-9), 190 | args.embedding_dimension, 191 | training_config['num_warmup_steps'], 192 | lr_inverse_dim=args.lr_inverse_dim, 193 | constant_lr =args.constant_lr, 194 | lr_factor = args.lr_factor 195 | ) 196 | 197 | wandb.init(name = args.checkpoint_path + '_train') 198 | 199 | # The decorator function makes things cleaner since there is a lot of redundancy between the train and val loops 200 | train_val_loop = get_train_val_loop(baseline_transformer, custom_lr_optimizer, kl_div_loss, label_smoothing, pad_token_id, time.time(), training_config) 201 | 202 | # Save the initial checkpoint and evaluate it 203 | if not os.path.exists(args.checkpoint_path): 204 | os.makedirs(args.checkpoint_path) 205 | ckpt_name = os.path.join(args.checkpoint_path, f'net_epoch_0{args.checkpoint_suffix}.ckpt') 206 | torch.save(utils.get_training_state(training_config, baseline_transformer), ckpt_name) 207 | 208 | with torch.no_grad(): 209 | train_val_loop(is_train=False, token_ids_loader=val_token_ids_loader, epoch=-1) 210 | 211 | bleu_score = utils.calculate_bleu_score(baseline_transformer, val_token_ids_loader, trg_field_processor) 212 | #if training_config['enable_tensorboard']: 213 | # writer.add_scalar('bleu_score', bleu_score, epoch) 214 | print('-'*30) 215 | print(f'BLEU score at epoch=0 is {bleu_score}') 216 | print('-'*30) 217 | wandb.log({'BLEU_score': bleu_score}) 218 | 219 | 220 | # Step 4: Start the training 221 | for epoch in range(training_config['num_of_epochs']): 222 | 223 | # Training loop 224 | train_val_loop(is_train=True, token_ids_loader=train_token_ids_loader, epoch=epoch) 225 | 226 | # Apply Sharpness transform 227 | #if epoch>10 and training_config['sharpness_transform']: 228 | if training_config['sharpness_transform']: 229 | Sharpness_transform(baseline_transformer, training_config) 230 | 231 | # Validation loop 232 | with torch.no_grad(): 233 | train_val_loop(is_train=False, token_ids_loader=val_token_ids_loader, epoch=epoch) 234 | 235 | bleu_score = utils.calculate_bleu_score(baseline_transformer, val_token_ids_loader, trg_field_processor) 236 | #if training_config['enable_tensorboard']: 237 | # writer.add_scalar('bleu_score', bleu_score, epoch) 238 | print('-'*30) 239 | print(f'BLEU score at epoch={epoch + 1} is {bleu_score}') 240 | print('-'*30) 241 | wandb.log({'BLEU_score': bleu_score}) 242 | 243 | if global_train_step > args.max_gradient_steps: 244 | print("Enough training! Saving model and exit.") 245 | 246 | ckpt_name = os.path.join(args.checkpoint_path, f'net_exp_{args.exp_ind}{args.checkpoint_suffix}.ckpt') 247 | torch.save(utils.get_training_state(training_config, baseline_transformer), ckpt_name) 248 | break 249 | 250 | else: 251 | print(f"global training step {global_train_step} is not reached. Continue.") 252 | 253 | # Save the latest transformer in the binaries directory 254 | # torch.save(utils.get_training_state(training_config, baseline_transformer), os.path.join(BINARIES_PATH, utils.get_available_binary_name())) 255 | 256 | 257 | if __name__ == "__main__": 258 | # 259 | # Fixed args - don't change these unless you have a good reason 260 | # 261 | num_warmup_steps = 4000 262 | 263 | # 264 | # Modifiable args - feel free to play with these (only small subset is exposed by design to avoid cluttering) 265 | # 266 | parser = argparse.ArgumentParser() 267 | # According to the paper I infered that the baseline was trained for ~19 epochs on the WMT-14 dataset and I got 268 | # nice returns up to epoch ~20 on IWSLT as well (nice round number) 269 | # From Yaoqing: This epoch number has been changed to 200 because we subsample about 10% of original data 270 | # If we only use 10% data, then we should scale up the number of epochs to get the same number of gradient steps 271 | parser.add_argument("--num_of_epochs", type=int, help="number of training epochs", default=200) 272 | #parser.add_argument("--num_of_epochs", type=int, help="number of training epochs", default=20) 273 | # You should adjust this for your particular machine (I have RTX 2080 with 8 GBs of VRAM so 1500 fits nicely!) 274 | parser.add_argument("--batch_size", type=int, help="target number of tokens in a src/trg batch", default=1500) 275 | 276 | # Data related args 277 | parser.add_argument("--dataset_name", choices=[el.name for el in DatasetType], help='which dataset to use for training', default=DatasetType.IWSLT.name) 278 | parser.add_argument("--language_direction", choices=[el.name for el in LanguageDirection], help='which direction to translate', default=LanguageDirection.E2G.name) 279 | parser.add_argument("--dataset_path", type=str, help='download dataset to this path', default=DATA_DIR_PATH) 280 | 281 | # Logging/debugging/checkpoint related (helps a lot with experimentation) 282 | # parser.add_argument("--enable_tensorboard", type=bool, help="enable tensorboard logging", default=True) 283 | parser.add_argument("--console_log_freq", type=int, help="log to output console (batch) freq", default=10) 284 | parser.add_argument("--checkpoint_freq", type=int, help="checkpoint model saving (epoch) freq", default=1) 285 | 286 | # Some parameters to cover the previous training setup 287 | parser.add_argument("--subsampling", help="subsample data", action="store_true") 288 | parser.add_argument("--lr-inverse-dim", help="use the heuristic of inverse learning rate proportional to the model dimension", action="store_true") 289 | parser.add_argument("--constant-lr", help="use constant learning rate", action="store_true") 290 | 291 | parser.add_argument("--num-samples", type=int, help="number of samples (src-tar pairs)", default=1000) 292 | parser.add_argument("--checkpoint-path", type=str, help="checkpoint model saving path", default=CHECKPOINTS_PATH) 293 | parser.add_argument("--checkpoint-suffix", type=str, help="checkpoint model saving name", default="") 294 | parser.add_argument("--exp-ind", type=int, help="index of the training", default=0) 295 | parser.add_argument("--max-gradient-steps", type=int, help="number of gradient steps to train", default=80000) 296 | parser.add_argument("--embedding-dimension", type=int, help="the dimension to save a checkpoint", default=BASELINE_MODEL_DIMENSION) 297 | parser.add_argument("--lr-factor", type=float, help="factor to adjust the inverse dim lr", default=1.0) 298 | parser.add_argument("--dropout", type=float, help="dropout probability", default=0.0) 299 | parser.add_argument("--num-layers", type=int, help="number of Transformer layers", default=6) 300 | parser.add_argument("--num-heads", type=int, help="number of Transformer layers", default=BASELINE_MODEL_NUMBER_OF_HEADS) 301 | 302 | # Some parameters to change the Sharpness transform 303 | parser.add_argument("--sharpness-transform", help="apply sharpness transform?", action="store_true") 304 | parser.add_argument("--sharpness-perbatch", help="apply sharpness transform to each batch?", action="store_true") 305 | parser.add_argument("--sharpness-frequency", type=int, help="how many batches should we apply the transform", default=100) 306 | #parser.add_argument("--embedding-factor-dimension", type=float, help="Should we fix the embedding factor dimension?", default=None) 307 | 308 | args = parser.parse_args() 309 | 310 | # Wrapping training configuration into a dictionary 311 | training_config = dict() 312 | for arg in vars(args): 313 | training_config[arg] = getattr(args, arg) 314 | training_config['num_warmup_steps'] = num_warmup_steps 315 | 316 | print(args.checkpoint_path) 317 | 318 | # Train the original transformer model 319 | train_transformer(training_config) 320 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nsfzyzz/Generalization_metrics_for_NLP/e7a991f8baa15a1651a53016cff532d9d99256fc/utils/__init__.py -------------------------------------------------------------------------------- /utils/constants.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | #BASELINE_MODEL_NUMBER_OF_LAYERS = 6 5 | BASELINE_MODEL_DIMENSION = 512 6 | BASELINE_MODEL_NUMBER_OF_HEADS = 8 7 | #BASELINE_MODEL_DROPOUT_PROB = 0.1 8 | BASELINE_MODEL_DROPOUT_PROB = 0 9 | BASELINE_MODEL_LABEL_SMOOTHING_VALUE = 0.1 10 | 11 | 12 | #BIG_MODEL_NUMBER_OF_LAYERS = 6 13 | BIG_MODEL_DIMENSION = 1024 14 | BIG_MODEL_NUMBER_OF_HEADS = 16 15 | BIG_MODEL_DROPOUT_PROB = 0.3 16 | BIG_MODEL_LABEL_SMOOTHING_VALUE = 0.1 17 | 18 | 19 | CHECKPOINTS_PATH = os.path.join(os.path.dirname(__file__), os.pardir, 'models', 'checkpoints') 20 | BINARIES_PATH = os.path.join(os.path.dirname(__file__), os.pardir, 'models', 'binaries') 21 | DATA_DIR_PATH = os.path.join(os.path.dirname(__file__), os.pardir, 'data') 22 | os.makedirs(CHECKPOINTS_PATH, exist_ok=True) 23 | os.makedirs(BINARIES_PATH, exist_ok=True) 24 | os.makedirs(DATA_DIR_PATH, exist_ok=True) 25 | 26 | 27 | BOS_TOKEN = '' 28 | EOS_TOKEN = '' 29 | PAD_TOKEN = "" 30 | -------------------------------------------------------------------------------- /utils/data_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from "Pytorch Original Transformer" by Aleksa Gordić 3 | https://github.com/gordicaleksa/pytorch-original-transformer 4 | """ 5 | 6 | import time 7 | import os 8 | import enum 9 | 10 | 11 | import torch 12 | from torchtext.data import Dataset, BucketIterator, Field, Example 13 | from torchtext.data.utils import interleave_keys 14 | from torchtext import datasets 15 | import spacy 16 | 17 | 18 | from .constants import BOS_TOKEN, EOS_TOKEN, PAD_TOKEN, DATA_DIR_PATH 19 | 20 | 21 | class DatasetType(enum.Enum): 22 | IWSLT = 0, 23 | WMT14 = 1 24 | 25 | 26 | class LanguageDirection(enum.Enum): 27 | E2G = 0, 28 | G2E = 1 29 | 30 | 31 | # 32 | # Caching mechanism datasets and functions (you don't need this but it makes things a lot faster!) 33 | # 34 | 35 | 36 | class FastTranslationDataset(Dataset): 37 | """ 38 | After understanding the source code of torch text's IWSLT, TranslationDataset and Dataset I realized how I 39 | can make data preparation much faster (tokenization was taking a lot of time and there is no need to redo it 40 | every time) by using a simple caching mechanism. 41 | 42 | This dataset leverages that caching mechanism which reduced loading time from ~70s -> 2.5s (massive!) 43 | 44 | """ 45 | 46 | @staticmethod 47 | def sort_key(ex): 48 | # What this does is basically it takes a 16-bit binary representation of lengths and interleaves them. 49 | # Example: lengths len(ex.src)=5 and len(ex.trg)=3 result in f(101, 011)=100111, 7 and 1 in f(111, 001)=101011 50 | # It's basically a heuristic that helps the BucketIterator sort bigger batches first 51 | return interleave_keys(len(ex.src), len(ex.trg)) 52 | 53 | def __init__(self, cache_path, fields, subsampling=False, num_samples = 1000, **kwargs): 54 | # save_cache interleaves src and trg examples so here we read the cache file having that format in mind 55 | cached_data = [line.split() for line in open(cache_path, encoding='utf-8')] 56 | 57 | cached_data_src = cached_data[0::2] # Even lines contain source examples 58 | cached_data_trg = cached_data[1::2] # Odd lines contain target examples 59 | 60 | if subsampling: 61 | print(f"Use only {num_samples} samples.") 62 | cached_data_src = cached_data_src[:num_samples] 63 | cached_data_trg = cached_data_trg[:num_samples] 64 | 65 | assert len(cached_data_src) == len(cached_data_trg), f'Source and target data should be of the same length.' 66 | 67 | examples = [] 68 | src_dataset_total_number_of_tokens = 0 69 | trg_dataset_total_number_of_tokens = 0 70 | for src_tokenized_data, trg_tokenized_data in zip(cached_data_src, cached_data_trg): 71 | ex = Example() 72 | 73 | setattr(ex, 'src', src_tokenized_data) 74 | setattr(ex, 'trg', trg_tokenized_data) 75 | 76 | examples.append(ex) 77 | 78 | # Update the number of tokens 79 | src_dataset_total_number_of_tokens += len(src_tokenized_data) 80 | trg_dataset_total_number_of_tokens += len(trg_tokenized_data) 81 | 82 | # Print relevant information about the dataset (parsing the cache file name) 83 | filename_parts = os.path.split(cache_path)[1].split('_') 84 | src_language, trg_language = ('English', 'German') if filename_parts[0] == 'en' else ('German', 'English') 85 | dataset_name = 'IWSLT' if filename_parts[2] == 'iwslt' else 'WMT-14' 86 | dataset_type = 'train' if filename_parts[3] == 'train' else 'val' 87 | print(f'{dataset_type} dataset ({dataset_name}) has {src_dataset_total_number_of_tokens} tokens in the source language ({src_language}) corpus.') 88 | print(f'{dataset_type} dataset ({dataset_name}) has {trg_dataset_total_number_of_tokens} tokens in the target language ({trg_language}) corpus.') 89 | 90 | # Call the parent class Dataset's constructor 91 | super().__init__(examples, fields, **kwargs) 92 | 93 | 94 | class DatasetWrapper(FastTranslationDataset): 95 | """ 96 | Just a wrapper around the FastTranslationDataset. 97 | 98 | """ 99 | 100 | @classmethod 101 | def get_train_and_val_datasets(cls, train_cache_path, val_cache_path, fields, subsampling=False, num_samples=1000, **kwargs): 102 | 103 | train_dataset = cls(train_cache_path, fields, subsampling=subsampling, num_samples=num_samples, **kwargs) 104 | val_dataset = cls(val_cache_path, fields, **kwargs) 105 | 106 | return train_dataset, val_dataset 107 | 108 | 109 | def save_cache(cache_path, dataset): 110 | with open(cache_path, 'w', encoding='utf-8') as cache_file: 111 | # Interleave source and target tokenized examples, source is on even lines, target is on odd lines 112 | for ex in dataset.examples: 113 | cache_file.write(' '.join(ex.src) + '\n') 114 | cache_file.write(' '.join(ex.trg) + '\n') 115 | 116 | 117 | # 118 | # End of caching mechanism utilities 119 | # 120 | 121 | 122 | def get_datasets_and_vocabs(dataset_path, language_direction, use_iwslt=True, use_caching_mechanism=True, subsampling=False, num_samples=1000, ood=False): 123 | german_to_english = language_direction == LanguageDirection.G2E.name 124 | spacy_de = spacy.load('de_core_news_sm') 125 | spacy_en = spacy.load('en_core_web_sm') 126 | 127 | def tokenize_de(text): 128 | return [tok.text for tok in spacy_de.tokenizer(text)] 129 | 130 | def tokenize_en(text): 131 | return [tok.text for tok in spacy_en.tokenizer(text)] 132 | 133 | # batch first set to true as my transformer is expecting that format (that's consistent with the format 134 | # used in computer vision), namely (B, C, H, W) -> batch size, number of channels, height and width 135 | src_tokenizer = tokenize_de if german_to_english else tokenize_en 136 | trg_tokenizer = tokenize_en if german_to_english else tokenize_de 137 | src_field_processor = Field(tokenize=src_tokenizer, pad_token=PAD_TOKEN, batch_first=True) 138 | trg_field_processor = Field(tokenize=trg_tokenizer, init_token=BOS_TOKEN, eos_token=EOS_TOKEN, pad_token=PAD_TOKEN, batch_first=True) 139 | 140 | fields = [('src', src_field_processor), ('trg', trg_field_processor)] 141 | MAX_LEN = 100 # filter out examples that have more than MAX_LEN tokens 142 | filter_pred = lambda x: len(x.src) <= MAX_LEN and len(x.trg) <= MAX_LEN 143 | 144 | # Only call once the splits function it is super slow as it constantly has to redo the tokenization 145 | prefix = 'de_en' if german_to_english else 'en_de' 146 | prefix += '_iwslt' if use_iwslt else '_wmt14' 147 | train_cache_path = os.path.join(dataset_path, f'{prefix}_train_cache.csv') 148 | val_cache_path = os.path.join(dataset_path, f'{prefix}_val_cache.csv') 149 | test_cache_path = os.path.join(dataset_path, f'{prefix}_test_cache.csv') 150 | 151 | # This simple caching mechanism gave me ~30x speedup on my machine! From ~70s -> ~2.5s! 152 | ts = time.time() 153 | if not use_caching_mechanism or not (os.path.exists(train_cache_path) and os.path.exists(val_cache_path)): 154 | # dataset objects have a list of examples where example is simply an empty Python Object that has 155 | # .src and .trg attributes which contain a tokenized list of strings (created by tokenize_en and tokenize_de). 156 | # It's that simple, we can consider our datasets as a table with 2 columns 'src' and 'trg' 157 | # each containing fields with tokenized strings from source and target languages 158 | src_ext = '.de' if german_to_english else '.en' 159 | trg_ext = '.en' if german_to_english else '.de' 160 | dataset_split_fn = datasets.IWSLT.splits if use_iwslt else datasets.WMT14.splits 161 | train_dataset, val_dataset, test_dataset = dataset_split_fn( 162 | exts=(src_ext, trg_ext), 163 | fields=fields, 164 | root=dataset_path, 165 | filter_pred=filter_pred 166 | ) 167 | 168 | save_cache(train_cache_path, train_dataset) 169 | save_cache(val_cache_path, val_dataset) 170 | save_cache(test_cache_path, test_dataset) 171 | else: 172 | # it's actually better to load from cache as we'll get rid of '\xa0', '\xa0 ' and '\x85' unicode characters 173 | # which we don't need and which SpaCy unfortunately includes as tokens. 174 | train_dataset, val_dataset = DatasetWrapper.get_train_and_val_datasets( 175 | train_cache_path, 176 | val_cache_path, 177 | fields, 178 | filter_pred=filter_pred, 179 | subsampling=subsampling, 180 | num_samples=num_samples 181 | ) 182 | 183 | print(f'Time it took to prepare the data: {time.time() - ts:3f} seconds.') 184 | 185 | MIN_FREQ = 2 186 | # __getattr__ implementation in the base Dataset class enables us to call .src on Dataset objects even though 187 | # we only have a list of examples in the Dataset object and the example itself had .src attribute. 188 | # Implementation will yield examples and call .src/.trg attributes on them (and those contain tokenized lists) 189 | if not ood: 190 | src_field_processor.build_vocab(train_dataset.src, min_freq=MIN_FREQ) 191 | trg_field_processor.build_vocab(train_dataset.trg, min_freq=MIN_FREQ) 192 | else: 193 | ### NEW CASE ADDED HERE ### 194 | # If loading an OOD dataset, build the vocabulary from the ID dataset. 195 | id_train_dataset, id_val_dataset, _, _ = get_datasets_and_vocabs( 196 | dataset_path=dataset_path, 197 | language_direction=language_direction, 198 | use_iwslt=not use_iwslt, 199 | use_caching_mechanism=use_caching_mechanism, 200 | subsampling=True, 201 | num_samples=num_samples, 202 | ood=False, 203 | ) 204 | src_field_processor.build_vocab(id_train_dataset.src, min_freq=MIN_FREQ) 205 | trg_field_processor.build_vocab(id_train_dataset.trg, min_freq=MIN_FREQ) 206 | 207 | return train_dataset, val_dataset, src_field_processor, trg_field_processor 208 | 209 | 210 | global longest_src_sentence, longest_trg_sentence 211 | 212 | 213 | def batch_size_fn(new_example, count, sofar): 214 | """ 215 | If we use this function in the BucketIterator the batch_size is no longer the number of examples/sentences 216 | in a batch but a number of tokens in a batch - which allows us to max out VRAM on a given GPU. 217 | 218 | Example: if we don't use this function and we set batch size to say 10 we will sometimes end up with 219 | a tensor of size (10, 100) because the longest sentence had a size of 100 tokens but other times we'll end 220 | up with a size of (10, 5) because the longest sentence had only 5 tokens! 221 | 222 | With this function what we do is we specify that source and target tensors can't go over a certain number 223 | of tokens like 1000. So usually either source or target tensors will contain around 1000 tokens and 224 | in worst case both will be really close to a 1000 tokens each. If that is still below max VRAM availabe on 225 | the system we're using the max potential of our GPU w.r.t. VRAM. 226 | 227 | Note: to understand this function you unfortunately would probably have to dig deeper into torch text's 228 | source code. 229 | 230 | """ 231 | global longest_src_sentence, longest_trg_sentence 232 | 233 | if count == 1: 234 | longest_src_sentence = 0 235 | longest_trg_sentence = 0 236 | 237 | longest_src_sentence = max(longest_src_sentence, len(new_example.src)) 238 | # 2 because of start/end of sentence tokens ( and ) 239 | longest_trg_sentence = max(longest_trg_sentence, len(new_example.trg) + 2) 240 | 241 | num_of_tokens_in_src_tensor = count * longest_src_sentence 242 | num_of_tokens_in_trg_tensor = count * longest_trg_sentence 243 | 244 | return max(num_of_tokens_in_src_tensor, num_of_tokens_in_trg_tensor) 245 | 246 | 247 | # https://github.com/pytorch/text/issues/536#issuecomment-719945594 <- there is a "bug" in BucketIterator i.e. it's 248 | # description is misleading as it won't group examples of similar length unless you set sort_within_batch to True! 249 | def get_data_loaders(dataset_path, language_direction, dataset_name, batch_size, device, subsampling=False, num_samples=1000, ood=False): 250 | train_dataset, val_dataset, src_field_processor, trg_field_processor = get_datasets_and_vocabs( 251 | dataset_path, 252 | language_direction, 253 | dataset_name == DatasetType.IWSLT.name, 254 | subsampling=subsampling, 255 | num_samples=num_samples, 256 | ood=ood, 257 | ) 258 | 259 | train_token_ids_loader, val_token_ids_loader = BucketIterator.splits( 260 | datasets=(train_dataset, val_dataset), 261 | batch_size=batch_size, 262 | device=device, 263 | sort_within_batch=True, # this part is really important otherwise we won't group similar length sentences 264 | batch_size_fn=batch_size_fn # this helps us max out GPU's VRAM 265 | ) 266 | 267 | return train_token_ids_loader, val_token_ids_loader, src_field_processor, trg_field_processor 268 | 269 | 270 | def get_masks_and_count_tokens_src(src_token_ids_batch, pad_token_id): 271 | batch_size = src_token_ids_batch.shape[0] 272 | 273 | # src_mask shape = (B, 1, 1, S) check out attention function in transformer_model.py where masks are applied 274 | # src_mask only masks pad tokens as we want to ignore their representations (no information in there...) 275 | src_mask = (src_token_ids_batch != pad_token_id).view(batch_size, 1, 1, -1) 276 | num_src_tokens = torch.sum(src_mask.long()) 277 | 278 | return src_mask, num_src_tokens 279 | 280 | 281 | def get_masks_and_count_tokens_trg(trg_token_ids_batch, pad_token_id): 282 | batch_size = trg_token_ids_batch.shape[0] 283 | device = trg_token_ids_batch.device 284 | 285 | # Same as src_mask but we additionally want to mask tokens from looking forward into the future tokens 286 | # Note: wherever the mask value is true we want to attend to that token, otherwise we mask (ignore) it. 287 | sequence_length = trg_token_ids_batch.shape[1] # trg_token_ids shape = (B, T) where T max trg token-sequence length 288 | trg_padding_mask = (trg_token_ids_batch != pad_token_id).view(batch_size, 1, 1, -1) # shape = (B, 1, 1, T) 289 | trg_no_look_forward_mask = torch.triu(torch.ones((1, 1, sequence_length, sequence_length), device=device) == 1).transpose(2, 3) 290 | 291 | # logic AND operation (both padding mask and no-look-forward must be true to attend to a certain target token) 292 | trg_mask = trg_padding_mask & trg_no_look_forward_mask # final shape = (B, 1, T, T) 293 | num_trg_tokens = torch.sum(trg_padding_mask.long()) 294 | 295 | return trg_mask, num_trg_tokens 296 | 297 | 298 | def get_masks_and_count_tokens(src_token_ids_batch, trg_token_ids_batch, pad_token_id, device): 299 | src_mask, num_src_tokens = get_masks_and_count_tokens_src(src_token_ids_batch, pad_token_id) 300 | trg_mask, num_trg_tokens = get_masks_and_count_tokens_trg(trg_token_ids_batch, pad_token_id) 301 | 302 | return src_mask, trg_mask, num_src_tokens, num_trg_tokens 303 | 304 | 305 | def get_src_and_trg_batches(token_ids_batch): 306 | src_token_ids_batch, trg_token_ids_batch = token_ids_batch.src, token_ids_batch.trg 307 | 308 | # Target input should be shifted by 1 compared to the target output tokens 309 | # Example: if we had a sentence like: [,what,is,up,] then to train the NMT model what we do is we pass 310 | # [,what,is,up] to the input as set [what,is,up,] as the expected output. 311 | trg_token_ids_batch_input = trg_token_ids_batch[:, :-1] 312 | 313 | # We reshape from (B, S) into (BxS, 1) as that's the the shape expected by LabelSmoothing which will produce 314 | # the shape (BxS, V) where V is the target vocab size which is the same shape as the one that comes out 315 | # from the transformer so we can directly pass them into the KL divergence loss 316 | trg_token_ids_batch_gt = trg_token_ids_batch[:, 1:].reshape(-1, 1) 317 | 318 | return src_token_ids_batch, trg_token_ids_batch_input, trg_token_ids_batch_gt 319 | 320 | 321 | # 322 | # Everything below is for testing purposes only - feel free to ignore 323 | # 324 | 325 | 326 | def sample_text_from_loader(src_field_processor, trg_field_processor, token_ids_loader, num_samples=2, sample_src=True, sample_trg=True, show_padded=False): 327 | assert sample_src or sample_trg, f'Either src or trg or both must be enabled.' 328 | 329 | for b_idx, token_ids_batch in enumerate(token_ids_loader): 330 | if b_idx == num_samples: # Number of sentence samples to print 331 | break 332 | 333 | print('*' * 5) 334 | if sample_src: 335 | print("Source text:", end="\t") 336 | for token_id in token_ids_batch.src[0]: # print only the first example from the batch 337 | src_token = src_field_processor.vocab.itos[token_id] 338 | 339 | if src_token == PAD_TOKEN and not show_padded: 340 | continue 341 | 342 | print(src_token, end=" ") 343 | print() 344 | 345 | if sample_trg: 346 | print("Target text:", end="\t") 347 | for token_id in token_ids_batch.trg[0]: 348 | trg_token = trg_field_processor.vocab.itos[token_id] 349 | 350 | if trg_token == PAD_TOKEN and not show_padded: 351 | continue 352 | 353 | print(trg_token, end=" ") 354 | print() 355 | 356 | 357 | if __name__ == "__main__": 358 | # To run this delete the dot from from .constants import - not the most elegant solution but it works 359 | # without me having to add sys.path stuff, if you have a more elegant solution please open an issue <3 360 | batch_size = 8 361 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 362 | dataset_name = DatasetType.IWSLT.name 363 | language_direction = LanguageDirection.G2E.name 364 | train_token_ids_loader, val_token_ids_loader, src_field_processor, trg_field_processor = get_data_loaders(DATA_DIR_PATH, language_direction, dataset_name, batch_size, device) 365 | 366 | # Verify that the mask logic is correct 367 | pad_token_id = src_field_processor.vocab.stoi[PAD_TOKEN] 368 | for batch in train_token_ids_loader: 369 | # Visually inspect that masks make sense 370 | src_padding_mask, trg_mask, num_src_tokens, num_trg_tokens = get_masks_and_count_tokens(batch.src, batch.trg, pad_token_id, device) 371 | break 372 | 373 | # Check vocab size 374 | print(f'Source vocabulary size={len(src_field_processor.vocab)}') 375 | print(f'Target vocabulary size={len(trg_field_processor.vocab)}') 376 | 377 | # Show text from token loader 378 | sample_text_from_loader(src_field_processor, trg_field_processor, train_token_ids_loader) 379 | 380 | -------------------------------------------------------------------------------- /utils/decoding_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from "Pytorch Original Transformer" by Aleksa Gordić 3 | https://github.com/gordicaleksa/pytorch-original-transformer 4 | """ 5 | 6 | import enum 7 | 8 | 9 | import torch 10 | import numpy as np 11 | 12 | 13 | from .constants import * 14 | from utils.data_utils import get_masks_and_count_tokens_trg 15 | 16 | 17 | class DecodingMethod(enum.Enum): 18 | GREEDY = 0, 19 | BEAM = 1 20 | 21 | 22 | def greedy_decoding(baseline_transformer, src_representations_batch, src_mask, trg_field_processor, max_target_tokens=100): 23 | """ 24 | Supports batch (decode multiple source sentences) greedy decoding. 25 | 26 | Decoding could be further optimized to cache old token activations because they can't look ahead and so 27 | adding a newly predicted token won't change old token's activations. 28 | 29 | Example: we input and do a forward pass. We get intermediate activations for and at the output at position 30 | 0, after the doing linear layer we get e.g. token . Now we input , but 's activations will remain 31 | the same. Similarly say we now got at output position 1, in the next step we input ,, and so 's 32 | activations will remain the same as it only looks at/attends to itself and to and so forth. 33 | 34 | """ 35 | 36 | device = next(baseline_transformer.parameters()).device 37 | pad_token_id = trg_field_processor.vocab.stoi[PAD_TOKEN] 38 | 39 | # Initial prompt is the beginning/start of the sentence token. Make it compatible shape with source batch => (B,1) 40 | target_sentences_tokens = [[BOS_TOKEN] for _ in range(src_representations_batch.shape[0])] 41 | trg_token_ids_batch = torch.tensor([[trg_field_processor.vocab.stoi[tokens[0]]] for tokens in target_sentences_tokens], device=device) 42 | 43 | # Set to true for a particular target sentence once it reaches the EOS (end-of-sentence) token 44 | is_decoded = [False] * src_representations_batch.shape[0] 45 | 46 | while True: 47 | trg_mask, _ = get_masks_and_count_tokens_trg(trg_token_ids_batch, pad_token_id) 48 | # Shape = (B*T, V) where T is the current token-sequence length and V target vocab size 49 | predicted_log_distributions = baseline_transformer.decode(trg_token_ids_batch, src_representations_batch, trg_mask, src_mask) 50 | 51 | # Extract only the indices of last token for every target sentence (we take every T-th token) 52 | num_of_trg_tokens = len(target_sentences_tokens[0]) 53 | predicted_log_distributions = predicted_log_distributions[num_of_trg_tokens-1::num_of_trg_tokens] 54 | 55 | # This is the "greedy" part of the greedy decoding: 56 | # We find indices of the highest probability target tokens and discard every other possibility 57 | most_probable_last_token_indices = torch.argmax(predicted_log_distributions, dim=-1).cpu().numpy() 58 | 59 | # Find target tokens associated with these indices 60 | predicted_words = [trg_field_processor.vocab.itos[index] for index in most_probable_last_token_indices] 61 | 62 | for idx, predicted_word in enumerate(predicted_words): 63 | target_sentences_tokens[idx].append(predicted_word) 64 | 65 | if predicted_word == EOS_TOKEN: # once we find EOS token for a particular sentence we flag it 66 | is_decoded[idx] = True 67 | 68 | if all(is_decoded) or num_of_trg_tokens == max_target_tokens: 69 | break 70 | 71 | # Prepare the input for the next iteration (merge old token ids with the new column of most probable token ids) 72 | trg_token_ids_batch = torch.cat((trg_token_ids_batch, torch.unsqueeze(torch.tensor(most_probable_last_token_indices, device=device), 1)), 1) 73 | 74 | # Post process the sentences - remove everything after the EOS token 75 | target_sentences_tokens_post = [] 76 | for target_sentence_tokens in target_sentences_tokens: 77 | try: 78 | target_index = target_sentence_tokens.index(EOS_TOKEN) + 1 79 | except: 80 | target_index = None 81 | 82 | target_sentence_tokens = target_sentence_tokens[:target_index] 83 | target_sentences_tokens_post.append(target_sentence_tokens) 84 | 85 | return target_sentences_tokens_post 86 | 87 | 88 | def get_beam_decoder(translation_config): 89 | """ 90 | Note: this implementation could probably be further optimized I just wanted a decent working version. 91 | 92 | Notes: 93 | 94 | https://arxiv.org/pdf/1609.08144.pdf introduces various heuristics into the beam search algorithm like coverage 95 | penalty, etc. Here I only designed a simple beam search algorithm with length penalty. As the probability of the 96 | sequence is constructed by multiplying the conditional probabilities (which are numbers smaller than 1) the beam 97 | search algorithm will prefer shorter sentences which we compensate for using the length penalty. 98 | 99 | """ 100 | beam_size = translation_config['beam_size'] 101 | length_penalty_coefficient = translation_config['length_penalty_coefficient'] 102 | 103 | def beam_decoding(baseline_transformer, src_representations_batch, src_mask, trg_field_processor, max_target_tokens=100): 104 | raise Exception('Not yet implemented.') 105 | device = next(baseline_transformer.parameters()).device 106 | pad_token_id = trg_field_processor.vocab.stoi[PAD_TOKEN] 107 | 108 | # Initial prompt is the beginning/start of the sentence token. Make it compatible shape with source batch => (B,1) 109 | batch_size, S, model_dimension = src_representations_batch.shape 110 | target_multiple_hypotheses_tokens = [[BOS_TOKEN] for _ in range(batch_size)] 111 | trg_token_ids_batch = torch.tensor([[trg_field_processor.vocab.stoi[tokens[0]]] for tokens in target_multiple_hypotheses_tokens], device=device) 112 | 113 | # Repeat so that source sentence representations are repeated contiguously, say we have [s1, s2] we want 114 | # [s1, s1, s2, s2] and not [s1, s2, s1, s2] where s1 is single sentence representation with shape=(S, D) 115 | # where S - max source token-sequence length, D - model dimension 116 | src_representations_batch = src_representations_batch.repeat(1, beam_size, 1).view(beam_size*batch_size, -1, model_dimension) 117 | trg_token_ids_batch = trg_token_ids_batch.repeat(beam_size, 1) 118 | 119 | hypotheses_log_probs = torch.zeros((batch_size * beam_size, 1), device=device) 120 | had_eos = [[False] for _ in range(hypotheses_log_probs.shape[0])] 121 | 122 | while True: 123 | trg_mask, _ = get_masks_and_count_tokens_trg(trg_token_ids_batch, pad_token_id) 124 | # Shape = (B*BS*T, V) T - current token-sequence length, V - target vocab size, BS - beam size, B - batch 125 | predicted_log_distributions = baseline_transformer.decode(trg_token_ids_batch, src_representations_batch, trg_mask, src_mask) 126 | 127 | # Extract only the indices of last token for every target sentence (we take every T-th token) 128 | # Shape = (B*BS, V) 129 | num_of_trg_tokens = trg_token_ids_batch.shape[-1] 130 | predicted_log_distributions = predicted_log_distributions[num_of_trg_tokens - 1::num_of_trg_tokens] 131 | 132 | # This time extract beam_size number of highest probability tokens (compare to greedy's arg max) 133 | # Shape = (B*BS, BS) 134 | latest_token_log_probs, most_probable_token_indices = torch.topk(predicted_log_distributions, beam_size, dim=-1, sorted=True) 135 | 136 | # Don't update the hypothesis which had EOS already (pruning) 137 | latest_token_log_probs.masked_fill(torch.tensor(had_eos == True), float("-inf")) 138 | 139 | # Calculate probabilities for every beam hypothesis (since we have log prob we add instead of multiply) 140 | # Shape = (B*BS, BS) 141 | hypotheses_pool_log_probs = hypotheses_log_probs + latest_token_log_probs 142 | # Shape = (B, BS, BS) 143 | most_probable_token_indices = most_probable_token_indices.view(batch_size, beam_size, beam_size) 144 | hypotheses_pool_log_probs = hypotheses_pool_log_probs.view(batch_size, beam_size, beam_size) 145 | # Shape = (B, BS*BS) 146 | hypotheses_pool_log_probs = torch.flatten(hypotheses_pool_log_probs, start_dim=-1) 147 | 148 | # Figure out indices of beam_size most probably hypothesis for every target sentence in the batch 149 | # Shape = (B, BS) 150 | new_hypothesis_log_probs, next_hypothesis_indices = torch.topk(hypotheses_pool_log_probs, beam_size, dim=-1, sorted=True) 151 | 152 | # Create new target ids batch 153 | hypotheses_log_probs_tmp = torch.empty((batch_size * beam_size, 1)) 154 | 155 | T = trg_token_ids_batch.shape[-1] 156 | new_trg_token_ids_batch = torch.empty((batch_size * beam_size, T + 1)) 157 | 158 | next_hypothesis_indices = next_hypothesis_indices.cpu().numpy() 159 | # Prepare new hypotheses for the next iteration 160 | for b_idx, indices in enumerate(next_hypothesis_indices): 161 | for h_idx, token_index in indices: 162 | row, column = token_index / beam_size, token_index % beam_size 163 | hypothesis_index = b_idx * beam_size + h_idx 164 | 165 | new_token_id = most_probable_token_indices[b_idx, row, column] 166 | if had_eos[hypothesis_index]: 167 | new_trg_token_ids_batch[hypothesis_index, :-1] = trg_token_ids_batch[hypothesis_index, :] 168 | else: 169 | new_trg_token_ids_batch[hypothesis_index, :-1] = trg_token_ids_batch[b_idx * beam_size + row, :] 170 | new_trg_token_ids_batch[hypothesis_index, -1] = new_token_id 171 | 172 | if had_eos[hypothesis_index]: 173 | hypotheses_log_probs_tmp[hypothesis_index] = hypotheses_log_probs[hypothesis_index] 174 | else: 175 | hypotheses_log_probs_tmp[hypothesis_index] = new_hypothesis_log_probs[hypothesis_index] 176 | 177 | if new_token_id == trg_field_processor.vocab.stoi[EOS_TOKEN]: 178 | had_eos[hypothesis_index] = True 179 | 180 | # Update the current hypothesis probabilities 181 | hypotheses_log_probs = hypotheses_log_probs_tmp 182 | trg_token_ids_batch = new_trg_token_ids_batch 183 | 184 | if all(had_eos) or num_of_trg_tokens == max_target_tokens: 185 | break 186 | 187 | # 188 | # Selection and post-processing 189 | # 190 | 191 | target_multiple_hypotheses_tokens = [] 192 | trg_token_ids_batch_numpy = trg_token_ids_batch.cpu().numpy() 193 | for hypothesis_ids in trg_token_ids_batch_numpy: 194 | target_multiple_hypotheses_tokens.append([trg_field_processor.vocab.itos[token_id] for token_id in hypothesis_ids]) 195 | 196 | # Step 1: Select the most probable hypothesis out of beam_size hypotheses for each target sentence 197 | hypotheses_log_probs = hypotheses_log_probs.view(batch_size, beam_size) 198 | most_probable_hypotheses_indices = torch.argmax(hypotheses_log_probs, dim=-1).cpu().numpy() 199 | target_sentences_tokens = [] 200 | for b_idx, index in enumerate(most_probable_hypotheses_indices): 201 | target_sentences_tokens.append(target_multiple_hypotheses_tokens[b_idx * beam_size + index]) 202 | 203 | # Step 2: Post process the sentences - remove everything after the EOS token 204 | target_sentences_tokens_post = [] 205 | for target_sentence_tokens in target_sentences_tokens: 206 | try: 207 | target_index = target_sentence_tokens.index(EOS_TOKEN) + 1 208 | except: 209 | target_index = None 210 | 211 | target_sentence_tokens = target_sentence_tokens[:target_index] 212 | target_sentences_tokens_post.append(target_sentence_tokens) 213 | 214 | return target_sentences_tokens_post 215 | 216 | return beam_decoding 217 | 218 | -------------------------------------------------------------------------------- /utils/optimizers_and_distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class CustomLRAdamOptimizer: 6 | """ 7 | Linear ramp learning rate for the warm-up number of steps and then start decaying 8 | according to the inverse square root law of the current training step number. 9 | 10 | Check out playground.py for visualization of the learning rate (visualize_custom_lr_adam). 11 | """ 12 | 13 | def __init__(self, optimizer, model_dimension, num_of_warmup_steps, lr_inverse_dim=False, constant_lr=False, lr_factor=1.0): 14 | self.optimizer = optimizer 15 | 16 | if lr_inverse_dim: 17 | self.model_size = model_dimension 18 | else: 19 | # We changed this part to make sure that the temperature does not change with model size 20 | # Otherwise, we have a coupled effect of the temperature and model size 21 | self.model_size = 256 22 | 23 | self.num_of_warmup_steps = num_of_warmup_steps 24 | self.constant_lr = constant_lr 25 | print(f"Do we use constant learning rate? {constant_lr}") 26 | self.current_step_number = 0 27 | self.lr_factor = lr_factor 28 | 29 | def step(self): 30 | self.current_step_number += 1 31 | current_learning_rate = self.get_current_learning_rate() 32 | 33 | for p in self.optimizer.param_groups: 34 | p['lr'] = current_learning_rate 35 | 36 | self.optimizer.step() # apply gradients 37 | 38 | # Check out the formula at Page 7, Chapter 5.3 "Optimizer" and playground.py for visualization 39 | def get_current_learning_rate(self): 40 | # For readability purpose 41 | step = self.current_step_number 42 | warmup = self.num_of_warmup_steps 43 | 44 | if not self.constant_lr: 45 | # This is the default setting used in fairseq 46 | return self.lr_factor * self.model_size ** (-0.5) * min(step ** (-0.5), step * warmup ** (-1.5)) 47 | else: 48 | # This is the setting when we have constant learning rate (or constant temperature) 49 | return self.lr_factor * self.model_size ** (-0.5) * min(warmup ** (-0.5), step * warmup ** (-1.5)) 50 | 51 | def zero_grad(self): 52 | self.optimizer.zero_grad() 53 | 54 | 55 | class LabelSmoothingDistribution(nn.Module): 56 | """ 57 | Instead of one-hot target distribution set the target word's probability to "confidence_value" (usually 0.9) 58 | and distribute the rest of the "smoothing_value" mass (usually 0.1) over the rest of the vocab. 59 | 60 | Check out playground.py for visualization of how the smooth target distribution looks like compared to one-hot. 61 | """ 62 | 63 | def __init__(self, smoothing_value, pad_token_id, trg_vocab_size, device): 64 | assert 0.0 <= smoothing_value <= 1.0 65 | 66 | super(LabelSmoothingDistribution, self).__init__() 67 | 68 | self.confidence_value = 1.0 - smoothing_value 69 | self.smoothing_value = smoothing_value 70 | 71 | self.pad_token_id = pad_token_id 72 | self.trg_vocab_size = trg_vocab_size 73 | self.device = device 74 | 75 | def forward(self, trg_token_ids_batch): 76 | 77 | batch_size = trg_token_ids_batch.shape[0] 78 | smooth_target_distributions = torch.zeros((batch_size, self.trg_vocab_size), device=self.device) 79 | 80 | # -2 because we are not distributing the smoothing mass over the pad token index and over the ground truth index 81 | # those 2 values will be overwritten by the following 2 lines with confidence_value and 0 (for pad token index) 82 | smooth_target_distributions.fill_(self.smoothing_value / (self.trg_vocab_size - 2)) 83 | 84 | smooth_target_distributions.scatter_(1, trg_token_ids_batch, self.confidence_value) 85 | smooth_target_distributions[:, self.pad_token_id] = 0. 86 | 87 | # If we had a pad token as a target we set the distribution to all 0s instead of smooth labeled distribution 88 | smooth_target_distributions.masked_fill_(trg_token_ids_batch == self.pad_token_id, 0.) 89 | 90 | return smooth_target_distributions 91 | 92 | 93 | class OneHotDistribution(nn.Module): 94 | """ 95 | Create a one hot distribution (feel free to ignore used only in playground.py) 96 | """ 97 | 98 | def __init__(self, pad_token_id, trg_vocab_size): 99 | 100 | super(OneHotDistribution, self).__init__() 101 | 102 | self.pad_token_id = pad_token_id 103 | self.trg_vocab_size = trg_vocab_size 104 | 105 | def forward(self, trg_token_ids_batch): 106 | 107 | batch_size = trg_token_ids_batch.shape[0] 108 | one_hot_distribution = torch.zeros((batch_size, self.trg_vocab_size)) 109 | one_hot_distribution.scatter_(1, trg_token_ids_batch, 1.) 110 | 111 | # If we had a pad token as a target we set the distribution to all 0s instead of one-hot distribution 112 | one_hot_distribution.masked_fill_(trg_token_ids_batch == self.pad_token_id, 0.) 113 | 114 | return one_hot_distribution 115 | -------------------------------------------------------------------------------- /utils/resource_downloader.py: -------------------------------------------------------------------------------- 1 | import zipfile 2 | import os 3 | 4 | 5 | from torch.hub import download_url_to_file 6 | 7 | 8 | from .constants import BINARIES_PATH 9 | 10 | 11 | IWSLT_ENGLISH_TO_GERMAN_MODEL_URL = r'https://www.dropbox.com/s/a6pfo6t9m2dh1jq/iwslt_e2g.pth?dl=1' 12 | IWSLT_GERMAN_TO_ENGLISH_MODEL_URL = r'https://www.dropbox.com/s/dgcd4xhwig7ygqd/iwslt_g2e.pth?dl=1' 13 | 14 | 15 | # Not yet trained 16 | WMT14_ENGLISH_TO_GERMAN_MODEL_URL = None 17 | WMT14_GERMAN_TO_ENGLISH_MODEL_URL = None 18 | 19 | 20 | DOWNLOAD_DICT = { 21 | 'iwslt_e2g': IWSLT_ENGLISH_TO_GERMAN_MODEL_URL, 22 | 'iwslt_g2e': IWSLT_GERMAN_TO_ENGLISH_MODEL_URL, 23 | 'wmt14_e2g': WMT14_ENGLISH_TO_GERMAN_MODEL_URL, 24 | 'wmt14_g2e': WMT14_GERMAN_TO_ENGLISH_MODEL_URL 25 | } 26 | 27 | 28 | download_choices = list(DOWNLOAD_DICT.keys()) 29 | 30 | 31 | def download_models(translation_config): 32 | # Step 1: Form the key 33 | language_direction = translation_config['language_direction'].lower() 34 | dataset_name = translation_config['dataset_name'].lower() 35 | key = f'{dataset_name}_{language_direction}' 36 | 37 | # Step 2: Check whether this model already exists 38 | model_name = f'{key}.pth' 39 | model_path = os.path.join(BINARIES_PATH, model_name) 40 | if os.path.exists(model_path): 41 | print(f'No need to download, found model {model_path} that was trained on {dataset_name} for language direction {language_direction}.') 42 | return model_path 43 | 44 | # Step 3: Download the resource to local filesystem 45 | remote_resource_path = DOWNLOAD_DICT[key] 46 | if remote_resource_path is None: # handle models which I've not provided URLs for yet 47 | print(f'No model found that was trained on {dataset_name} for language direction {language_direction}.') 48 | exit(0) 49 | 50 | print(f'Downloading from {remote_resource_path}. This may take a while.') 51 | download_url_to_file(remote_resource_path, model_path) 52 | 53 | return model_path 54 | 55 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from "Pytorch Original Transformer" by Aleksa Gordić 3 | https://github.com/gordicaleksa/pytorch-original-transformer 4 | """ 5 | 6 | import re 7 | import os 8 | import time 9 | 10 | 11 | #import git 12 | import torch 13 | from nltk.translate.bleu_score import corpus_bleu 14 | 15 | 16 | from .constants import BINARIES_PATH, PAD_TOKEN 17 | from .decoding_utils import greedy_decoding 18 | from .data_utils import get_masks_and_count_tokens_src 19 | 20 | 21 | def get_available_binary_name(): 22 | prefix = 'transformer' 23 | 24 | def valid_binary_name(binary_name): 25 | # First time you see raw f-string? Don't worry the only trick is to double the brackets. 26 | pattern = re.compile(rf'{prefix}_[0-9]{{6}}\.pth') 27 | return re.fullmatch(pattern, binary_name) is not None 28 | 29 | # Just list the existing binaries so that we don't overwrite them but write to a new one 30 | valid_binary_names = list(filter(valid_binary_name, os.listdir(BINARIES_PATH))) 31 | if len(valid_binary_names) > 0: 32 | last_binary_name = sorted(valid_binary_names)[-1] 33 | new_suffix = int(last_binary_name.split('.')[0][-6:]) + 1 # increment by 1 34 | return f'{prefix}_{str(new_suffix).zfill(6)}.pth' 35 | else: 36 | return f'{prefix}_000000.pth' 37 | 38 | 39 | def get_training_state(training_config, model): 40 | training_state = { 41 | # "commit_hash": git.Repo(search_parent_directories=True).head.object.hexsha, 42 | "dataset_name": training_config['dataset_name'], 43 | "language_direction": training_config['language_direction'], 44 | 45 | "num_of_epochs": training_config['num_of_epochs'], 46 | "batch_size": training_config['batch_size'], 47 | 48 | "state_dict": model.state_dict() 49 | } 50 | 51 | return training_state 52 | 53 | 54 | def print_model_metadata(training_state): 55 | header = f'\n{"*"*5} Model training metadata: {"*"*5}' 56 | print(header) 57 | 58 | for key, value in training_state.items(): 59 | if key != 'state_dict': # don't print state_dict it's a bunch of numbers... 60 | if key == 'language_direction': # convert into human readable format 61 | value = 'English to German' if value == 'E2G' else 'German to English' 62 | print(f'{key}: {value}') 63 | print(f'{"*" * len(header)}\n') 64 | 65 | 66 | # Calculate the BLEU-4 score 67 | def calculate_bleu_score(transformer, token_ids_loader, trg_field_processor, max_batch=-1): 68 | with torch.no_grad(): 69 | pad_token_id = trg_field_processor.vocab.stoi[PAD_TOKEN] 70 | 71 | gt_sentences_corpus = [] 72 | predicted_sentences_corpus = [] 73 | 74 | ts = time.time() 75 | for batch_idx, token_ids_batch in enumerate(token_ids_loader): 76 | src_token_ids_batch, trg_token_ids_batch = token_ids_batch.src, token_ids_batch.trg 77 | if batch_idx % 10 == 0: 78 | print(f'batch={batch_idx}, time elapsed = {time.time()-ts} seconds.') 79 | 80 | # Optimization - compute the source token representations only once 81 | src_mask, _ = get_masks_and_count_tokens_src(src_token_ids_batch, pad_token_id) 82 | src_representations_batch = transformer.encode(src_token_ids_batch, src_mask) 83 | 84 | predicted_sentences = greedy_decoding(transformer, src_representations_batch, src_mask, trg_field_processor) 85 | predicted_sentences_corpus.extend(predicted_sentences) # add them to the corpus of translations 86 | 87 | # Get the token and not id version of GT (ground-truth) sentences 88 | trg_token_ids_batch = trg_token_ids_batch.cpu().numpy() 89 | for target_sentence_ids in trg_token_ids_batch: 90 | target_sentence_tokens = [trg_field_processor.vocab.itos[id] for id in target_sentence_ids if id != pad_token_id] 91 | gt_sentences_corpus.append([target_sentence_tokens]) # add them to the corpus of GT translations 92 | 93 | if max_batch>0 and batch_idx>max_batch: 94 | break 95 | 96 | bleu_score = corpus_bleu(gt_sentences_corpus, predicted_sentences_corpus) 97 | print(f'BLEU-4 corpus score = {bleu_score}, corpus length = {len(gt_sentences_corpus)}, time elapsed = {time.time()-ts} seconds.') 98 | return bleu_score 99 | -------------------------------------------------------------------------------- /utils/utils_CKA.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def gram_linear(x): 5 | """Compute Gram (kernel) matrix for a linear kernel. 6 | Args: 7 | x: A num_examples x num_features matrix of features. 8 | Returns: 9 | A num_examples x num_examples Gram matrix of examples. 10 | """ 11 | return x.dot(x.T) 12 | 13 | 14 | def gram_rbf(x, threshold=1.0): 15 | """Compute Gram (kernel) matrix for an RBF kernel. 16 | Args: 17 | x: A num_examples x num_features matrix of features. 18 | threshold: Fraction of median Euclidean distance to use as RBF kernel 19 | bandwidth. (This is the heuristic we use in the paper. There are other 20 | possible ways to set the bandwidth; we didn't try them.) 21 | Returns: 22 | A num_examples x num_examples Gram matrix of examples. 23 | """ 24 | dot_products = x.dot(x.T) 25 | sq_norms = np.diag(dot_products) 26 | sq_distances = -2 * dot_products + sq_norms[:, None] + sq_norms[None, :] 27 | sq_median_distance = np.median(sq_distances) 28 | return np.exp(-sq_distances / (2 * threshold ** 2 * sq_median_distance)) 29 | 30 | 31 | def center_gram(gram, unbiased=False): 32 | """Center a symmetric Gram matrix. 33 | This is equvialent to centering the (possibly infinite-dimensional) features 34 | induced by the kernel before computing the Gram matrix. 35 | Args: 36 | gram: A num_examples x num_examples symmetric matrix. 37 | unbiased: Whether to adjust the Gram matrix in order to compute an unbiased 38 | estimate of HSIC. Note that this estimator may be negative. 39 | Returns: 40 | A symmetric matrix with centered columns and rows. 41 | """ 42 | if not np.allclose(gram, gram.T): 43 | raise ValueError('Input must be a symmetric matrix.') 44 | gram = gram.copy() 45 | 46 | if unbiased: 47 | # This formulation of the U-statistic, from Szekely, G. J., & Rizzo, M. 48 | # L. (2014). Partial distance correlation with methods for dissimilarities. 49 | # The Annals of Statistics, 42(6), 2382-2412, seems to be more numerically 50 | # stable than the alternative from Song et al. (2007). 51 | n = gram.shape[0] 52 | np.fill_diagonal(gram, 0) 53 | means = np.sum(gram, 0, dtype=np.float64) / (n - 2) 54 | means -= np.sum(means) / (2 * (n - 1)) 55 | gram -= means[:, None] 56 | gram -= means[None, :] 57 | np.fill_diagonal(gram, 0) 58 | else: 59 | means = np.mean(gram, 0, dtype=np.float64) 60 | means -= np.mean(means) / 2 61 | gram -= means[:, None] 62 | gram -= means[None, :] 63 | 64 | return gram 65 | 66 | 67 | def cka_compute(gram_x, gram_y, debiased=False): 68 | """Compute CKA. 69 | Args: 70 | gram_x: A num_examples x num_examples Gram matrix. 71 | gram_y: A num_examples x num_examples Gram matrix. 72 | debiased: Use unbiased estimator of HSIC. CKA may still be biased. 73 | Returns: 74 | The value of CKA between X and Y. 75 | """ 76 | gram_x = center_gram(gram_x, unbiased=debiased) 77 | gram_y = center_gram(gram_y, unbiased=debiased) 78 | 79 | # Note: To obtain HSIC, this should be divided by (n-1)**2 (biased variant) or 80 | # n*(n-3) (unbiased variant), but this cancels for CKA. 81 | scaled_hsic = gram_x.ravel().dot(gram_y.ravel()) 82 | 83 | normalization_x = np.linalg.norm(gram_x) 84 | normalization_y = np.linalg.norm(gram_y) 85 | return scaled_hsic / (normalization_x * normalization_y) 86 | 87 | 88 | def _debiased_dot_product_similarity_helper( 89 | xty, sum_squared_rows_x, sum_squared_rows_y, squared_norm_x, squared_norm_y, 90 | n): 91 | """Helper for computing debiased dot product similarity (i.e. linear HSIC).""" 92 | # This formula can be derived by manipulating the unbiased estimator from 93 | # Song et al. (2007). 94 | return ( 95 | xty - n / (n - 2.) * sum_squared_rows_x.dot(sum_squared_rows_y) 96 | + squared_norm_x * squared_norm_y / ((n - 1) * (n - 2))) 97 | 98 | 99 | def feature_space_linear_cka(features_x, features_y, debiased=False): 100 | """Compute CKA with a linear kernel, in feature space. 101 | This is typically faster than computing the Gram matrix when there are fewer 102 | features than examples. 103 | Args: 104 | features_x: A num_examples x num_features matrix of features. 105 | features_y: A num_examples x num_features matrix of features. 106 | debiased: Use unbiased estimator of dot product similarity. CKA may still be 107 | biased. Note that this estimator may be negative. 108 | Returns: 109 | The value of CKA between X and Y. 110 | """ 111 | 112 | features_x = features_x - np.mean(features_x, 0, keepdims=True) 113 | features_y = features_y - np.mean(features_y, 0, keepdims=True) 114 | 115 | dot_product_similarity = np.linalg.norm(features_x.T.dot(features_y)) ** 2 116 | normalization_x = np.linalg.norm(features_x.T.dot(features_x)) 117 | normalization_y = np.linalg.norm(features_y.T.dot(features_y)) 118 | 119 | if debiased: 120 | n = features_x.shape[0] 121 | # Equivalent to np.sum(features_x ** 2, 1) but avoids an intermediate array. 122 | sum_squared_rows_x = np.einsum('ij,ij->i', features_x, features_x) 123 | sum_squared_rows_y = np.einsum('ij,ij->i', features_y, features_y) 124 | squared_norm_x = np.sum(sum_squared_rows_x) 125 | squared_norm_y = np.sum(sum_squared_rows_y) 126 | 127 | dot_product_similarity = _debiased_dot_product_similarity_helper( 128 | dot_product_similarity, sum_squared_rows_x, sum_squared_rows_y, 129 | squared_norm_x, squared_norm_y, n) 130 | 131 | """ 132 | dx = _debiased_dot_product_similarity_helper( 133 | normalization_x ** 2, sum_squared_rows_x, sum_squared_rows_x, 134 | squared_norm_x, squared_norm_x, n) 135 | dy = _debiased_dot_product_similarity_helper( 136 | normalization_y ** 2, sum_squared_rows_y, sum_squared_rows_y, 137 | squared_norm_y, squared_norm_y, n) 138 | 139 | print("One estimate is {0}".format(dx)) 140 | 141 | if dx<0: 142 | print(dx) 143 | 1/0 144 | if dy<0: 145 | print(dy) 146 | 1/0 147 | """ 148 | 149 | normalization_x = np.sqrt(_debiased_dot_product_similarity_helper( 150 | normalization_x ** 2, sum_squared_rows_x, sum_squared_rows_x, 151 | squared_norm_x, squared_norm_x, n)) 152 | normalization_y = np.sqrt(_debiased_dot_product_similarity_helper( 153 | normalization_y ** 2, sum_squared_rows_y, sum_squared_rows_y, 154 | squared_norm_y, squared_norm_y, n)) 155 | 156 | return dot_product_similarity / (normalization_x * normalization_y) -------------------------------------------------------------------------------- /utils/utils_NMT.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pickle 4 | 5 | def get_ckpt_folder(experiment_type, dataset, num_samples, width, training_type, directory_depth=0, 6 | lr_factor=None, folder_suffix='', depth=None): 7 | 8 | """ 9 | This function gets the checkpoint folder based on the type of experiment 10 | """ 11 | 12 | data_folder = dataset 13 | if experiment_type=="sample": 14 | data_folder += f'_sample_{num_samples}_new{folder_suffix}' 15 | width_folder = f'w{width}' 16 | training_folder = training_type 17 | if experiment_type=="lr": 18 | training_folder += f'_lr_factor_{lr_factor}{folder_suffix}' 19 | 20 | if experiment_type=="depth": 21 | width_folder += f'_depth_{depth}{folder_suffix}' 22 | elif experiment_type=="width": 23 | width_folder += folder_suffix 24 | 25 | if directory_depth==0: 26 | directory_suffix = '.' 27 | if directory_depth==1: 28 | directory_suffix = '../' 29 | 30 | ckpt_folder = os.path.join(directory_suffix, f'../checkpoint/NMT_epochs/{data_folder}/{width_folder}/{training_folder}') 31 | 32 | return ckpt_folder 33 | 34 | 35 | def get_epochs(args, num_samples=0): 36 | 37 | """ 38 | This function gets the epochs based on the number of samples 39 | """ 40 | if args.experiment_type == 'sample': 41 | 42 | epochs = range(21) 43 | 44 | elif args.experiment_type in ['width', 'lr', 'depth']: 45 | 46 | epochs = range(20) 47 | 48 | else: 49 | 50 | raise ValueError("Not trained yet.") 51 | 52 | return epochs 53 | 54 | 55 | def mk_metrics_folder(args, ckpt_folder): 56 | 57 | """ 58 | This function makes a folder to store the ww metrics 59 | """ 60 | 61 | print(ckpt_folder) 62 | assert os.path.exists(ckpt_folder) 63 | metric_folder = os.path.join(ckpt_folder, args.metric_folder) 64 | if not os.path.exists(metric_folder): 65 | os.mkdir(metric_folder) 66 | return 67 | 68 | 69 | def get_experiment_folders_and_epochs(args): 70 | 71 | """ 72 | This function gets checkpoint folders and epochs 73 | """ 74 | 75 | ckpt_folders = [] 76 | ckpt_epochs = [] 77 | widths = [] 78 | samples = [] 79 | lr_factors = [] 80 | depths = [] 81 | 82 | if args.experiment_type == 'width': 83 | #width_list = [128, 256, 512, 768, 1024, 1536] 84 | width_list = [128, 192, 256, 384, 512] 85 | 86 | for width in width_list: 87 | ckpt_folder = get_ckpt_folder(args.experiment_type, args.dataset, 0, width, args.training_type, 88 | directory_depth=args.directory_depth, 89 | folder_suffix=args.folder_suffix) 90 | if args.mkdir: 91 | if args.script_type=='train' and not os.path.exists(ckpt_folder): 92 | os.makedirs(ckpt_folder) 93 | mk_metrics_folder(args, ckpt_folder) 94 | ckpt_folders.append(ckpt_folder) 95 | ckpt_epochs.append(get_epochs(args)) 96 | widths.append(width) 97 | if args.dataset=='IWSLT': 98 | samples.append(0) 99 | else: 100 | samples.append(1280000) 101 | lr_factors.append(1) 102 | depths.append(6) 103 | 104 | if args.experiment_type == 'depth': 105 | if args.exclude_standard: 106 | depth_list = [2,3,4,5] 107 | else: 108 | depth_list = [2,3,4,5,None] 109 | width = args.IWSLT_width 110 | for depth in depth_list: 111 | ckpt_folder = get_ckpt_folder(args.experiment_type, args.dataset, 0, width, args.training_type, 112 | directory_depth=args.directory_depth, 113 | folder_suffix=args.folder_suffix, depth=depth) 114 | if args.mkdir: 115 | if args.script_type=='train' and not os.path.exists(ckpt_folder): 116 | os.makedirs(ckpt_folder) 117 | mk_metrics_folder(args, ckpt_folder) 118 | ckpt_folders.append(ckpt_folder) 119 | ckpt_epochs.append(get_epochs(args)) 120 | widths.append(width) 121 | if args.dataset=='IWSLT': 122 | samples.append(0) 123 | else: 124 | samples.append(1280000) 125 | lr_factors.append(1) 126 | depths.append(depth) 127 | 128 | elif args.experiment_type == 'sample': 129 | if args.dataset == 'IWSLT': 130 | sample_list = [10000, 20000, 40000, 80000, 160000] 131 | width = args.IWSLT_width 132 | elif args.dataset == 'WMT': 133 | sample_list = [160000, 320000, 640000, 1280000] 134 | width = args.IWSLT_width 135 | for sample in sample_list: 136 | ckpt_folder = get_ckpt_folder(args.experiment_type, args.dataset, sample, width, args.training_type, 137 | directory_depth=args.directory_depth, folder_suffix=args.folder_suffix) 138 | if args.mkdir: 139 | if args.script_type=='train' and not os.path.exists(ckpt_folder): 140 | os.makedirs(ckpt_folder) 141 | mk_metrics_folder(args, ckpt_folder) 142 | ckpt_folders.append(ckpt_folder) 143 | ckpt_epochs.append(get_epochs(args, num_samples=sample)) 144 | widths.append(width) 145 | samples.append(sample) 146 | lr_factors.append(1) 147 | depths.append(6) 148 | 149 | elif args.experiment_type == 'lr': 150 | if args.exclude_standard: 151 | lr_factor_list = ['0.25', '0.375', '0.5', '0.75', '2'] 152 | else: 153 | lr_factor_list = ['0.25', '0.375', '0.5', '0.75', None, '2'] 154 | for lr_factor in lr_factor_list: 155 | width = args.IWSLT_width 156 | ckpt_folder = get_ckpt_folder(args.experiment_type, args.dataset, 0, width, args.training_type, lr_factor=lr_factor, 157 | directory_depth=args.directory_depth, folder_suffix=args.folder_suffix) 158 | print(ckpt_folder) 159 | if args.mkdir: 160 | if args.script_type=='train' and not os.path.exists(ckpt_folder): 161 | os.makedirs(ckpt_folder) 162 | mk_metrics_folder(args, ckpt_folder) 163 | ckpt_folders.append(ckpt_folder) 164 | ckpt_epochs.append(get_epochs(args)) 165 | widths.append(width) 166 | if args.dataset=='IWSLT': 167 | samples.append(0) 168 | else: 169 | samples.append(1280000) 170 | lr_factors.append(lr_factor) 171 | depths.append(6) 172 | 173 | return ckpt_folders, ckpt_epochs, widths, samples, lr_factors, depths 174 | 175 | 176 | 177 | def get_val_loss(plot_result, loss_result): 178 | 179 | """ 180 | This function gets validation loss and saves in the plot_result dictionary 181 | """ 182 | 183 | plot_result['validation_loss'] = [] 184 | with open(loss_result, 'r') as file: 185 | lines = file.readlines() 186 | for x in lines: 187 | x = x.split(' ') 188 | if 'validation' and 'loss' in x: 189 | plot_result['validation_loss'].append(float(x[-1])) 190 | 191 | 192 | def get_bleu_score(plot_result, loss_result, divide): 193 | 194 | """ 195 | This function gets the bleu score and saves in the plot_result di 196 | """ 197 | 198 | plot_result['bleu_score'] = [] 199 | with open(loss_result, 'r') as file: 200 | lines = file.readlines() 201 | for x in lines: 202 | x = x.split(' ') 203 | if 'BLEU' in x and 'at' in x and 'score' in x: 204 | plot_result['bleu_score'].append(float(x[-1])*100/divide) 205 | 206 | 207 | def get_bleu_gap(plot_result, loss_result, divide): 208 | 209 | """ 210 | This function gets the bleu score and saves in the plot_result di 211 | """ 212 | 213 | dict0 = pickle.load(open(loss_result, 'rb')) 214 | epochs = list(dict0.keys()) 215 | epochs.sort() 216 | plot_result['bleu_score'] = [dict0[epoch]['bleu_score_generalization']/divide for epoch in epochs] 217 | 218 | 219 | def get_train_loss(plot_result, loss_result): 220 | 221 | plot_result['training_loss'] = [] 222 | with open(loss_result, 'r') as file: 223 | lines = file.readlines() 224 | for x in lines: 225 | x = x.split(' ') 226 | if 'training' and 'loss=' in x: 227 | plot_result['training_loss'].append(float(x[-1])) 228 | plot_result['training_loss'] = smooth(plot_result['training_loss'], box_pts=20) 229 | 230 | 231 | def smooth(y, box_pts): 232 | box = np.ones(box_pts)/box_pts 233 | y_smooth = np.convolve(y, box, mode='same') 234 | return y_smooth -------------------------------------------------------------------------------- /utils/utils_analyze_plots.py: -------------------------------------------------------------------------------- 1 | import re 2 | import pickle 3 | from scipy import stats 4 | import numpy as np 5 | import os 6 | 7 | def change_dict_form(mypath): 8 | 9 | with open(mypath, 'rb') as f: 10 | dict0 = pickle.load(f) 11 | 12 | dict1 = {} 13 | for key in dict0[0].keys(): 14 | dict1[key] = [dict0[epoch][key] for epoch in dict0.keys()] 15 | 16 | return dict1 17 | 18 | 19 | def compute_correlation_single(a,b,correlation_type='pearsonr'): 20 | 21 | if correlation_type=='pearsonr': 22 | rho, pval = stats.pearsonr(a, b) 23 | elif correlation_type=='spearmanr': 24 | rho, pval = stats.spearmanr(a, b) 25 | 26 | return rho 27 | 28 | 29 | def compute_rank_correlation(plot_result, different_metrics, exclude_metrics=[], correlation_type='pearsonr', single_features=[], bleu_score_type='test'): 30 | 31 | rank_correlations = {} 32 | if bleu_score_type=='test': 33 | bleu = [-x for x in plot_result['bleu_score']] 34 | elif bleu_score_type == 'gap': 35 | bleu = plot_result['bleu_score'] 36 | else: 37 | raise ValueError('BLEU score type not implemented!') 38 | 39 | for key in different_metrics.keys(): 40 | if key in exclude_metrics: 41 | continue 42 | 43 | rho = compute_correlation_single(bleu,different_metrics[key],correlation_type=correlation_type) 44 | rank_correlations[key] = rho 45 | 46 | rho = compute_correlation_single(bleu, plot_result['alpha'],correlation_type=correlation_type) 47 | rank_correlations['alpha'] = rho 48 | 49 | for key in single_features: 50 | rho = compute_correlation_single(bleu, plot_result[key],correlation_type=correlation_type) 51 | rank_correlations[key] = rho 52 | if key == 'rand_distance': 53 | rank_correlations[key] = -rho 54 | 55 | return rank_correlations 56 | 57 | 58 | def create_plot_result(keys, single_features, exclude_metrics): 59 | 60 | plot_result = {key:[] for key in keys} 61 | plot_result.update({key:[] for key in ['exp_dist_exponent', 'lognormal_sigma']}) 62 | plot_result.update({key:[] for key in single_features if key not in exclude_metrics}) 63 | 64 | return plot_result 65 | 66 | 67 | def aggregate_rank_correlations(rank_correlations, rank_correlations_min, rank_correlations_ave): 68 | 69 | if len(rank_correlations_min)==0: 70 | 71 | for key in rank_correlations.keys(): 72 | rank_correlations_min[key] = rank_correlations[key] 73 | rank_correlations_ave[key] = rank_correlations[key] 74 | 75 | return 76 | 77 | for key in rank_correlations.keys(): 78 | rank_correlations_min[key] = min(rank_correlations_min[key], rank_correlations[key]) 79 | rank_correlations_ave[key] = rank_correlations_ave[key] + rank_correlations[key] 80 | 81 | def average_rank_correlations(rank_correlations_ave): 82 | 83 | for key in rank_correlations_ave: 84 | rank_correlations_ave[key] /= 18 85 | 86 | 87 | def plot_rank_correlations(rank_correlations, ax): 88 | 89 | plot_keys = [k.upper() for k in rank_correlations.keys()] 90 | values = rank_correlations.values() 91 | 92 | y_pos = np.arange(len(plot_keys)) 93 | 94 | ax.barh(y_pos, values, align='center') 95 | ax.set_yticks(y_pos) 96 | ax.set_yticklabels(plot_keys) 97 | 98 | 99 | def get_ESD(ckpt_folder, epoch, ESD_type=1, layer_id=105, metric_folder_name='metrics_expcutoff'): 100 | 101 | if ESD_type == 1: 102 | esd_suffix = f'esd' 103 | elif ESD_type in [2,3,4]: 104 | esd_suffix = f'esd{ESD_type}' 105 | elif ESD_type in [5,6]: 106 | esd_suffix = f'mpfit{ESD_type-4}' 107 | elif ESD_type in [7,8]: 108 | esd_suffix = f'randesd.{ESD_type-6}' 109 | 110 | ESD_result = os.path.join(ckpt_folder, metric_folder_name, f'epoch_{epoch}', f'ww.layer{layer_id}.{esd_suffix}.png') 111 | 112 | return ESD_result 113 | -------------------------------------------------------------------------------- /utils/utils_huggingface.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | def preprocess_layers(args, model, model_type="distilbert-base-uncased"): 4 | 5 | if model_type=="distilbert-base-uncased": 6 | if args.train_one_layer: 7 | # Use this when fine-tuning only the last layer 8 | for name, param in model.distilbert.named_parameters(): 9 | param.requires_grad = False 10 | for name, param in model.distilbert.transformer.layer[5].named_parameters(): 11 | print(f"Only allow the training in layer 5-{name}") 12 | param.requires_grad = True 13 | 14 | if args.randomize_layers: 15 | # Use this when initializing the weights of the last layer 16 | for layer_id in range(args.randomize_layers_num): 17 | for name, param in model.distilbert.transformer.layer[5-layer_id].named_modules(): 18 | if isinstance(param, (nn.Linear, nn.Embedding, nn.LayerNorm)): 19 | model._init_weights(param) 20 | print(f"Randomizing the weights of layer {5-layer_id}-{name}") 21 | 22 | elif model_type == "distilroberta-base": 23 | if args.train_one_layer: 24 | # Use this when fine-tuning only the last layer 25 | raise NameError 26 | 27 | if args.randomize_layers: 28 | # Use this when initializing the weights of the last several layers 29 | for name, param in model.lm_head.named_modules(): 30 | if isinstance(param, (nn.Linear, nn.Embedding, nn.LayerNorm)): 31 | model._init_weights(param) 32 | print(f"Randomizing the weights of layer {name}") 33 | 34 | for layer_id in range(args.randomize_layers_num): 35 | for name, param in model.roberta.encoder.layer[5-layer_id].named_modules(): 36 | if isinstance(param, (nn.Linear, nn.Embedding, nn.LayerNorm)): 37 | model._init_weights(param) 38 | print(f"Randomizing the weights of layer {5-layer_id}-{name}") 39 | 40 | else: 41 | # It looks like randomizing layers is not the right way to go 42 | # We thus do not pursue this direction further 43 | if args.train_one_layer: 44 | raise NameError 45 | 46 | if args.randomize_layers: 47 | raise NameError -------------------------------------------------------------------------------- /utils/utils_ww_results.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | from os import listdir 4 | from os.path import isfile, join 5 | 6 | 7 | def get_ww_layer_feature(ww_result, key): 8 | 9 | """ 10 | Returns a particular key from the results 11 | """ 12 | 13 | with open(ww_result, 'rb') as f: 14 | dict0 = pickle.load(f) 15 | feature_values = [dict0['details'][key][layer] for layer in dict0['details'][key].keys()] 16 | 17 | return feature_values 18 | 19 | 20 | def get_alphas_Ds(ww_result): 21 | 22 | """ 23 | Return both alphas and Ds 24 | """ 25 | 26 | with open(ww_result, 'rb') as f: 27 | dict0 = pickle.load(f) 28 | 29 | assert 'alphas' in dict0 and 'Ds' in dict0, "This ww result does not contain alphas and/or Ds" 30 | 31 | alphas = dict0['alphas'] 32 | Ds = dict0['Ds'] 33 | 34 | return alphas, Ds 35 | 36 | 37 | def get_ww_alpha(ww_result): 38 | 39 | """ 40 | Return one single alpha from the summary dictionary 41 | """ 42 | 43 | with open(ww_result, 'rb') as f: 44 | dict0 = pickle.load(f) 45 | alpha_value = dict0['summary']['alpha'] 46 | 47 | return alpha_value 48 | 49 | 50 | def compute_alpha_from_Ds(alphas, Ds, alpha_threshold=6, softmin=False, all_alpha=False, return_D=False, spectral_norms=None): 51 | 52 | temperature = 100 53 | layers = alphas.keys() 54 | # remove the last 10 points 55 | remove_last = 1 56 | 57 | alpha_layers = [] 58 | D_layers = [] 59 | for layer in layers: 60 | 61 | Ds_this_layer = Ds[layer]#[:-remove_last] 62 | alphas_this_layer = alphas[layer]#[:-remove_last] 63 | 64 | # remove the outliers 65 | mask = alphas_this_layer as we never attend to it, it's produced at the output and we stop 35 | target_sentence_tokens = target_sentence_tokens[0][:-1] 36 | 37 | # Visualize encoder attention weights 38 | for layer_id, encoder_layer in enumerate(encoder.encoder_layers): 39 | mha = encoder_layer.multi_headed_attention # Every encoder layer has 1 MHA module 40 | 41 | # attention_weights shape = (B, NH, S, S), extract 0th batch and loop over NH (number of heads) MHA heads 42 | # S stands for maximum source token-sequence length 43 | attention_weights = mha.attention_weights.cpu().numpy()[0] 44 | 45 | title = f'Encoder layer {layer_id + 1}' 46 | visualize_attention_helper(attention_weights, source_sentence_tokens, title=title) 47 | 48 | # Visualize decoder attention weights 49 | for layer_id, decoder_layer in enumerate(decoder.decoder_layers): 50 | mha_trg = decoder_layer.trg_multi_headed_attention # Extract the self-attention MHA 51 | mha_src = decoder_layer.src_multi_headed_attention # Extract the source attending MHA 52 | 53 | # attention_weights shape = (B, NH, T, T), T stands for maximum target token-sequence length 54 | attention_weights_trg = mha_trg.attention_weights.cpu().numpy()[0] 55 | # shape = (B, NH, T, S), target token representations create queries and keys/values come from the encoder 56 | attention_weights_src = mha_src.attention_weights.cpu().numpy()[0] 57 | 58 | title = f'Decoder layer {layer_id + 1}, self-attention MHA' 59 | visualize_attention_helper(attention_weights_trg, target_sentence_tokens=target_sentence_tokens, title=title) 60 | 61 | title = f'Decoder layer {layer_id + 1}, source-attending MHA' 62 | visualize_attention_helper(attention_weights_src, source_sentence_tokens, target_sentence_tokens, title) -------------------------------------------------------------------------------- /visualization/Best_ETPL_Lambda.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nsfzyzz/Generalization_metrics_for_NLP/e7a991f8baa15a1651a53016cff532d9d99256fc/visualization/Best_ETPL_Lambda.png -------------------------------------------------------------------------------- /visualization/Model_quality_vs_generalization_gap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nsfzyzz/Generalization_metrics_for_NLP/e7a991f8baa15a1651a53016cff532d9d99256fc/visualization/Model_quality_vs_generalization_gap.png -------------------------------------------------------------------------------- /visualization/TPL_vs_PL_mediocre.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nsfzyzz/Generalization_metrics_for_NLP/e7a991f8baa15a1651a53016cff532d9d99256fc/visualization/TPL_vs_PL_mediocre.png -------------------------------------------------------------------------------- /visualization/TPL_vs_PL_mediocre_1.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nsfzyzz/Generalization_metrics_for_NLP/e7a991f8baa15a1651a53016cff532d9d99256fc/visualization/TPL_vs_PL_mediocre_1.pdf -------------------------------------------------------------------------------- /visualization/results/TPL_vs_PL_bad_evals.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nsfzyzz/Generalization_metrics_for_NLP/e7a991f8baa15a1651a53016cff532d9d99256fc/visualization/results/TPL_vs_PL_bad_evals.npy -------------------------------------------------------------------------------- /visualization/results/TPL_vs_PL_good_evals.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nsfzyzz/Generalization_metrics_for_NLP/e7a991f8baa15a1651a53016cff532d9d99256fc/visualization/results/TPL_vs_PL_good_evals.npy -------------------------------------------------------------------------------- /visualization/results/TPL_vs_PL_mediocre_evals.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nsfzyzz/Generalization_metrics_for_NLP/e7a991f8baa15a1651a53016cff532d9d99256fc/visualization/results/TPL_vs_PL_mediocre_evals.npy --------------------------------------------------------------------------------