├── __init__.py ├── tests ├── __init__.py ├── create_model.py ├── perplexity_eval.py ├── data_utils │ └── data_utils.py └── perplexity_dynamiceval.py ├── tint_main ├── utils │ ├── __init__.py │ ├── activation │ │ ├── __init__.py │ │ ├── backward.py │ │ └── forward.py │ ├── linear │ │ ├── __init__.py │ │ └── forward.py │ ├── layernorm │ │ ├── __init__.py │ │ ├── forward.py │ │ └── backward.py │ ├── self_attention │ │ ├── __init__.py │ │ ├── forward.py │ │ └── backward.py │ ├── config.py │ ├── activations.py │ └── all_arguments.py └── tint_creator.py ├── figures └── icl_overview.png ├── icl_eval ├── images │ └── single_FT_or_multiple_FT.png ├── metrics.py ├── scripts │ ├── others │ │ ├── print_result.py │ │ ├── print_result_new.py │ │ ├── print_dynamic_eval_results.py │ │ └── print_result_May10.py │ ├── zero_shot │ │ └── print_result_May10.py │ ├── few_shot │ │ └── print_result_May10.py │ ├── TINT │ │ ├── print_paper_result_zero_shot.py │ │ └── print_paper_result.py │ ├── dynamic_eval │ │ └── print_result.py │ └── run_eval.sh ├── README.md ├── models │ └── modeling_opt.py └── utils.py ├── .gitignore ├── env.yml └── README.md /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tint_main/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tint_main/utils/activation/__init__.py: -------------------------------------------------------------------------------- 1 | from .forward import ActivationForward 2 | from .backward import ActivationBackward 3 | -------------------------------------------------------------------------------- /figures/icl_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abhishekpanigrahi1996/transformer_in_transformer/HEAD/figures/icl_overview.png -------------------------------------------------------------------------------- /tint_main/utils/linear/__init__.py: -------------------------------------------------------------------------------- 1 | from .forward import LinearForward 2 | from .backward import LinearBackward, LinearDescent, Linear_Descent_Backward -------------------------------------------------------------------------------- /tint_main/utils/layernorm/__init__.py: -------------------------------------------------------------------------------- 1 | from .forward import LayerNormForward 2 | from .backward import LayerNormBackward, LayerNormDescent, LayerNormDescent_Backward 3 | -------------------------------------------------------------------------------- /icl_eval/images/single_FT_or_multiple_FT.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abhishekpanigrahi1996/transformer_in_transformer/HEAD/icl_eval/images/single_FT_or_multiple_FT.png -------------------------------------------------------------------------------- /tint_main/utils/self_attention/__init__.py: -------------------------------------------------------------------------------- 1 | from .forward import AttentionForward 2 | from .backward import AttentionBackward, AttentionDescent, AttentionBackward_Descent 3 | 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | not_firstversion/ 7 | *.sh 8 | Constructed_model/* 9 | log* 10 | -------------------------------------------------------------------------------- /tests/create_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | from tint_main.tint_creator import * 4 | from transformers import AutoTokenizer 5 | from datasets import load_dataset 6 | from transformers import HfArgumentParser 7 | from filelock import Timeout, FileLock 8 | from .data_utils.data_utils import * 9 | from tint_main.utils.all_arguments import * 10 | 11 | parser = HfArgumentParser((ModelArguments, ConstructionArguments, DynamicArguments,)) 12 | model_args, construction_args, data_args, = parser.parse_args_into_dataclasses() 13 | 14 | constructed_model, _, _, _ = TinT_Creator(model_args, construction_args) 15 | 16 | 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /icl_eval/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # For question answering 4 | import re 5 | import string 6 | 7 | def normalize_answer(s): 8 | """Lower text and remove punctuation, articles and extra whitespace.""" 9 | 10 | def remove_articles(text): 11 | return re.sub(r'\b(a|an|the)\b', ' ', text) 12 | 13 | def white_space_fix(text): 14 | return ' '.join(text.split()) 15 | 16 | def remove_punc(text): 17 | exclude = set(string.punctuation) 18 | return ''.join(ch for ch in text if ch not in exclude) 19 | 20 | def lower(text): 21 | return text.lower() 22 | 23 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 24 | 25 | def calculate_metric(predictions, metric_name): 26 | if metric_name == "accuracy": 27 | if isinstance(predictions[0].correct_candidate, list): 28 | return np.mean([pred.predicted_candidate in pred.correct_candidate for pred in predictions]) 29 | else: 30 | return np.mean([pred.correct_candidate == pred.predicted_candidate for pred in predictions]) 31 | elif metric_name == "em": 32 | return np.mean([any([normalize_answer(ans) == normalize_answer(pred.predicted_candidate) for ans in pred.correct_candidate]) for pred in predictions]) 33 | elif metric_name == "substring_em": 34 | return np.mean([any([normalize_answer(ans) in normalize_answer(pred.predicted_candidate) for ans in pred.correct_candidate]) for pred in predictions]) 35 | 36 | -------------------------------------------------------------------------------- /icl_eval/scripts/others/print_result.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import json 3 | import numpy as np 4 | import os 5 | 6 | 7 | def print_numpy_result(result, models): 8 | for i in range(len(result)): 9 | # print(models[i], end="\t") 10 | for j in range(result.shape[1]): 11 | print(f"{result[i, j]:.4f}", end="\t") 12 | print() 13 | print() 14 | 15 | out_dir = Path("/n/fs/nlp-mengzhou/space9/out/llm_eval/constructed-lm-results") 16 | 17 | # rows 18 | models = ["gpt2", "constructed-gpt2-l12-ns1-lr1e-06", "constructed-gpt2-l12-ns1-lr1e-05", "constructed-gpt2-l12-ns1-lr1e-04", "opt-125m", 19 | "constructed-opt-l12-ns1-lr1e-06", "constructed-opt-l12-ns1-lr2e-06", "constructed-opt-l12-ns1-lr5e-07", "gpt2-large", "opt-1.3b"] 20 | 21 | shots = [0, 2, 4, 8] 22 | 23 | task = "AGNews" 24 | icl_sfc = False 25 | sfc = False 26 | metric = "accuracy" 27 | 28 | result = np.zeros((len(models), len(shots))) 29 | for i, model in enumerate(models): 30 | for j, shot in enumerate(shots): 31 | if sfc: tag = "-sfc" 32 | elif icl_sfc: tag = "-icl_sfc" 33 | else: tag = "" 34 | if shot == 0: file = out_dir / model / f"{task}-{model}{tag}-sampleeval200-onetrainpereval.json" 35 | else: file = out_dir / model / f"{task}-{model}{tag}-sampleeval200-ntrain{shot}-onetrainpereval.json" 36 | if os.path.exists(file): 37 | re = json.load(open(file, "r")) 38 | result[i, j] = re[metric] 39 | 40 | print(f"{task} {metric}") 41 | print_numpy_result(result, models) 42 | 43 | -------------------------------------------------------------------------------- /icl_eval/scripts/zero_shot/print_result_May10.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import json 3 | import numpy as np 4 | import os 5 | 6 | def read_jsonl(file): 7 | ds = [] 8 | try: 9 | with open(file) as f: 10 | for i, line in enumerate(f): 11 | if line.strip() == "": 12 | continue 13 | d = json.loads(line.strip()) 14 | ds.append(d) 15 | except: 16 | import pdb 17 | pdb.set_trace() 18 | return ds 19 | 20 | def print_numpy_result(result, models): 21 | for i in range(len(result)): 22 | # print(models[i], end="\t") 23 | for j in range(result.shape[1]): 24 | print(f"{result[i, j]:.4f}", end="\t") 25 | print() 26 | print() 27 | 28 | out_dir = Path("/n/fs/nlp-mengzhou/space9/out/llm_eval/zeroshot_eval_May15") 29 | 30 | 31 | shots = 0 32 | task = "MPQA" 33 | metric = "accuracy" 34 | eval_type = "eval" 35 | seed = 1 36 | # result = np.zeros((len(models), len(shots))) 37 | 38 | print_model_name = False 39 | for tmp in [0]: 40 | for icl in ["plain", "icl_sfc"]: 41 | count = 0 42 | model = f"opt-1.3b-ntrain{shots}-{task}-seed{seed}-tmp{tmp}-{icl}-eval" 43 | task_dir = out_dir / model 44 | file = task_dir / "metrics.jsonl" 45 | if os.path.exists(file): 46 | re = read_jsonl(file) 47 | acc = re[0]["accuracy"] 48 | count += 1 49 | else: 50 | print(file) 51 | print(acc, end=" ") 52 | # print(seed_accs) 53 | print() 54 | 55 | 56 | 57 | # print(f"{task} {metric}") 58 | # print_numpy_result(result, models) 59 | 60 | -------------------------------------------------------------------------------- /icl_eval/scripts/few_shot/print_result_May10.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import json 3 | import numpy as np 4 | import os 5 | 6 | def read_jsonl(file): 7 | ds = [] 8 | try: 9 | with open(file) as f: 10 | for i, line in enumerate(f): 11 | if line.strip() == "": 12 | continue 13 | d = json.loads(line.strip()) 14 | ds.append(d) 15 | except: 16 | import pdb 17 | pdb.set_trace() 18 | return ds 19 | 20 | def print_numpy_result(result, models): 21 | for i in range(len(result)): 22 | # print(models[i], end="\t") 23 | for j in range(result.shape[1]): 24 | print(f"{result[i, j]:.4f}", end="\t") 25 | print() 26 | print() 27 | 28 | out_dir = Path("/n/fs/nlp-mengzhou/space9/out/llm_eval/fewshot_eval_May15") 29 | 30 | shots = 32 31 | # task = sys.argv[1] 32 | metric = "accuracy" 33 | eval_type = "eval" 34 | # result = np.zeros((len(models), len(shots))) 35 | 36 | all_seed_accs = [] 37 | for task in ["Subj", "AGNews", "SST2", "CR", "MR", "MPQA", "AmazonPolarity"]: 38 | # print(task) 39 | print_model_name = False 40 | for tmp in [0]: 41 | for icl in ["icl_sfc"]: 42 | seed_accs = [] 43 | seed_models = [] 44 | for seed in range(1, 4): 45 | count = 0 46 | model = f"opt-1.3b-ntrain{shots}-{task}-seed{seed}-tmp{tmp}-{icl}-eval" 47 | task_dir = out_dir / model 48 | file = task_dir / "metrics.jsonl" 49 | if os.path.exists(file): 50 | re = read_jsonl(file) 51 | acc = re[0]["accuracy"] 52 | count += 1 53 | seed_accs.append(acc * 100) 54 | else: 55 | print(file) 56 | all_seed_accs.append(np.array(seed_accs)) 57 | acc = np.std(seed_accs) 58 | if print_model_name: 59 | print(" ".join(seed_models), end=" ") 60 | else: 61 | print(acc, end=" ") 62 | # print(seed_accs) 63 | print() 64 | 65 | all_seed_accs = np.stack(all_seed_accs, axis=1) 66 | print("sigma:", all_seed_accs.mean(1).std()) 67 | 68 | # print(f"{task} {metric}") 69 | # print_numpy_result(result, models) 70 | 71 | -------------------------------------------------------------------------------- /icl_eval/scripts/others/print_result_new.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import json 3 | import numpy as np 4 | import os 5 | from cus_data.dataloader import read_jsonl 6 | 7 | 8 | def print_numpy_result(result, models): 9 | for i in range(len(result)): 10 | # print(models[i], end="\t") 11 | for j in range(result.shape[1]): 12 | print(f"{result[i, j]:.4f}", end="\t") 13 | print() 14 | print() 15 | 16 | out_dir = Path("/scratch/gpfs/mengzhou/space9/out/llm_eval/dynamic_icl_eval_May8_test") 17 | 18 | # rows 19 | models = ["constructed-opt-backprop-6-nstep-2-lr-1e-04", "constructed-opt-backprop-6-nstep-2-lr-1e-05", "constructed-opt-backprop-6-nstep-2-lr-1e-06", 20 | "constructed-opt-backprop-12-nstep-1-lr-1e-04", "constructed-opt-backprop-12-nstep-1-lr-1e-05", "constructed-opt-backprop-12-nstep-1-lr-1e-06", 21 | "constructed-opt-backprop-3-nstep-4-lr-1e-04", "constructed-opt-backprop-3-nstep-4-lr-1e-05", "constructed-opt-backprop-3-nstep-4-lr-1e-06"] 22 | 23 | shots = 16 24 | task = "MR" 25 | metric = "accuracy" 26 | eval_type = "eval" 27 | 28 | # result = np.zeros((len(models), len(shots))) 29 | 30 | for single_FT in [True, False]: 31 | for label_type in ["context", "label_only"]: 32 | for icl in ["plain", "icl_sfc"]: 33 | seed_accs = [] 34 | for seed in range(1, 3): 35 | acc = np.zeros((len(models), 2)) 36 | for i, model in enumerate(models): 37 | model_dir = out_dir / model 38 | lr = "-".join(model.split("-")[-2:]) 39 | layer = model.split("-")[3] 40 | task_dir = model_dir / f"ntrain{shots}-{task}-lr{lr}-seed{seed}-single_FT{single_FT}-layers{layer}-{label_type}-{icl}-{eval_type}" 41 | file = task_dir / "metrics.jsonl" 42 | if os.path.exists(file): 43 | re = read_jsonl(file) 44 | acc[i, 0] = re[0]["accuracy"] 45 | # acc[i, 1] = re[1]["accuracy"] 46 | else: 47 | print(file) 48 | seed_acc = acc.max(axis=0) 49 | seed_accs.append(seed_acc) 50 | acc = np.stack(seed_accs).mean(axis=0) 51 | print(" ".join([str(a) for a in acc]), end = " ") 52 | # print(seed_accs) 53 | print() 54 | 55 | 56 | 57 | # print(f"{task} {metric}") 58 | # print_numpy_result(result, models) 59 | 60 | -------------------------------------------------------------------------------- /icl_eval/README.md: -------------------------------------------------------------------------------- 1 | # In-context Learning Evaluation 2 | 3 | We demonstrate the process of conducting in-context learning experiments with our TINT or other models for comparison. We support running experiments with following three modes, `standard`, `dynamic evaluation` and `TINT`. `TINT` effectively approximates `dynamical evaluation` with only forward passes with no backward pass. `standard` is simply the standard in-context learning setting. 4 | 5 | We use `scripts/run_eval.sh` as our main script and you can pass arguments to the script to enable different types of evaluations. 6 | 7 | ``` 8 | bash scripts/run_eval.sh [ARGUMENTS] 9 | ``` 10 | 11 | We explain the arguments as follows: 12 | 13 | `model_path`: path of the model, it could be a huggingface model path, e.g., `facebook/opt-125m`, or a TINT model. 14 | `model_name`: name of the model and can be freely decided by users, for logging purposes. 15 | `output_dir`: output directory. 16 | `train_num`: number of training examples for in-context learning, dynamic evaluation, or internal training for TINT. Supports both zero-shot (=0) and few-shot (>0) settings. 17 | `task`: task name and currently supports `AmazonPolarity`, `YelpPolarity`, `AGNews`, `SST2`, `MR`, `CR`, `MPQA`, and `Subj`. 18 | `lr`: learning rate for training. We grid search a learning rate of 1e-3, 1e-4, 1e-5 and 1e-6 in our experiments. 19 | `train_set_seed`: seed for selecting training examples. 20 | `single_FT`: whether training with a concatenation of in-context learning examples or training with seperate examples in each input. Please find the image below for illustration. 21 | `num_train_layers`: number of training layers (top layers). We experimented with 3, 6, 12 training layers. 22 | `label_type`: whether we use the loss of all the tokens in the context (`context`) or only use loss on the label words (`label_only`). 23 | `sfc_type`: whether or not to use the [surface form competition algorithm](https://arxiv.org/abs/2104.08315) to calibrate. 24 | `plain`: we do not use any calibration method. 25 | `icl_sfc`: we use the sfc algorithm for calibration. 26 | `eval_type`: `eval` or `test`, we use `eval` as the default setting. 27 | `template`: we use default template `0`, uses could define templates in `templates.py` and add it to the specific task in `tasks.py`. 28 | `test`: if `True`, the evaluation enters a test mode for debugging. 29 | `dynamic_eval`: `True` for `dynamic evaluation` and `TINT` mode, and `False` for `standard` mode. 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /icl_eval/scripts/others/print_dynamic_eval_results.py: -------------------------------------------------------------------------------- 1 | model_name = "opt-125m" 2 | task = "MRPC" 3 | eval_type = "valid" 4 | 5 | # MRPC-train8-eval200-lr1e-03-labelcontext-dynamicFalse 6 | 7 | import json 8 | from pathlib import Path 9 | import numpy as np 10 | import os 11 | def read_jsonl(file): 12 | lines = open(file, "r").readlines() 13 | d = [] 14 | for line in lines: 15 | d.append(json.loads(line.strip())) 16 | return d 17 | 18 | output_dir = Path(f"/n/fs/nlp-mengzhou/space9/out/llm_eval/dynamic_icl_eval_May3") 19 | # for sfc_type in ["plain", "icl", "icl_sfc"]: 20 | # for shot in [16]: 21 | # for single_FT in [True, False]: 22 | # for label_type in ["label_only", "context"]: 23 | # for num_train_layers in [3, 6, 12]: 24 | # accuracies_all_seeds = [] 25 | # for seed in [1, 2, 3]: 26 | # accuracies = [] 27 | # for lr in ["1e-05", "1e-04", "1e-03", "1e-02"]: 28 | # file_name=f"{model_name}-ntrain{shot}-{task}-lr{lr}-seed{seed}-single_FT{single_FT}-layers{num_train_layers}-{label_type}-{sfc_type}-{eval_type}.json" 29 | # # file_name=f"{model_name}-ntrain{shot}-{task}-seed{seed}-{sfc_type}-{eval_type}.json" 30 | # file = output_dir / file_name 31 | # if not os.path.exists(file): 32 | # continue 33 | # d = read_jsonl(file) 34 | # accuracies.append(np.array([dd["accuracy"] for dd in d])) 35 | # accuracies = np.stack(accuracies).max(axis=0).tolist() 36 | # accuracies_all_seeds.append(accuracies) 37 | # accuracies_all_seeds = np.stack(accuracies_all_seeds) 38 | # accuracies = np.mean(accuracies_all_seeds, axis=0) 39 | # re = " ".join(str(acc) for acc in accuracies) 40 | # print(re) 41 | 42 | accuracies = [] 43 | shot=16 44 | for sfc_type in ["plain", "icl", "icl_sfc"]: 45 | for seed in [1, 2, 3]: 46 | file = output_dir / f"{model_name}-ntrain{shot}-{task}-seed{seed}-{sfc_type}-{eval_type}.json" 47 | if not os.path.exists(file): 48 | continue 49 | d = read_jsonl(file) 50 | accuracies.append(np.array([dd["accuracy"] for dd in d])) 51 | accuracies = np.stack(accuracies) 52 | accuracies = np.mean(accuracies, axis=0).tolist() 53 | re = " ".join(str(acc) for acc in accuracies) 54 | print(re) 55 | 56 | 57 | 58 | -------------------------------------------------------------------------------- /icl_eval/scripts/TINT/print_paper_result_zero_shot.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import json 3 | import numpy as np 4 | import os 5 | 6 | def read_jsonl(file): 7 | ds = [] 8 | try: 9 | with open(file) as f: 10 | for i, line in enumerate(f): 11 | d = json.loads(line.strip()) 12 | ds.append(d) 13 | except: 14 | import pdb 15 | pdb.set_trace() 16 | return ds 17 | 18 | def print_numpy_result(result, models): 19 | for i in range(len(result)): 20 | # print(models[i], end="\t") 21 | for j in range(result.shape[1]): 22 | print(f"{result[i, j]:.4f}", end="\t") 23 | print() 24 | print() 25 | 26 | out_dir = Path("/scratch/gpfs/smalladi/mengzhou/out/llm_eval/dynamic_icl_eval_zero-shot") 27 | 28 | # rows 29 | def get_models(label_type): 30 | models = [] 31 | for layer in [12, 6, 3]: 32 | steps = 12 // layer 33 | if label_type == "context": lrs = ["1e-05", "1e-06", "1e-07"] 34 | else: lrs = ["1e-03", "1e-04", "1e-05"] 35 | for lr in lrs: 36 | models.append(f"constructed-opt-backprop-{layer}-nstep-{steps}-lr-{lr}") 37 | return models 38 | 39 | shots = 0 40 | task = "AmazonPolarity" 41 | metric = "accuracy" 42 | eval_type = "eval" 43 | # result = np.zeros((len(models), len(shots))) 44 | icl = "plain" 45 | tmp = 0 46 | seed = 1 47 | 48 | print_model_name = False 49 | for label_type in ["context"]: 50 | for single_FT in [True]: 51 | seed_accs = [] 52 | seed_models = [] 53 | sub_models = get_models(label_type) 54 | acc = np.zeros((len(sub_models), 2)) 55 | for i, model in enumerate(sub_models): 56 | model_dir = out_dir / model 57 | lr = "-".join(model.split("-")[-2:]) 58 | layer = model.split("-")[3] 59 | task_dir = model_dir / f"ntrain{shots}-{task}-lr{lr}-seed{seed}-single_FT{single_FT}-layers{layer}-{label_type}-tmp{tmp}-{icl}-{eval_type}" 60 | file = task_dir / "metrics.jsonl" 61 | if os.path.exists(file): 62 | re = read_jsonl(file) 63 | acc[i, 0] = re[0]["accuracy"] 64 | # acc[i, 1] = re[1]["accuracy"] 65 | else: 66 | pass 67 | # print(file) 68 | model_name = sub_models[acc[:, 0].argmax(axis=0)] 69 | acc = acc[:, 0].max(axis=0) 70 | if print_model_name: 71 | print(" ".join(seed_models)) 72 | else: 73 | print(acc) 74 | # print(seed_accs) 75 | 76 | 77 | 78 | # print(f"{task} {metric}") 79 | # print_numpy_result(result, models) 80 | 81 | -------------------------------------------------------------------------------- /icl_eval/scripts/dynamic_eval/print_result.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import json 3 | import numpy as np 4 | import os 5 | 6 | def read_jsonl(file): 7 | ds = [] 8 | try: 9 | with open(file) as f: 10 | for i, line in enumerate(f): 11 | if line.strip() == "": 12 | continue 13 | d = json.loads(line.strip()) 14 | ds.append(d) 15 | except: 16 | import pdb 17 | pdb.set_trace() 18 | return ds 19 | 20 | def read_json(file): 21 | return json.load(open(file)) 22 | 23 | def print_numpy_result(result, models): 24 | for i in range(len(result)): 25 | # print(models[i], end="\t") 26 | for j in range(result.shape[1]): 27 | print(f"{result[i, j]:.4f}", end="\t") 28 | print() 29 | print() 30 | 31 | out_dir = Path("/n/fs/nlp-mengzhou/space9/out/llm_eval/dynamic_icl_eval_May10") 32 | 33 | 34 | shots = 32 35 | task = "Subj" 36 | metric = "accuracy" 37 | eval_type = "eval" 38 | # result = np.zeros((len(models), len(shots))) 39 | 40 | print_model_name = False 41 | all_accs = [] 42 | for task in ["Subj", "AGNews", "SST2", "CR", "MR", "MPQA", "AmazonPolarity"]: 43 | print(task) 44 | for tmp in [0, 1]: 45 | for single_FT in [True, False]: 46 | for label_type in ["context", "label_only"]: 47 | for icl in ["plain"]: 48 | seed_accs = [] 49 | seed_models = [] 50 | for seed in range(1, 4): 51 | acc = np.zeros((9, 2)) 52 | count = 0 53 | for lr in ["1e-03", "1e-04", "1e-05"]: 54 | for layer in [3, 6, 12]: 55 | model = f"opt-125m-ntrain{shots}-{task}-lr{lr}-seed{seed}-single_FT{single_FT}-layers{layer}-{label_type}-tmp{tmp}-{icl}-{eval_type}" 56 | task_dir = out_dir / model 57 | file = task_dir / "metrics.jsonl" 58 | if os.path.exists(file): 59 | re = read_json(file) 60 | acc[count, 0] = re["accuracy"] 61 | else: 62 | print(file) 63 | count += 1 64 | model_name = f"lr{lr}-seed{seed}" 65 | seed_models.append(model_name) 66 | seed_acc = acc.max(axis=0) 67 | seed_accs.append(seed_acc * 100) 68 | acc = np.stack(seed_accs).std(axis=0) 69 | if print_model_name: 70 | print(" ".join(seed_models), end=" ") 71 | else: 72 | print(" ".join([str(a) for a in acc[0:1]]), end = " ") 73 | # print(seed_accs) 74 | print() 75 | all_accs.append(seed_accs) 76 | 77 | all_accs = np.stack(all_accs) 78 | print("std:", all_accs[:, :, 0].mean(0).std()) 79 | 80 | 81 | 82 | # print(f"{task} {metric}") 83 | # print_numpy_result(result, models) 84 | 85 | -------------------------------------------------------------------------------- /tint_main/utils/layernorm/forward.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import math 4 | import os 5 | from dataclasses import dataclass 6 | from typing import Optional, Tuple, Union 7 | 8 | import torch 9 | import torch.utils.checkpoint 10 | from torch import nn 11 | from torch.cuda.amp import autocast 12 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 13 | from ..modules import * 14 | from ..linear import * 15 | 16 | 17 | 18 | 19 | class LayerNormForward(nn.Module): 20 | def __init__(self, config, din, use_softmax, memory_index=-1): 21 | super(LayerNormForward, self).__init__() 22 | assert use_softmax==False ,\ 23 | "Currently I only use linear attention in this module" 24 | 25 | 26 | 27 | self.linear=LinearForward ( config, din=din, dout=din, use_softmax=use_softmax, memory_index=memory_index ) 28 | self.din=din 29 | self.epsilon = config.epsilon 30 | self.config=config 31 | self.memory_index = memory_index 32 | 33 | #w acts like a gate to decide what portion of the embedding we apply layernorm on 34 | self.w = torch.zeros (( 1, 1, config.hidden_size )) 35 | self.w [:, :, :din] += config.gate_scale 36 | self.gate = torch.nn.Tanh() 37 | 38 | 39 | #mask out normalization on prefixes 40 | self.normalization_gates = Gates (config) 41 | #Initialize Gates 42 | #Ignore the changes for the prefixes! 43 | #w, u, v, w_bias, u_bias, v_bias 44 | w = torch.zeros((1, 2*config.hidden_size)) 45 | u = torch.zeros((1, 2*config.hidden_size)) 46 | v = torch.zeros((1, 2*config.position_dim)) 47 | w_bias = torch.zeros(2) 48 | u_bias = torch.zeros(2) 49 | v_bias = torch.zeros(2) 50 | 51 | #Input Gate is 1 on prefixes and 0 for non-prefixes 52 | v [0, config.seq_length: config.position_dim] = config.gate_scale * torch.ones(config.num_prefixes) 53 | 54 | 55 | #Change Gate is 0 on prefixes and 1 for non-prefixes 56 | v [0, config.position_dim+config.seq_length: 2*config.position_dim] = -config.gate_scale * torch.ones(config.num_prefixes) 57 | v_bias [1] += config.gate_scale 58 | 59 | self.normalization_gates.initialize_weights (w, u, v, w_bias, u_bias, v_bias) 60 | 61 | 62 | 63 | 64 | def forward(self, hidden_states, position_embeddings): 65 | 66 | 67 | weights = self.gate ( self.w ).to(hidden_states.device) 68 | mean = torch.sum(hidden_states * weights, dim=-1, keepdim=True) / torch.sum(weights, dim=-1, keepdim=True) 69 | 70 | var = ( self.epsilon + torch.sum( (weights * (hidden_states - mean)) ** 2, dim=-1, keepdim=True) / torch.sum(weights, dim=-1, keepdim=True) ) ** 0.5 71 | 72 | normalized_states = (hidden_states - mean) / var 73 | normalized_states = weights * normalized_states + (1. - weights) * hidden_states 74 | 75 | gated_output = self.normalization_gates.forward (hidden_states, normalized_states, position_embeddings) 76 | 77 | output = self.linear.forward ( gated_output, position_embeddings ) 78 | 79 | 80 | #store [(x-\mu)/\sigma, x] for memory in backward pass 81 | assert torch.sum( torch.absolute( output[:, self.config.num_prefixes:, self.memory_index+self.din:]) ).item() < 1e-10,\ 82 | "Memory portion not empty!" 83 | output[:, self.config.num_prefixes:, self.memory_index+self.din: self.memory_index+2*self.din] += hidden_states[:, self.config.num_prefixes:, :self.din] 84 | 85 | return output 86 | 87 | 88 | -------------------------------------------------------------------------------- /icl_eval/scripts/TINT/print_paper_result.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import json 3 | import numpy as np 4 | import os 5 | 6 | def read_jsonl(file): 7 | ds = [] 8 | try: 9 | with open(file) as f: 10 | for i, line in enumerate(f): 11 | d = json.loads(line.strip()) 12 | ds.append(d) 13 | except: 14 | import pdb 15 | pdb.set_trace() 16 | return ds 17 | 18 | def print_numpy_result(result, models): 19 | for i in range(len(result)): 20 | # print(models[i], end="\t") 21 | for j in range(result.shape[1]): 22 | print(f"{result[i, j]:.4f}", end="\t") 23 | print() 24 | print() 25 | 26 | out_dir = Path("/scratch/gpfs/smalladi/mengzhou/out/llm_eval/dynamic_icl_eval_May10") 27 | 28 | # rows 29 | def get_models(label_type, layer, steps): 30 | models = [] 31 | if label_type == "context": lrs = ["1e-05", "1e-06", "1e-07"] 32 | else: lrs = ["1e-03", "1e-04", "1e-05"] 33 | for lr in lrs: 34 | models.append(f"constructed-opt-backprop-{layer}-nstep-{steps}-lr-{lr}") 35 | return models 36 | 37 | shots = 32 38 | metric = "accuracy" 39 | eval_type = "eval" 40 | # result = np.zeros((len(models), len(shots))) 41 | icl = "icl_sfc" 42 | tmp = 0 43 | 44 | print_model_name = False 45 | 46 | all_res = [] 47 | print_res = [] 48 | for task in ["Subj", "AGNews", "SST2", "CR", "MR", "MPQA", "AmazonPolarity"]: 49 | print_re = [] 50 | task_re = [] 51 | for label_type in ["context", "label_only"]: 52 | for single_FT in [True, False]: 53 | for layer in [12, 6, 3]: 54 | step = 12 // layer 55 | sub_models = get_models(label_type, layer, step) 56 | seed_accs = [] 57 | seed_models = [] 58 | for seed in range(1, 4): 59 | acc = np.zeros((len(sub_models), 2)) 60 | for i, model in enumerate(sub_models): 61 | model_dir = out_dir / model 62 | lr = "-".join(model.split("-")[-2:]) 63 | layer = model.split("-")[3] 64 | task_dir = model_dir / f"ntrain{shots}-{task}-lr{lr}-seed{seed}-single_FT{single_FT}-layers{layer}-{label_type}-tmp{tmp}-{icl}-{eval_type}" 65 | file = task_dir / "metrics.jsonl" 66 | if os.path.exists(file): 67 | re = read_jsonl(file) 68 | acc[i, 0] = re[0]["accuracy"] 69 | # acc[i, 1] = re[1]["accuracy"] 70 | else: 71 | pass 72 | # print(file) 73 | model_name = sub_models[acc[:, 0].argmax(axis=0)] 74 | seed_models.append(model_name) 75 | seed_acc = acc.max(axis=0) 76 | seed_accs.append(seed_acc * 100) 77 | task_re.append(seed_acc * 100) 78 | acc = np.stack(seed_accs).std(axis=0) 79 | if print_model_name: 80 | print(" ".join(seed_models)) 81 | else: 82 | print(" ".join([str(a) for a in acc[0:1]])) 83 | print_res.append(print_re) 84 | # print(seed_accs) 85 | all_res.append(task_re) 86 | 87 | print_res = np.array(print_res).transpose() 88 | for line in print_res: 89 | print(' '.join(map(str, line))) 90 | 91 | # print(f"{task} {metric}") 92 | # print_numpy_result(result, models) 93 | 94 | import pdb; pdb.set_trace() 95 | all_res = np.array(all_res) 96 | ree = all_res.reshape(7, 12, 3, -1)[:, :, :, 0] 97 | 98 | for n in ree.mean(0).std(-1): print(n) -------------------------------------------------------------------------------- /tint_main/utils/activation/backward.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import math 4 | import os 5 | from dataclasses import dataclass 6 | from typing import Optional, Tuple, Union 7 | 8 | import torch 9 | import torch.utils.checkpoint 10 | from torch import nn 11 | from torch.cuda.amp import autocast 12 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 13 | from ..modules import * 14 | from ..linear import * 15 | from ..activations import * 16 | 17 | 18 | 19 | class ActivationBackward (nn.Module): 20 | def __init__ (self, config, din, input_projection=None, projection_matrix=None, memory_index=-1, retain_og_act=False): 21 | super(ActivationBackward, self).__init__() 22 | 23 | 24 | 25 | assert memory_index == -1 or memory_index >= din, \ 26 | "memory crosses current signal" 27 | 28 | 29 | 30 | self.epsilon = config.epsilon 31 | self.memory_index = memory_index 32 | self.config = config 33 | 34 | head_dim = config.hidden_size // config.num_attention_heads 35 | self.c_fc = Conv2D(config.num_attention_heads, head_dim, transpose=True, use_einsum=self.config.use_einsum) 36 | self.proj_fc = Conv2D(config.num_attention_heads, head_dim, transpose=True, use_einsum=self.config.use_einsum) 37 | 38 | self.config = config 39 | self.din = din 40 | self.act = ACT2FN[config.activation_function] 41 | 42 | 43 | c_fc_init = torch.zeros((head_dim, config.num_attention_heads, config.num_attention_heads)) 44 | c_proj_init = torch.zeros((head_dim, config.num_attention_heads, config.num_attention_heads)) 45 | 46 | 47 | 48 | assert din % head_dim == 0, \ 49 | " 'din' should be a multiple of head_dim! " 50 | 51 | num_partitions = din // head_dim 52 | 53 | 54 | 55 | assert self.memory_index % head_dim == 0, \ 56 | "Memory should start at a multiple of head_dim!" 57 | 58 | mem_head_start = self.memory_index // head_dim 59 | 60 | 61 | start_shift = 0 62 | c_fc_init[:, start_shift: start_shift + num_partitions, start_shift: start_shift + num_partitions] = 1. / config.scale_embeddings * torch.eye(num_partitions) 63 | c_fc_init[:, start_shift: start_shift + num_partitions, mem_head_start: mem_head_start + num_partitions] = torch.eye(num_partitions) 64 | 65 | #pass x as well 66 | c_fc_init[:, mem_head_start: mem_head_start + num_partitions, mem_head_start: mem_head_start + num_partitions] = torch.eye(num_partitions) 67 | 68 | 69 | #Compute GeLU(x + 1/N \nabla y) - GeLU(x) 70 | 71 | c_proj_init[:, start_shift: start_shift + num_partitions, start_shift: start_shift + num_partitions] = config.scale_embeddings * torch.eye(num_partitions) 72 | c_proj_init[:, start_shift: start_shift + num_partitions, mem_head_start: mem_head_start + num_partitions] = -config.scale_embeddings * torch.eye(num_partitions) 73 | 74 | #retain Act (x) for future purposes? 75 | if retain_og_act: 76 | c_proj_init[:, mem_head_start: mem_head_start + num_partitions, mem_head_start: mem_head_start + num_partitions] = torch.eye(num_partitions) 77 | 78 | 79 | with torch.no_grad(): 80 | self.c_fc.weight.copy_(torch.swapaxes(c_fc_init, axis0=-1, axis1=-2)) 81 | self.proj_fc.weight.copy_(torch.swapaxes(c_proj_init, axis0=-1, axis1=-2)) 82 | 83 | 84 | 85 | def forward(self, hidden_states, position_embeddings, attention_mask=None, activation_memory=None, icl_mask=None): 86 | output = self.proj_fc ( self.act( self.c_fc(hidden_states) ) ) 87 | return output 88 | 89 | -------------------------------------------------------------------------------- /icl_eval/scripts/run_eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | #SBATCH --job-name=icl 3 | #SBATCH --time=1:00:00 4 | #SBATCH --gres=gpu:1 5 | #SBATCH --mem=30gb 6 | 7 | # $model_path $model_name $output_dir $ntrain $task $lr $dynamic_eval $seed $sf $layer $sfc_type $eval_type $label_type 8 | 9 | script=$base_dir/llm_eval_construction/run.py 10 | model_path=$1 11 | model_name=$2 12 | output_dir=$3 13 | ntrain=$4 14 | task=$5 15 | lr=$6 16 | train_set_seed=$7 17 | single_FT=$8 18 | num_train_layers=${9} 19 | label_type=${10} 20 | sfc_type=${11} 21 | eval_type=${12} 22 | template=${13} 23 | test=${14} 24 | num_epochs=$(( 12 / $num_train_layers )) 25 | dynamic_eval=${15} 26 | 27 | # change the path of the constructed models 28 | model_path=/scratch/gpfs/ap34/icl-as-ft/Dynamic_initialization/Constructed_models_withlnupdate_fastlin/model_opt_backprop_${num_train_layers}_nstep_${num_epochs}_lr_${lr} 29 | model_name=constructed-opt-backprop-${num_train_layers}-nstep-${num_epochs}-lr-${lr} 30 | 31 | restrict_attention_demonstration=False # whether to use restrict the attention of each demonstration to itself in icl experiments 32 | position_modify_demonstration=False # whether to change the position ids of each demonstration to top; only relevant if restrict_attention_demonstration is true 33 | 34 | if [[ $single_FT == True ]]; then 35 | restrict_attention_demonstration=True 36 | position_modify_demonstration=True 37 | else 38 | restrict_attention_demonstration=False 39 | position_modify_demonstration=False 40 | fi 41 | 42 | if [[ $label_type == "label_only" ]]; then loss_label_only=True; 43 | else loss_label_only=False; fi 44 | 45 | sub_dir=${model_name}/ntrain${ntrain}-${task}-lr${lr}-seed${train_set_seed}-single_FT${single_FT}-layers${num_train_layers}-${label_type}-tmp${template}-${sfc_type}-${eval_type} 46 | output_dir=$output_dir/$sub_dir 47 | mkdir -p $output_dir 48 | 49 | echo "********** inside single script **********" 50 | echo model_path=$model_path 51 | echo model_name=$model_name 52 | echo output_dir=$output_dir 53 | echo ntrain=$ntrain 54 | echo task=$task 55 | echo train_set_seed=$train_set_seed 56 | echo single_FT=$single_FT 57 | echo label_type=$label_type 58 | echo sfc_type=${sfc_type} 59 | echo eval_type=${eval_type} 60 | echo restrict_attention_demonstration=${restrict_attention_demonstration} 61 | echo position_modify_demonstration=${position_modify_demonstration} 62 | echo template=${template} 63 | echo "Outputting to $output_dir/${file_name}" 64 | echo "********** inside single script **********" 65 | 66 | 67 | if [[ -f $output_dir/metrics.jsonl ]]; then 68 | echo "File exists: $output_dir/metrics.jsonl" 69 | exit 0 70 | fi 71 | 72 | if [[ $test == True ]]; then num_eval=4; else num_eval=200; fi 73 | 74 | cmd="python3 $script \ 75 | --task_name $task \ 76 | --num_train ${n_train} \ 77 | --num_eval ${num_eval} \ 78 | --model_path $model_path \ 79 | --model_name $model_name \ 80 | --load_float16 True \ 81 | --pruned False \ 82 | --output_dir=$output_dir \ 83 | --train_set_seed $train_set_seed \ 84 | --loss_type $label_type \ 85 | --result_file $output_dir/metrics.jsonl \ 86 | --exclusive_training True \ 87 | --single_FT ${single_FT} \ 88 | --num_train_layers ${num_train_layers} \ 89 | --num_epochs ${num_epochs} \ 90 | --eval_set_seed 0 \ 91 | --restrict_attention_demonstration ${restrict_attention_demonstration} \ 92 | --position_modify_demonstration ${position_modify_demonstration} \ 93 | --loss_label_only ${loss_label_only} \ 94 | --test_mode False \ 95 | --template_id $template" 96 | 97 | if [[ $eval_type == "test" ]]; then 98 | cmd="$cmd --test_set_seed 10" 99 | fi 100 | if [[ $sfc_type == "sfc" ]]; then 101 | cmd="$cmd --sfc True" 102 | fi 103 | if [[ $sfc_type == "icl_sfc" ]]; then 104 | cmd="$cmd --icl_sfc True" 105 | fi 106 | $cmd 2>&1 | tee $output_dir/log.txt 107 | -------------------------------------------------------------------------------- /icl_eval/scripts/others/print_result_May10.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import json 3 | import numpy as np 4 | import os 5 | 6 | def read_jsonl(file): 7 | ds = [] 8 | try: 9 | with open(file) as f: 10 | for i, line in enumerate(f): 11 | d = json.loads(line.strip()) 12 | ds.append(d) 13 | except: 14 | import pdb 15 | pdb.set_trace() 16 | return ds 17 | 18 | def print_numpy_result(result, models): 19 | for i in range(len(result)): 20 | # print(models[i], end="\t") 21 | for j in range(result.shape[1]): 22 | print(f"{result[i, j]:.4f}", end="\t") 23 | print() 24 | print() 25 | 26 | out_dir = Path("/scratch/gpfs/smalladi/mengzhou/out/llm_eval/dynamic_icl_eval_May10") 27 | 28 | # rows 29 | def get_models(label_type): 30 | models = [] 31 | for layer in [12, 6, 3]: 32 | steps = 12 // layer 33 | if label_type == "context": lrs = ["1e-05", "1e-06", "1e-07"] 34 | else: lrs = ["1e-03", "1e-04", "1e-05"] 35 | for lr in lrs: 36 | models.append(f"constructed-opt-backprop-{layer}-nstep-{steps}-lr-{lr}") 37 | return models 38 | 39 | shots = 32 40 | task = "AmazonPolarity" 41 | metric = "accuracy" 42 | eval_type = "eval" 43 | # result = np.zeros((len(models), len(shots))) 44 | 45 | print_model_name = False 46 | 47 | print_res = [] 48 | all_accs = [] 49 | for task in ["Subj", "AGNews", "SST2", "CR", "MR", "MPQA", "AmazonPolarity"]: 50 | task_acc = [] 51 | print(task) 52 | print_re = [] 53 | for tmp in [0]: 54 | for single_FT in [True, False]: 55 | for label_type in ["context", "label_only"]: 56 | sub_models = get_models(label_type) 57 | for icl in ["icl_sfc"]: 58 | seed_accs = [] 59 | seed_models = [] 60 | all_seed_accs = [] 61 | for seed in range(1, 4): 62 | acc = np.zeros((len(sub_models), 2)) 63 | for i, model in enumerate(sub_models): 64 | model_dir = out_dir / model 65 | lr = "-".join(model.split("-")[-2:]) 66 | layer = model.split("-")[3] 67 | task_dir = model_dir / f"ntrain{shots}-{task}-lr{lr}-seed{seed}-single_FT{single_FT}-layers{layer}-{label_type}-tmp{tmp}-{icl}-{eval_type}" 68 | file = task_dir / "metrics.jsonl" 69 | if os.path.exists(file): 70 | re = read_jsonl(file) 71 | acc[i, 0] = re[0]["accuracy"] 72 | # acc[i, 1] = re[1]["accuracy"] 73 | else: 74 | pass 75 | # print(file) 76 | model_name = sub_models[acc[:, 0].argmax(axis=0)] 77 | seed_models.append(model_name) 78 | seed_acc = acc.max(axis=0) 79 | seed_accs.append(seed_acc * 100) 80 | all_seed_accs.append(seed_acc[0]) 81 | task_acc.append(all_seed_accs) 82 | acc = np.stack(seed_accs).std(axis=0) 83 | print_re.append(acc[0]) 84 | if print_model_name: 85 | print(" ".join(seed_models), end=" ") 86 | else: 87 | print(" ".join([str(a) for a in acc[0:1]]), end = " ") 88 | # print(seed_accs) 89 | print() 90 | print_res.append(print_re) 91 | all_accs.append(task_acc) 92 | # print(seed_accs) 93 | 94 | print_res = np.array(print_res).transpose() 95 | for line in print_res: 96 | print(' '.join(map(str, line))) 97 | 98 | all_accs = np.array(all_accs) # num_tasks * num_settings * num_seeds 99 | import pdb; pdb.set_trace() 100 | print((np.array(all_accs) * 100).max(1).mean(0).std()) 101 | 102 | 103 | # print(f"{task} {metric}") 104 | # print_numpy_result(result, models) 105 | 106 | -------------------------------------------------------------------------------- /tint_main/utils/config.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Construct_config: 5 | def __init__(self, model_config, construction_args): 6 | #Sequence length of auxiliary model 7 | self.seq_length = construction_args.seq_length 8 | #Sequence length of auxiliary model + Number of prefix tokens 9 | self.position_dim = construction_args.seq_length + construction_args.num_prefixes 10 | #Number of prefix tokens 11 | self.num_prefixes = construction_args.num_prefixes 12 | #Number of attention heads in TinT 13 | self.num_attention_heads = construction_args.num_attention_heads 14 | #Scale_embeddings (for softmax attention, unnecessary argument for now) 15 | self.scale_embeddings = construction_args.scale_embeddings 16 | #inner_lr for dynamic evaluation 17 | self.inner_lr = construction_args.inner_lr 18 | #Scaling for sigmoid gates to behave as hard gates 19 | self.gate_scale = construction_args.gate_scale 20 | #Embedding dimension of TinT 21 | self.hidden_size = construction_args.hidden_size 22 | #A max bound on the final sequence length, for initialization purposes 23 | self.max_position_embeddings = construction_args.max_position_embeddings 24 | 25 | #Following three arguments will be useful only when we pre-train 26 | #Dropout rate of embedding 27 | self.embd_pdrop=construction_args.embd_pdrop 28 | #Dropout rate of attention 29 | self.attn_pdrop=construction_args.attn_pdrop 30 | #Dropout rate of residual connection 31 | self.resid_pdrop=construction_args.resid_pdrop 32 | 33 | #Auxiliary's activation function 34 | self.activation_function=construction_args.activation_function 35 | #Auxiliary's error term for layernorm 36 | self.epsilon=construction_args.epsilon 37 | #Whether Attention score scales before softmax, as determined by Auxiliary model 38 | self.scale_attn_weights=construction_args.scale_attn_weights 39 | #Attention score appropriate scaling, as determined by Auxiliary model 40 | self.initial_scale=np.sqrt( (construction_args.hidden_size / model_config.hidden_size) * (model_config.num_attention_heads / construction_args.num_attention_heads) ) 41 | 42 | #Number of layers to involve in SGD 43 | self.n_simulation_layers=construction_args.n_simulation_layers 44 | #Number of SGD steps 45 | self.n_forward_backward=construction_args.n_forward_backward 46 | #Unnecessary argument, was used for debugging 47 | self.n_debug_layers=construction_args.n_debug_layers 48 | #Unnecessary argument, was used for projection 49 | self.projection_paths=construction_args.projection_paths 50 | #We never backprop through attention, hence unnecessary argument for now 51 | self.backprop_through_attention=construction_args.backprop_through_attention 52 | #Whether to restrict attention between prefix and non-prefix tokens for linear operations 53 | self.restrict_prefixes=construction_args.restrict_prefixes 54 | #Whether to use einsum, which speeds up inference 55 | self.use_einsum=construction_args.use_einsum 56 | #Whether to use classification loss with softmax, for computing gradients 57 | self.use_prediction_loss=construction_args.use_prediction_loss 58 | #Whether to use quad loss from Saunshi et al., for computing gradients 59 | self.use_quad=construction_args.use_quad 60 | #self.n_gpus = construction_args.n_gpus 61 | #For multiple gpus, we can further partition the model across multiple gpus 62 | self.n_layers_pergpu = construction_args.n_layers_pergpu 63 | #'cuda'/'cpu' 64 | self.device = construction_args.device 65 | 66 | #Whether to reuse forward blocks, when we do multiple forward passes 67 | self.reuse_forward_blocks = construction_args.reuse_forward_blocks 68 | #Whether to reuse backward blocks, when we do multiple forward passes 69 | self.reuse_backward_blocks = construction_args.reuse_backward_blocks 70 | 71 | #Whether to only update biases in layernorm. 72 | self.ln_update_bias_only = construction_args.ln_update_bias_only 73 | -------------------------------------------------------------------------------- /tests/perplexity_eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | import os 4 | 5 | from tint_main.tint_creator import * 6 | from transformers import AutoTokenizer 7 | from datasets import load_dataset 8 | from transformers import HfArgumentParser 9 | from filelock import Timeout, FileLock 10 | from .data_utils.data_utils import * 11 | from tint_main.utils.all_arguments import * 12 | 13 | parser = HfArgumentParser((ModelArguments, ConstructionArguments, DynamicArguments,)) 14 | model_args, construction_args, data_args, = parser.parse_args_into_dataclasses() 15 | 16 | 17 | tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir) 18 | if 'gpt' in model_args.model_name_or_path: 19 | tokenizer.pad_token = tokenizer.eos_token 20 | pad_token_id=tokenizer.convert_tokens_to_ids(tokenizer.pad_token) 21 | elif 'opt' in model_args.model_name_or_path: 22 | tokenizer.bos_token_id = 0 23 | 24 | dataset = preprocess_dataset(data_args, tokenizer) 25 | batch_size=data_args.batch_size 26 | constructed_model = load_construction(model_args.model_name_or_path, model_args.cache_dir, model_args.construct_model_path) 27 | 28 | 29 | 30 | constructed_model.eval() 31 | if construction_args.device == 'cuda': device='cuda:0' 32 | else: device='cpu' 33 | 34 | num_valid_batches = len(dataset) // batch_size 35 | train_fraction = data_args.train_fraction 36 | 37 | if data_args.data_subset == 0: 38 | exit(0) 39 | 40 | if data_args.data_subset != -1: 41 | num_valid_batches = min(data_args.data_subset // batch_size, num_valid_batches) 42 | 43 | 44 | avg_model_test_perplexity = 0. 45 | avg_eval_test_perplexity = 0. 46 | total_test_words = 0. 47 | 48 | 49 | 50 | for batch_id in tqdm( range( num_valid_batches ), desc='Inference' ): 51 | 52 | 53 | 54 | data = dataset [ batch_id * batch_size : (batch_id + 1) * batch_size ] 55 | batch_sentences = torch.tensor( data ['input_ids'] ) 56 | attention_mask = torch.tensor( data ['attention_mask'] ) 57 | labels = torch.tensor( data ['labels'] ) 58 | 59 | if len(batch_sentences.shape) == 1: 60 | batch_sentences = batch_sentences.view((1, -1)) 61 | attention_mask = attention_mask.view((1, -1)) 62 | labels = labels.view((1, -1)) 63 | 64 | 65 | 66 | mask = torch.zeros_like(attention_mask) 67 | target = batch_sentences.detach().clone() 68 | target [ torch.where(attention_mask == 0.) ] = -100 69 | 70 | batch_seq_lengths = torch.sum(attention_mask, dim=-1) 71 | for i in range(len(batch_seq_lengths)): 72 | mask[i, :int(batch_seq_lengths[i] * train_fraction)] = 1. 73 | target[ i, :int(batch_seq_lengths[i] * train_fraction) ] = -100 74 | 75 | bidirection_mask = mask.float() 76 | gradient_mask = None 77 | 78 | with torch.no_grad(): 79 | results = constructed_model.forward(batch_sentences.to(device), bidirection_mask.to(device), gradient_mask=gradient_mask, test_backward_pass=True, continue_from_first_forward_pass=False, labels=target.to(device)) 80 | original_loss, final_loss = results.original_loss, results.final_loss 81 | 82 | 83 | total_terms = torch.ne(target, -100).float().sum() 84 | 85 | avg_model_test_perplexity += original_loss.item() * total_terms 86 | avg_eval_test_perplexity += final_loss.item() * total_terms 87 | total_test_words += total_terms 88 | 89 | 90 | final_result = {} 91 | final_result[ 'Validation Dynamic eval acc (on test)' ] = np.exp(avg_eval_test_perplexity / total_test_words) 92 | final_result[ 'Validation Model acc (on test)' ] = np.exp(avg_model_test_perplexity / total_test_words) 93 | 94 | 95 | 96 | with FileLock('log_exp_construct.lock'): 97 | with open('log_exp_construct', 'a') as f: 98 | final_result.update(vars(model_args)) 99 | final_result.update(vars(data_args)) 100 | final_result.update(vars(construction_args)) 101 | f.write(str(final_result) + '\n') 102 | 103 | import torch.utils.benchmark as benchmark 104 | def benchmark_forward(fn, *inputs, repeats=10, desc='', verbose=True, amp=False, 105 | amp_dtype=torch.float16, **kwinputs): 106 | """ Use Pytorch Benchmark on the forward pass of an arbitrary function. """ 107 | if verbose: 108 | print(desc, '- Forward pass') 109 | def fn_amp(*inputs, **kwinputs): 110 | with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp): 111 | fn(*inputs, **kwinputs) 112 | for _ in range(repeats): # warmup 113 | fn_amp(*inputs, **kwinputs) 114 | t = benchmark.Timer( 115 | stmt='fn_amp(*inputs, **kwinputs)', 116 | globals={'fn_amp': fn_amp, 'inputs': inputs, 'kwinputs': kwinputs}, 117 | num_threads=torch.get_num_threads(), 118 | ) 119 | m = t.timeit(repeats) 120 | if verbose: 121 | print(m) 122 | return t, m -------------------------------------------------------------------------------- /tests/data_utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | import copy 5 | from filelock import FileLock 6 | from itertools import chain 7 | 8 | from accelerate import Accelerator 9 | from datasets import load_dataset, concatenate_datasets 10 | logger = logging.getLogger(__name__) 11 | 12 | def get_raw_data(args): 13 | if args.dataset == 'wikitext-103': 14 | dataset = load_dataset('wikitext', 'wikitext-103-v1', cache_dir=args.data_cache_dir) 15 | if args.use_eval_set: 16 | dataset = dataset['validation'] 17 | elif args.use_test_set: 18 | dataset = dataset['test'] 19 | else: 20 | dataset = dataset['train'] 21 | return dataset 22 | elif args.dataset == 'wikitext-2': 23 | print (args.data_cache_dir) 24 | dataset = load_dataset('wikitext', 'wikitext-2-v1', cache_dir=args.data_cache_dir) 25 | if args.use_eval_set: 26 | dataset = dataset['validation'] 27 | elif args.use_test_set: 28 | dataset = dataset['test'] 29 | else: 30 | dataset = dataset['train'] 31 | return dataset 32 | elif args.dataset == 'c4': 33 | print (args.data_cache_dir) 34 | dataset = load_dataset('c4', 'realnewslike', cache_dir=args.data_cache_dir) 35 | if args.use_eval_set: 36 | dataset = dataset['validation'] 37 | elif args.use_test_set: 38 | dataset = dataset['validation'] 39 | else: 40 | dataset = dataset['train'] 41 | return dataset 42 | raise NotImplementedError 43 | 44 | def preprocess_dataset(args, tokenizer): 45 | raw_datasets = get_raw_data(args) 46 | 47 | text_column_name = "text" 48 | column_names = raw_datasets.column_names 49 | 50 | 51 | accelerator = Accelerator() 52 | 53 | def tokenize_function(examples): 54 | return tokenizer([i for i in examples[text_column_name] if len(i)>0]) 55 | 56 | with accelerator.main_process_first(): 57 | datasets = raw_datasets.map( 58 | tokenize_function, 59 | batched=True, 60 | num_proc=args.num_workers, 61 | remove_columns=column_names, 62 | load_from_cache_file=True, 63 | desc="Running tokenizer on dataset", 64 | ) 65 | 66 | block_size = args.block_size 67 | # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. 68 | def group_texts(examples): 69 | # Concatenate all texts. 70 | concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} 71 | total_length = len(concatenated_examples[list(examples.keys())[0]]) 72 | # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can 73 | # customize this part to your needs. 74 | if total_length >= block_size: 75 | total_length = (total_length // block_size) * block_size 76 | # Split by chunks of max_len. 77 | result = { 78 | k: [t[i : i + block_size] for i in range(0, total_length, block_size)] 79 | for k, t in concatenated_examples.items() 80 | } 81 | result["labels"] = result["input_ids"].copy() 82 | return result 83 | 84 | def concat_texts(examples): 85 | # Concatenate all texts. 86 | concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} 87 | concatenated_examples['labels'] = concatenated_examples['input_ids'].copy() 88 | return concatenated_examples 89 | 90 | if args.chunk_data: 91 | with accelerator.main_process_first(): 92 | datasets = datasets.map( 93 | group_texts, 94 | batched=True, 95 | load_from_cache_file=True, 96 | num_proc=args.num_workers, 97 | desc=f"Grouping texts in chunks of {block_size}", 98 | ) 99 | else: 100 | with accelerator.main_process_first(): 101 | datasets = datasets.map( 102 | concat_texts, 103 | batched=True, 104 | load_from_cache_file=True, 105 | num_proc=args.num_workers, 106 | desc=f"Concatenating the texts", 107 | ) 108 | 109 | return datasets 110 | 111 | 112 | def prepare_inputs(batch, padding=None): 113 | for k in batch: 114 | if len(batch[k].shape) == 1: 115 | batch[k] = batch[k].reshape(1, -1) 116 | if padding is not None: 117 | for k in batch: 118 | batch[k] = torch.cat((batch[k], padding)) 119 | batch[k] = batch[k].cuda() 120 | 121 | return batch 122 | 123 | def init_model_and_optimizer(model_init, args): 124 | model = copy.deepcopy(model_init) 125 | if args.freeze_embs: 126 | if 'opt' in args.model_name: 127 | model.model.decoder.embed_tokens.requires_grad = False 128 | model.model.decoder.embed_positions.requires_grad = False 129 | else: 130 | model.transformer.wte.requires_grad = False 131 | model.transformer.wpe.requires_grad = False 132 | optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate) 133 | return model, optimizer -------------------------------------------------------------------------------- /tint_main/tint_creator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | from .tint_gpt import * 5 | from .tint_opt import * 6 | from transformers import GPT2LMHeadModel, OPTForCausalLM 7 | from transformers import HfArgumentParser 8 | from transformers import AutoConfig 9 | from .utils.config import Construct_config 10 | 11 | 12 | #Creates the tree view of the model layers 13 | def nested_children(m: torch.nn.Module): 14 | children = dict(m.named_children()) 15 | output = {} 16 | if children == {}: 17 | # if module has no children; m is last child! :O 18 | return m 19 | else: 20 | # look for children from children... to the last child! 21 | for name, child in children.items(): 22 | try: 23 | output[name] = nested_children(child) 24 | except TypeError: 25 | output[name] = nested_children(child) 26 | return output 27 | 28 | 29 | #Creates TinT on a given auxiliary models 30 | #model_args contains the necessary config for auxiliary model 31 | #construction_args contains the necessary config for TinT 32 | #The module calls TinT_gpt for gpt models and TinT_opt for opt models 33 | def TinT_Creator(model_args, construction_args): 34 | model_config = AutoConfig.from_pretrained( 35 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 36 | cache_dir=model_args.cache_dir 37 | ) 38 | 39 | 40 | construction_args.activation_function = model_config.activation_function 41 | construction_args.seq_length = model_config.max_position_embeddings 42 | 43 | config = Construct_config(model_config, construction_args) 44 | 45 | if 'gpt' in model_args.model_name_or_path: 46 | model_fn = GPT2LMHeadModel 47 | elif 'opt' in model_args.model_name_or_path: 48 | model_fn = OPTForCausalLM 49 | else: 50 | raise NotImplmentedError 51 | 52 | model = model_fn.from_pretrained( 53 | model_args.model_name_or_path, 54 | config=model_config, 55 | cache_dir=model_args.cache_dir, 56 | ) 57 | 58 | 59 | print ("....Constructing the model....") 60 | 61 | if model_args.construct_load_model: 62 | print ("....Load constructed model....") 63 | checkpoint = torch.load(model_args.construct_model_path) 64 | model_config = checkpoint['model_config'] 65 | config = checkpoint['construction_config'] 66 | 67 | print (nested_children(model)) 68 | if 'gpt' in model_args.model_name_or_path: 69 | constructed_model = TinT_gpt(config, model_config, nested_children(model)) 70 | elif 'opt' in model_args.model_name_or_path: 71 | constructed_model = TinT_opt(config, model_config, nested_children(model)) 72 | else: 73 | raise NotImplmentedError 74 | 75 | if model_args.construct_load_model: 76 | constructed_model.load_state_dict(checkpoint['model_state_dict']) 77 | 78 | 79 | if model_args.construct_save_model: 80 | print ("....Store constructed model....") 81 | torch.save({'model_state_dict': constructed_model.state_dict(),\ 82 | 'model_config': model_config,\ 83 | 'construction_config': config},\ 84 | model_args.construct_model_path,\ 85 | ) 86 | 87 | 88 | tot = 0 89 | for parameter in constructed_model.parameters(): 90 | tot += parameter.numel() 91 | print ("Total trainable parameters in constructed model", tot) 92 | 93 | return (constructed_model, model, model_config, config) 94 | 95 | 96 | 97 | #Creates TinT from a config checkpoint 98 | #model_path: auxiliary model name gpt2 or facebook/opt-125m 99 | #cache_dir is where the auxiliary model has been saved from huggingface 100 | #construction_args contains the necessary config for TinT 101 | #The module calls TinT_gpt for gpt models and TinT_opt for opt models 102 | def load_construction(model_path, \ 103 | cache_dir, \ 104 | construction_path, \ 105 | ): 106 | 107 | model_config = AutoConfig.from_pretrained( 108 | model_path, 109 | cache_dir=cache_dir 110 | ) 111 | 112 | 113 | if 'gpt' in model_path: 114 | model_fn = GPT2LMHeadModel 115 | elif 'opt' in model_path: 116 | model_fn = OPTForCausalLM 117 | else: 118 | raise NotImplmentedError 119 | 120 | model = model_fn.from_pretrained( 121 | model_path, 122 | config=model_config, 123 | cache_dir=cache_dir, 124 | ) 125 | 126 | import time 127 | start = time.time() 128 | print ("....Load constructed model....") 129 | checkpoint = torch.load(construction_path) 130 | model_config = checkpoint['model_config'] 131 | config = checkpoint['construction_config'] 132 | 133 | if 'gpt' in model_path: 134 | constructed_model = TinT_gpt(config, model_config, nested_children(model)) 135 | elif 'opt' in model_path: 136 | constructed_model = TinT_opt(config, model_config, nested_children(model)) 137 | else: 138 | raise NotImplmentedError 139 | 140 | constructed_model.load_state_dict(checkpoint['model_state_dict']) 141 | end = time.time() 142 | print("Time for loading constructed model: ", round(end - start, 2)) 143 | return constructed_model 144 | 145 | 146 | if __name__ == "__main__": 147 | parser = HfArgumentParser((ModelArguments,)) 148 | model_args, = parser.parse_args_into_dataclasses() 149 | 150 | constructed_model = load_construction(model_args.model_name_or_path, \ 151 | model_args.cache_dir, \ 152 | model_args.construct_model_path, \ 153 | ) 154 | 155 | 156 | 157 | 158 | -------------------------------------------------------------------------------- /tint_main/utils/activation/forward.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import math 4 | import os 5 | from dataclasses import dataclass 6 | from typing import Optional, Tuple, Union 7 | 8 | import torch 9 | import torch.utils.checkpoint 10 | from torch import nn 11 | from torch.cuda.amp import autocast 12 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 13 | from ..modules import * 14 | from ..linear import * 15 | from ..activations import * 16 | 17 | 18 | 19 | 20 | class ActivationForward (nn.Module): 21 | def __init__ (self, config, din, projection_matrix=None, memory_index=-1): 22 | super(ActivationForward, self).__init__() 23 | 24 | self.din=din 25 | self.config=config 26 | self.projection_matrix = projection_matrix 27 | self.memory_index = memory_index 28 | 29 | 30 | if projection_matrix is not None: 31 | self.dout = projection_matrix.shape[0] 32 | else: 33 | self.dout = din 34 | 35 | 36 | assert memory_index == -1 or memory_index >= self.dout,\ 37 | "Memory interacts with final signal" 38 | 39 | assert memory_index == -1 or memory_index <= config.hidden_size - self.din, \ 40 | "not enough space to store memory" 41 | 42 | if projection_matrix is not None: 43 | head_dim = self.dout 44 | num_channels = config.hidden_size // head_dim 45 | else: 46 | num_channels = config.num_attention_heads 47 | head_dim = config.hidden_size // num_channels 48 | 49 | self.mlp_module = MLP (config.hidden_size, \ 50 | config, \ 51 | conv2d=True, \ 52 | transpose_intermediate=True, \ 53 | transpose_proj=False, \ 54 | conv_proj_features=num_channels, \ 55 | ) 56 | 57 | self.mlp_gates = Gates (config) 58 | self.projection_ = None 59 | 60 | 61 | if projection_matrix is not None: 62 | 63 | assert projection_matrix.shape[1] == din,\ 64 | "Projection matrix must have 'din' in second coordinate" 65 | assert projection_matrix.shape[1] >= head_dim, \ 66 | "Currently, this projection only works when we project down to a lower dimension" 67 | assert projection_matrix.shape[1] % head_dim == 0, \ 68 | "Perfect division into channels" 69 | 70 | c_proj_init = torch.zeros((num_channels, head_dim, head_dim), dtype=self.mlp_module.c_proj.weight.dtype) 71 | num_useful_channels = projection_matrix.shape[1] // head_dim 72 | for i in range (num_useful_channels): 73 | c_proj_init[i] = torch.tensor(projection_matrix[:, i*head_dim: (i+1)*head_dim], dtype=self.mlp_module.c_proj.weight.dtype) 74 | self.mlp_module.initialize_weights(c_proj_init=c_proj_init) 75 | 76 | self.projection_ = Conv2D( nf=num_channels, nx=head_dim, transpose=True, use_einsum=self.config.use_einsum ) 77 | with torch.no_grad(): 78 | self.projection_.weight.copy_(torch.zeros(head_dim, num_channels, num_channels)) 79 | self.projection_.weight[:, :num_useful_channels, 0] = 1. 80 | 81 | else: 82 | c_proj_init = torch.zeros((num_channels, head_dim, head_dim), dtype=self.mlp_module.c_proj.weight.dtype) 83 | 84 | if self.memory_index != -1: 85 | assert memory_index % head_dim == 0, \ 86 | "Memory should be divisible by the number of channels!" 87 | 88 | mem_head_start = memory_index // head_dim 89 | 90 | c_proj_init[:mem_head_start] = torch.eye(head_dim) 91 | self.mlp_module.initialize_weights(c_proj_init=c_proj_init) 92 | else: 93 | c_proj_init[:] = torch.eye(head_dim) 94 | self.mlp_module.initialize_weights(c_proj_init=c_proj_init) 95 | 96 | #Initialize Gates 97 | #Ignore the changes for the prefixes! 98 | #w, u, v, w_bias, u_bias, v_bias 99 | w = torch.zeros((1, 2*config.hidden_size)) 100 | u = torch.zeros((1, 2*config.hidden_size)) 101 | v = torch.zeros((1, 2*config.position_dim)) 102 | w_bias = torch.zeros(2) 103 | u_bias = torch.zeros(2) 104 | v_bias = torch.zeros(2) 105 | 106 | #Input Gate is 1 on prefixes and 0 for non-prefixes 107 | v [0, config.seq_length: config.position_dim] = config.gate_scale * torch.ones(config.position_dim-config.seq_length) 108 | 109 | 110 | #Change Gate is 0 on prefixes and 1 for non-prefixes 111 | v [0, config.position_dim+config.seq_length: 2*config.position_dim] = -config.gate_scale * torch.ones(config.position_dim-config.seq_length) 112 | v_bias [1] += config.gate_scale 113 | 114 | self.mlp_gates.initialize_weights (w, u, v, w_bias, u_bias, v_bias) 115 | 116 | 117 | def forward(self, hidden_states, position_embeddings): 118 | mlp_out = self.mlp_module.forward(hidden_states) 119 | if self.projection_ is not None: 120 | mlp_out = self.projection_(mlp_out) 121 | 122 | if self.memory_index != -1: 123 | assert torch.sum(torch.absolute(mlp_out[:, self.config.num_prefixes:, self.memory_index:])).item() < 1e-10,\ 124 | "Memory portion not empty!" 125 | 126 | mlp_out[:, self.config.num_prefixes:, self.memory_index: self.memory_index+self.din] += hidden_states[:, self.config.num_prefixes:, :self.din] 127 | 128 | gate_out = self.mlp_gates.forward(hidden_states, mlp_out, position_embeddings) 129 | 130 | return gate_out 131 | 132 | 133 | 134 | -------------------------------------------------------------------------------- /env.yml: -------------------------------------------------------------------------------- 1 | name: icl_as_ft 2 | channels: 3 | - pytorch 4 | - soumith 5 | - nvidia 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=main 9 | - _openmp_mutex=5.1=1_gnu 10 | - blas=1.0=mkl 11 | - brotlipy=0.7.0=py310h7f8727e_1002 12 | - bzip2=1.0.8=h7b6447c_0 13 | - ca-certificates=2022.10.11=h06a4308_0 14 | - certifi=2022.9.24=py310h06a4308_0 15 | - cffi=1.15.1=py310h74dc2b5_0 16 | - cryptography=38.0.1=py310h9ce1e76_0 17 | - cuda=11.7.1=0 18 | - cuda-cccl=11.7.91=0 19 | - cuda-command-line-tools=11.7.1=0 20 | - cuda-compiler=11.7.1=0 21 | - cuda-cudart=11.7.99=0 22 | - cuda-cudart-dev=11.7.99=0 23 | - cuda-cuobjdump=11.7.91=0 24 | - cuda-cupti=11.7.101=0 25 | - cuda-cuxxfilt=11.7.91=0 26 | - cuda-demo-suite=11.8.86=0 27 | - cuda-documentation=11.8.86=0 28 | - cuda-driver-dev=11.7.99=0 29 | - cuda-gdb=11.8.86=0 30 | - cuda-libraries=11.7.1=0 31 | - cuda-libraries-dev=11.7.1=0 32 | - cuda-memcheck=11.8.86=0 33 | - cuda-nsight=11.8.86=0 34 | - cuda-nsight-compute=11.8.0=0 35 | - cuda-nvcc=11.7.99=0 36 | - cuda-nvdisasm=11.8.86=0 37 | - cuda-nvml-dev=11.7.91=0 38 | - cuda-nvprof=11.8.87=0 39 | - cuda-nvprune=11.7.91=0 40 | - cuda-nvrtc=11.7.99=0 41 | - cuda-nvrtc-dev=11.7.99=0 42 | - cuda-nvtx=11.7.91=0 43 | - cuda-nvvp=11.8.87=0 44 | - cuda-runtime=11.7.1=0 45 | - cuda-sanitizer-api=11.8.86=0 46 | - cuda-toolkit=11.7.1=0 47 | - cuda-tools=11.7.1=0 48 | - cuda-visual-tools=11.7.1=0 49 | - cuda80=1.0=0 50 | - cudatoolkit=11.1.74=h6bb024c_0 51 | - ffmpeg=4.3=hf484d3e_0 52 | - freetype=2.12.1=h4a9f257_0 53 | - gds-tools=1.4.0.31=0 54 | - giflib=5.2.1=h7b6447c_0 55 | - gmp=6.2.1=h295c915_3 56 | - gnutls=3.6.15=he1e5248_0 57 | - idna=3.4=py310h06a4308_0 58 | - intel-openmp=2021.4.0=h06a4308_3561 59 | - jpeg=9e=h7f8727e_0 60 | - lame=3.100=h7b6447c_0 61 | - lcms2=2.12=h3be6417_0 62 | - ld_impl_linux-64=2.38=h1181459_1 63 | - lerc=3.0=h295c915_0 64 | - libcublas=11.11.3.6=0 65 | - libcublas-dev=11.11.3.6=0 66 | - libcufft=10.9.0.58=0 67 | - libcufft-dev=10.9.0.58=0 68 | - libcufile=1.4.0.31=0 69 | - libcufile-dev=1.4.0.31=0 70 | - libcurand=10.3.0.86=0 71 | - libcurand-dev=10.3.0.86=0 72 | - libcusolver=11.4.1.48=0 73 | - libcusolver-dev=11.4.1.48=0 74 | - libcusparse=11.7.5.86=0 75 | - libcusparse-dev=11.7.5.86=0 76 | - libdeflate=1.8=h7f8727e_5 77 | - libffi=3.3=he6710b0_2 78 | - libgcc-ng=11.2.0=h1234567_1 79 | - libgomp=11.2.0=h1234567_1 80 | - libiconv=1.16=h7f8727e_2 81 | - libidn2=2.3.2=h7f8727e_0 82 | - libnpp=11.8.0.86=0 83 | - libnpp-dev=11.8.0.86=0 84 | - libnvjpeg=11.9.0.86=0 85 | - libnvjpeg-dev=11.9.0.86=0 86 | - libpng=1.6.37=hbc83047_0 87 | - libstdcxx-ng=11.2.0=h1234567_1 88 | - libtasn1=4.16.0=h27cfd23_0 89 | - libtiff=4.4.0=hecacb30_2 90 | - libunistring=0.9.10=h27cfd23_0 91 | - libuuid=1.41.5=h5eee18b_0 92 | - libwebp=1.2.4=h11a3e52_0 93 | - libwebp-base=1.2.4=h5eee18b_0 94 | - lz4-c=1.9.3=h295c915_1 95 | - mkl=2021.4.0=h06a4308_640 96 | - mkl-service=2.4.0=py310h7f8727e_0 97 | - mkl_fft=1.3.1=py310hd6ae3a3_0 98 | - mkl_random=1.2.2=py310h00e6091_0 99 | - ncurses=6.3=h5eee18b_3 100 | - nettle=3.7.3=hbbd107a_1 101 | - nsight-compute=2022.3.0.22=0 102 | - numpy=1.23.4=py310hd5efca6_0 103 | - numpy-base=1.23.4=py310h8e6c178_0 104 | - openh264=2.1.1=h4ff587b_0 105 | - openssl=1.1.1s=h7f8727e_0 106 | - pillow=9.2.0=py310hace64e9_1 107 | - pip=22.2.2=py310h06a4308_0 108 | - pycparser=2.21=pyhd3eb1b0_0 109 | - pyopenssl=22.0.0=pyhd3eb1b0_0 110 | - pysocks=1.7.1=py310h06a4308_0 111 | - python=3.10.8=haa1d7c7_0 112 | - pytorch=1.13.0=py3.10_cuda11.7_cudnn8.5.0_0 113 | - pytorch-cuda=11.7=h67b0de4_0 114 | - pytorch-mutex=1.0=cuda 115 | - readline=8.2=h5eee18b_0 116 | - requests=2.28.1=py310h06a4308_0 117 | - setuptools=65.5.0=py310h06a4308_0 118 | - six=1.16.0=pyhd3eb1b0_1 119 | - sqlite=3.39.3=h5082296_0 120 | - tk=8.6.12=h1ccaba5_0 121 | - torchaudio=0.13.0=py310_cu117 122 | - torchvision=0.14.0=py310_cu117 123 | - typing_extensions=4.3.0=py310h06a4308_0 124 | - tzdata=2022f=h04d1e81_0 125 | - urllib3=1.26.12=py310h06a4308_0 126 | - wheel=0.37.1=pyhd3eb1b0_0 127 | - xz=5.2.6=h5eee18b_0 128 | - zlib=1.2.13=h5eee18b_0 129 | - zstd=1.5.2=ha4553b6_0 130 | - pip: 131 | - accelerate==0.14.0 132 | - aiohttp==3.8.3 133 | - aiosignal==1.3.1 134 | - anyio==3.6.2 135 | - argon2-cffi==21.3.0 136 | - argon2-cffi-bindings==21.2.0 137 | - asttokens==2.2.0 138 | - async-timeout==4.0.2 139 | - attrs==22.1.0 140 | - backcall==0.2.0 141 | - beautifulsoup4==4.11.1 142 | - bleach==5.0.1 143 | - charset-normalizer==2.1.1 144 | - contourpy==1.0.6 145 | - cycler==0.11.0 146 | - datasets==2.7.1 147 | - debugpy==1.6.4 148 | - decorator==5.1.1 149 | - defusedxml==0.7.1 150 | - dill==0.3.6 151 | - entrypoints==0.4 152 | - executing==1.2.0 153 | - fastjsonschema==2.16.2 154 | - filelock==3.8.0 155 | - fonttools==4.38.0 156 | - frozenlist==1.3.3 157 | - fsspec==2022.11.0 158 | - huggingface-hub==0.11.0 159 | - ipykernel==6.17.1 160 | - ipython==8.7.0 161 | - ipython-genutils==0.2.0 162 | - jedi==0.18.2 163 | - jinja2==3.1.2 164 | - jsonschema==4.17.3 165 | - jupyter-client==7.4.7 166 | - jupyter-core==5.1.0 167 | - jupyter-server==1.23.3 168 | - jupyterlab-pygments==0.2.2 169 | - kiwisolver==1.4.4 170 | - markupsafe==2.1.1 171 | - matplotlib==3.6.2 172 | - matplotlib-inline==0.1.6 173 | - mistune==2.0.4 174 | - multidict==6.0.2 175 | - multiprocess==0.70.14 176 | - nbclassic==0.4.8 177 | - nbclient==0.7.2 178 | - nbconvert==7.2.5 179 | - nbformat==5.7.0 180 | - nest-asyncio==1.5.6 181 | - notebook==6.5.2 182 | - notebook-shim==0.2.2 183 | - packaging==21.3 184 | - pandas==1.5.1 185 | - pandocfilters==1.5.0 186 | - parso==0.8.3 187 | - pexpect==4.8.0 188 | - pickleshare==0.7.5 189 | - platformdirs==2.5.4 190 | - prometheus-client==0.15.0 191 | - prompt-toolkit==3.0.33 192 | - psutil==5.9.4 193 | - ptyprocess==0.7.0 194 | - pure-eval==0.2.2 195 | - pyarrow==10.0.0 196 | - pygments==2.13.0 197 | - pyparsing==3.0.9 198 | - pyrsistent==0.19.2 199 | - python-dateutil==2.8.2 200 | - pytz==2022.6 201 | - pyyaml==6.0 202 | - pyzmq==24.0.1 203 | - regex==2022.10.31 204 | - responses==0.18.0 205 | - send2trash==1.8.0 206 | - sniffio==1.3.0 207 | - soupsieve==2.3.2.post1 208 | - stack-data==0.6.2 209 | - terminado==0.17.0 210 | - tinycss2==1.2.1 211 | - tokenizers==0.13.2 212 | - tornado==6.2 213 | - tqdm==4.64.1 214 | - traitlets==5.6.0 215 | - transformers==4.24.0 216 | - wcwidth==0.2.5 217 | - webencodings==0.5.1 218 | - websocket-client==1.4.2 219 | - xxhash==3.1.0 220 | - yarl==1.8.1 221 | -------------------------------------------------------------------------------- /tint_main/utils/activations.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import math 16 | from collections import OrderedDict 17 | 18 | import torch 19 | from packaging import version 20 | from torch import Tensor, nn 21 | 22 | 23 | 24 | class PytorchGELUTanh(nn.Module): 25 | """ 26 | A fast C implementation of the tanh approximation of the GeLU activation function. See 27 | https://arxiv.org/abs/1606.08415. 28 | 29 | This implementation is equivalent to NewGELU and FastGELU but much faster. However, it is not an exact numerical 30 | match due to rounding errors. 31 | """ 32 | 33 | def __init__(self): 34 | super().__init__() 35 | if version.parse(torch.__version__) < version.parse("1.12.0"): 36 | raise ImportError( 37 | f"You are using torch=={torch.__version__}, but torch>=1.12.0 is required to use " 38 | "PytorchGELUTanh. Please upgrade torch." 39 | ) 40 | 41 | def forward(self, input: Tensor) -> Tensor: 42 | return nn.functional.gelu(input, approximate="tanh") 43 | 44 | 45 | class NewGELUActivation(nn.Module): 46 | """ 47 | Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see 48 | the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 49 | """ 50 | 51 | def forward(self, input: Tensor) -> Tensor: 52 | return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0)))) 53 | 54 | 55 | class GELUActivation(nn.Module): 56 | """ 57 | Original Implementation of the GELU activation function in Google BERT repo when initially created. For 58 | information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 + 59 | torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional 60 | Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 61 | """ 62 | 63 | def __init__(self, use_gelu_python: bool = False): 64 | super().__init__() 65 | if use_gelu_python: 66 | self.act = self._gelu_python 67 | else: 68 | self.act = nn.functional.gelu 69 | 70 | def _gelu_python(self, input: Tensor) -> Tensor: 71 | return input * 0.5 * (1.0 + torch.erf(input / math.sqrt(2.0))) 72 | 73 | def forward(self, input: Tensor) -> Tensor: 74 | return self.act(input) 75 | 76 | 77 | class FastGELUActivation(nn.Module): 78 | """ 79 | Applies GELU approximation that is slower than QuickGELU but more accurate. See: https://github.com/hendrycks/GELUs 80 | """ 81 | 82 | def forward(self, input: Tensor) -> Tensor: 83 | return 0.5 * input * (1.0 + torch.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input))) 84 | 85 | 86 | class QuickGELUActivation(nn.Module): 87 | """ 88 | Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs 89 | """ 90 | 91 | def forward(self, input: Tensor) -> Tensor: 92 | return input * torch.sigmoid(1.702 * input) 93 | 94 | 95 | class ClippedGELUActivation(nn.Module): 96 | """ 97 | Clip the range of possible GeLU outputs between [min, max]. This is especially useful for quantization purpose, as 98 | it allows mapping negatives values in the GeLU spectrum. For more information on this trick, please refer to 99 | https://arxiv.org/abs/2004.09602. 100 | 101 | Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when 102 | initially created. 103 | 104 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 + 105 | torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))). See https://arxiv.org/abs/1606.08415 106 | """ 107 | 108 | def __init__(self, min: float, max: float): 109 | if min > max: 110 | raise ValueError(f"min should be < max (got min: {min}, max: {max})") 111 | 112 | super().__init__() 113 | self.min = min 114 | self.max = max 115 | 116 | def forward(self, x: Tensor) -> Tensor: 117 | return torch.clip(gelu(x), self.min, self.max) 118 | 119 | 120 | class SiLUActivation(nn.Module): 121 | """ 122 | See Gaussian Error Linear Units (Hendrycks et al., https://arxiv.org/abs/1606.08415) where the SiLU (Sigmoid Linear 123 | Unit) was originally introduced and coined, and see Sigmoid-Weighted Linear Units for Neural Network Function 124 | Approximation in Reinforcement Learning (Elfwing et al., https://arxiv.org/abs/1702.03118) and Swish: a Self-Gated 125 | Activation Function (Ramachandran et al., https://arxiv.org/abs/1710.05941v1) where the SiLU was experimented with 126 | later. 127 | """ 128 | 129 | def forward(self, input: Tensor) -> Tensor: 130 | return nn.functional.silu(input) 131 | 132 | 133 | class MishActivation(nn.Module): 134 | """ 135 | See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://arxiv.org/abs/1908.08681). Also 136 | visit the official repository for the paper: https://github.com/digantamisra98/Mish 137 | """ 138 | 139 | def __init__(self): 140 | super().__init__() 141 | if version.parse(torch.__version__) < version.parse("1.9.0"): 142 | self.act = self._mish_python 143 | else: 144 | self.act = nn.functional.mish 145 | 146 | def _mish_python(self, input: Tensor) -> Tensor: 147 | return input * torch.tanh(nn.functional.softplus(input)) 148 | 149 | def forward(self, input: Tensor) -> Tensor: 150 | return self.act(input) 151 | 152 | 153 | class LinearActivation(nn.Module): 154 | """ 155 | Applies the linear activation function, i.e. forwarding input directly to output. 156 | """ 157 | 158 | def forward(self, input: Tensor) -> Tensor: 159 | return input 160 | 161 | 162 | class ClassInstantier(OrderedDict): 163 | def __getitem__(self, key): 164 | content = super().__getitem__(key) 165 | cls, kwargs = content if isinstance(content, tuple) else (content, {}) 166 | return cls(**kwargs) 167 | 168 | 169 | ACT2CLS = { 170 | "gelu": GELUActivation, 171 | "gelu_10": (ClippedGELUActivation, {"min": -10, "max": 10}), 172 | "gelu_fast": FastGELUActivation, 173 | "gelu_new": NewGELUActivation, 174 | "gelu_python": (GELUActivation, {"use_gelu_python": True}), 175 | "gelu_pytorch_tanh": PytorchGELUTanh, 176 | "linear": LinearActivation, 177 | "mish": MishActivation, 178 | "quick_gelu": QuickGELUActivation, 179 | "relu": nn.ReLU, 180 | "relu6": nn.ReLU6, 181 | "sigmoid": nn.Sigmoid, 182 | "silu": SiLUActivation, 183 | "swish": SiLUActivation, 184 | "tanh": nn.Tanh, 185 | } 186 | ACT2FN = ClassInstantier(ACT2CLS) 187 | 188 | 189 | def get_activation(activation_string): 190 | if activation_string in ACT2FN: 191 | return ACT2FN[activation_string] 192 | else: 193 | raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}") 194 | 195 | 196 | # For backwards compatibility with: from activations import gelu_python 197 | gelu_python = get_activation("gelu_python") 198 | gelu_new = get_activation("gelu_new") 199 | gelu = get_activation("gelu") 200 | gelu_fast = get_activation("gelu_fast") 201 | quick_gelu = get_activation("quick_gelu") 202 | silu = get_activation("silu") 203 | mish = get_activation("mish") 204 | linear_act = get_activation("linear") -------------------------------------------------------------------------------- /tint_main/utils/all_arguments.py: -------------------------------------------------------------------------------- 1 | from transformers import HfArgumentParser 2 | from dataclasses import dataclass, field 3 | from typing import Callable, Dict, Optional, Union, List 4 | 5 | @dataclass 6 | class ModelArguments: 7 | """ 8 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 9 | """ 10 | model_name_or_path: str = field( 11 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 12 | ) 13 | config_name: Optional[str] = field( 14 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 15 | ) 16 | cache_dir: Optional[str] = field( 17 | default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} 18 | ) 19 | 20 | 21 | construct_save_model: Optional[bool] = field( 22 | default=False, metadata={"help": "Whether to save constructed model"} 23 | ) 24 | 25 | 26 | construct_load_model: Optional[bool] = field( 27 | default=False, metadata={"help": "Whether to load constructed model from a specific path"} 28 | ) 29 | 30 | construct_model_path: Optional[str] = field( 31 | default="", metadata={"help": "Path to save the model"} 32 | ) 33 | 34 | 35 | 36 | 37 | 38 | 39 | @dataclass 40 | class ConstructionArguments: 41 | """ 42 | Arguments pertaining 43 | """ 44 | 45 | device: Optional[str] = field( 46 | default="cuda", metadata={"help": "cuda/cpu"} 47 | ) 48 | 49 | seq_length: Optional[int] = field( 50 | default=1024, metadata={"help": "Sequence length for the smaller model"} 51 | ) 52 | 53 | #position_dim: Optional[int] = field( 54 | # default=1024 + 256, metadata={"help": ""} 55 | #) 56 | 57 | num_prefixes: Optional[int] = field( 58 | default=256, metadata={"help": "Number of prefixes to encode at the start of the sequence"} 59 | ) 60 | 61 | num_attention_heads: Optional[int] = field( 62 | default=12, metadata={"help": "Number of attention heads"} 63 | ) 64 | 65 | scale_embeddings: Optional[float] = field( 66 | default=1000., metadata={"help": "Scale factor to minimize error introduced by multiplications with GeLU"} 67 | ) 68 | 69 | inner_lr: Optional[float] = field( 70 | default=0.000001, metadata={"help": "Learning rate to simulate SGD inside the model"} 71 | ) 72 | 73 | gate_scale: Optional[float] = field( 74 | default=10., metadata={"help": "Initial scale inside gate weights to simulate 0/1 switch"} 75 | ) 76 | hidden_size: Optional[int] = field( 77 | default=4, metadata={"help": "Multiple of Embedding size (of the smaller model) for the construction"} 78 | ) 79 | 80 | max_position_embeddings: Optional[int] = field( 81 | default=2048, metadata={"help": "Max sequence length"} 82 | ) 83 | 84 | embd_pdrop: Optional[float] = field( 85 | default=0.1, metadata={"help": "Dropout on attention weights"} 86 | ) 87 | 88 | attn_pdrop: Optional[float] = field( 89 | default=0.1, metadata={"help": "Dropout on attention weights"} 90 | ) 91 | resid_pdrop: Optional[float] = field( 92 | default=0.1, metadata={"help": "Dropout on residual connections"} 93 | ) 94 | 95 | activation_function: Optional[str] = field( 96 | default="gelu", metadata={"help": "Activation: gelu/relu"} 97 | ) 98 | epsilon: Optional[float] = field( 99 | default=1e-05, metadata={"help": "Epsilon for layernorm computation"} 100 | ) 101 | 102 | scale_attn_weights: Optional[bool] = field( 103 | default=True, metadata={"help": "Whether to scale attention weights"} 104 | ) 105 | 106 | n_simulation_layers: Optional[int] = field( 107 | default=-1, metadata={"help": "Number of layers to simulate forward-backward passes"} 108 | ) 109 | 110 | n_forward_backward: Optional[int] = field( 111 | default=1, metadata={"help": "Number of forward-backward passes"} 112 | ) 113 | 114 | n_debug_layers: Optional[int] = field( 115 | default=-1, metadata={"help": "Number of layers of smaller model that we use (for debugging purposes)"} 116 | ) 117 | 118 | projection_paths: Optional[str] = field( 119 | default="./projections", metadata={"help": "Path to all the projection matrices"} 120 | ) 121 | 122 | backprop_through_attention: Optional[bool] = field( 123 | default=True, metadata={"help": "Whether to look ahead during backprop through attention"} 124 | ) 125 | 126 | restrict_prefixes: Optional[bool] = field( 127 | default=False, metadata={"help": "Whether to restrict attention to blank tokens only in linear forward/backward"} 128 | ) 129 | 130 | use_einsum: Optional[bool] = field( 131 | default=True, metadata={"help": "Whether to use einsum in dimension wise linear convolutions"} 132 | ) 133 | 134 | use_prediction_loss: Optional[bool] = field( 135 | default=True, metadata={"help": "If true, gradient w.r.t. loss is E(p-q) with p being the softmax prediction, if false (and use_quad is false), gradient w.r.t. loss is E(1 - q)"} 136 | ) 137 | 138 | use_quad: Optional[bool] = field( 139 | default=False, metadata={"help": "If true, we use the quad loss to compute the gradient from Saunshi & Malladi et al. (2020)"} 140 | ) 141 | 142 | 143 | n_layers_pergpu: Optional[int] = field( 144 | default=100000, metadata={"help": "Number of layers simulated per gpu"} 145 | ) 146 | 147 | reuse_forward_blocks: Optional[bool] = field( 148 | default=False, metadata={"help": "Whether to re-use forward blocks"} 149 | ) 150 | 151 | reuse_backward_blocks: Optional[bool] = field( 152 | default=False, metadata={"help": "Whether to re-use backward blocks"} 153 | ) 154 | 155 | ln_update_bias_only: Optional[bool] = field( 156 | default=True, metadata={"help": "Whether to update only biases in layernorm modules; current weight update of layernorm modules is very noisy, hence it's best to avoid!"} 157 | ) 158 | 159 | 160 | @dataclass 161 | class DynamicArguments: 162 | 163 | data_cache_dir: Optional[str] = field( 164 | default='/scratch/gpfs/ap34/icl-as-ft/Dynamic_initialization/data', 165 | metadata={ 166 | 'help': 'where to store downloaded model, datasets, and tokenizer' 167 | } 168 | ) 169 | 170 | log_dir: Optional[str] = field( 171 | default='/scratch/gpfs/smalladi/icl_as_ft/logs' 172 | ) 173 | 174 | dataset: Optional[str] = field( 175 | default='wikitext-103', 176 | metadata={ 177 | 'help': 'dataset to use, should be in HF datasets repo' 178 | } 179 | ) 180 | 181 | incontext_dataset: Optional[str] = field( 182 | default="", metadata={"help": "Dataset for in-context experiments!"} 183 | ) 184 | 185 | incontext_stem: Optional[str] = field( 186 | default="", metadata={"help": "Directory containing the files!"} 187 | ) 188 | 189 | incontext_n_shot: Optional[int] = field( 190 | default=4, metadata={"help": "Number of demonstations in the in-context prompt!"} 191 | ) 192 | 193 | prompt_variant: Optional[int] = field( 194 | default=0, metadata={"help": "Variants to try for sst-2!"} 195 | ) 196 | 197 | 198 | use_eval_set: bool = field( 199 | default=True, 200 | metadata={'help': 'when on, uses the eval dataset only for ppl measurement/in-context expts.'} 201 | ) 202 | 203 | use_test_set: bool = field( 204 | default=False, 205 | metadata={'help': 'when on, uses the test dataset only for ppl measurement/in-context expts.'} 206 | ) 207 | 208 | ### finetuning strategies ### 209 | 210 | chunk_data: bool = field( 211 | default=True, 212 | metadata={'help': 'when on, chunks the data into block_size chunks'} 213 | ) 214 | 215 | baseline: bool = field( 216 | default=False 217 | ) 218 | 219 | num_subsets: int = field( 220 | default=16 221 | ) 222 | data_chunk_index: int = field( 223 | default=0 224 | ) 225 | 226 | block_size: Optional[int] = field( 227 | default=1024, 228 | metadata={'help': 'size of chunks to break corpus into'} 229 | ) 230 | 231 | num_workers: Optional[int] = field( 232 | default=4, 233 | metadata={'help': 'number of workers'} 234 | ) 235 | 236 | data_subset: Optional[int] = field( 237 | default=-1, metadata={"help": "Number of training examples to use in the subset!"} 238 | ) 239 | 240 | train_fraction: Optional[float] = field( 241 | default=0.5, metadata={"help": "Fraction of sentence for training!"} 242 | ) 243 | 244 | 245 | batch_size: Optional[int] = field( 246 | default=1, metadata={"help": "Batch size for training!"} 247 | ) 248 | 249 | -------------------------------------------------------------------------------- /tint_main/utils/layernorm/backward.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import math 4 | import os 5 | from dataclasses import dataclass 6 | from typing import Optional, Tuple, Union 7 | 8 | import torch 9 | import torch.utils.checkpoint 10 | from torch import nn 11 | from torch.cuda.amp import autocast 12 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 13 | from ..modules import * 14 | from ..linear import * 15 | 16 | 17 | #Assumption on memory 18 | #It should contain [(x-\mu)/\sigma, x] 19 | 20 | class LayerNormBackward(nn.Module): 21 | def __init__(self, config, din, use_softmax, retain_nablay=False, memory_index=-1): 22 | super(LayerNormBackward, self).__init__() 23 | 24 | assert use_softmax==False ,\ 25 | "Currently I only use linear attention in this module" 26 | 27 | assert memory_index == -1 or memory_index >= din, \ 28 | "memory crosses current signal" 29 | 30 | 31 | self.linear = LinearBackward(config, \ 32 | din=din, \ 33 | dout=din, \ 34 | use_softmax=use_softmax, \ 35 | retain_nablay=retain_nablay, \ 36 | memory_index=memory_index, \ 37 | ) 38 | self.epsilon = config.epsilon 39 | self.memory_index = memory_index 40 | self.config = config 41 | 42 | head_dim = config.hidden_size // config.num_attention_heads 43 | self.c_fc = Conv2D(config.num_attention_heads, head_dim, transpose=True, use_einsum=self.config.use_einsum) 44 | self.proj_fc = Conv2D(config.num_attention_heads, head_dim, transpose=True, use_einsum=self.config.use_einsum) 45 | 46 | self.config = config 47 | self.din = din 48 | 49 | 50 | c_fc_init = torch.zeros((head_dim, config.num_attention_heads, config.num_attention_heads)) 51 | c_proj_init = torch.zeros((head_dim, config.num_attention_heads, config.num_attention_heads)) 52 | 53 | 54 | 55 | 56 | assert din % head_dim == 0, \ 57 | " 'din' should be a multiple of head_dim! " 58 | 59 | num_partitions = din // head_dim 60 | 61 | 62 | 63 | assert self.memory_index % head_dim == 0, \ 64 | "Memory should start at a multiple of head_dim!" 65 | 66 | mem_head_start = self.memory_index // head_dim 67 | 68 | if retain_nablay: 69 | start_shift = num_partitions 70 | else: 71 | start_shift = 0 72 | 73 | c_fc_init[:, start_shift: start_shift + num_partitions, start_shift: start_shift + num_partitions] = 1. / config.scale_embeddings * torch.eye(num_partitions) 74 | #1. / config.scale_embeddings 75 | c_fc_init[:, start_shift: start_shift + num_partitions, mem_head_start + num_partitions: mem_head_start + 2*num_partitions] = torch.eye(num_partitions) 76 | 77 | 78 | #Compute GeLU(x + 1/N \nabla y) - GeLU(x) 79 | 80 | c_proj_init[:, start_shift: start_shift + num_partitions, start_shift: start_shift + num_partitions] = config.scale_embeddings * torch.eye(num_partitions) 81 | c_proj_init[:, start_shift: start_shift + num_partitions, mem_head_start: mem_head_start + num_partitions] = -config.scale_embeddings * torch.eye(num_partitions) 82 | 83 | 84 | with torch.no_grad(): 85 | 86 | self.c_fc.weight.copy_(torch.swapaxes(c_fc_init, axis0=-1, axis1=-2)) 87 | self.proj_fc.weight.copy_(torch.swapaxes(c_proj_init, axis0=-1, axis1=-2)) 88 | 89 | #w acts like a gate to decide what portion of the embedding we apply layernorm on 90 | self.w = torch.zeros (( 1, 1, config.hidden_size )) 91 | if retain_nablay: 92 | self.w [:, :, din : 2*din] += config.gate_scale 93 | else: 94 | self.w [:, :, : din] += config.gate_scale 95 | self.gate = torch.nn.Tanh() 96 | 97 | 98 | #mask out normalization on prefixes 99 | self.normalization_gates = Gates (config) 100 | #Initialize Gates 101 | #Ignore the changes for the prefixes! 102 | #w, u, v, w_bias, u_bias, v_bias 103 | w = torch.zeros((1, 2*config.hidden_size)) 104 | u = torch.zeros((1, 2*config.hidden_size)) 105 | v = torch.zeros((1, 2*config.position_dim)) 106 | w_bias = torch.zeros(2) 107 | u_bias = torch.zeros(2) 108 | v_bias = torch.zeros(2) 109 | 110 | #Input Gate is 1 on prefixes and 0 for non-prefixes 111 | v [0, config.seq_length: config.position_dim] = config.gate_scale * torch.ones(config.num_prefixes) 112 | 113 | 114 | #Change Gate is 0 on prefixes and 1 for non-prefixes 115 | v [0, config.position_dim+config.seq_length: 2*config.position_dim] = -config.gate_scale * torch.ones(config.num_prefixes) 116 | v_bias [1] += config.gate_scale 117 | 118 | self.normalization_gates.initialize_weights (w, u, v, w_bias, u_bias, v_bias) 119 | 120 | 121 | 122 | def forward(self, hidden_states, position_states, attention_mask=None, icl_mask=None): 123 | 124 | weights = self.gate ( self.w ).to(hidden_states.device) 125 | 126 | back_gradient = self.linear.forward(hidden_states, position_states) 127 | 128 | #print (back_gradient[0, 12, :72], back_gradient[0, 12, -72:]) 129 | ####################################################################### 130 | #Next few lines compute the operation: 131 | # f(x) = (x - \mu(x)) / \nabla(x) 132 | # N (f(x + 1/N \nabla y) - f(x)) 133 | ####################################################################### 134 | first_layer = self.c_fc.forward ( back_gradient ) 135 | first_layer = weights * first_layer + (1. - weights) * back_gradient 136 | 137 | #print (first_layer[0, 12, :72]) 138 | 139 | mean = torch.sum(first_layer * weights, dim=-1, keepdim=True) / torch.sum(weights, dim=-1, keepdim=True) 140 | var = ( self.epsilon + torch.sum( (weights * (first_layer - mean)) ** 2, dim=-1, keepdim=True) / torch.sum(weights, dim=-1, keepdim=True) ) ** 0.5 141 | 142 | #print (var) 143 | normalized_states = (first_layer - mean) / var 144 | #print (normalized_states[:, 192, :64]) 145 | 146 | normalized_states = weights * normalized_states + (1. - weights) * first_layer 147 | 148 | second_layer = self.proj_fc.forward ( normalized_states ) 149 | 150 | second_layer = weights * second_layer + (1. - weights) * normalized_states 151 | 152 | ####################################################################### 153 | 154 | gated_output = self.normalization_gates.forward ( hidden_states, second_layer, position_states) 155 | 156 | return gated_output 157 | 158 | 159 | 160 | class LayerNormDescent(nn.Module): 161 | def __init__ (self, config, din, use_softmax, memory_index=-1, debug_zero=False): 162 | super(LayerNormDescent, self).__init__() 163 | self.config=config 164 | self.linear = LinearDescent(config, din=din, dout=din, use_softmax=use_softmax, memory_index=memory_index, debug_zero=debug_zero, update_bias_only=self.config.ln_update_bias_only) 165 | 166 | def forward(self, hidden_states, position_states, attention_mask, activation_memory=None, icl_mask=None): 167 | return self.linear.forward(hidden_states, position_states, attention_mask) 168 | 169 | 170 | 171 | class LayerNormDescent_Backward(nn.Module): 172 | def __init__(self, config, din, use_softmax, debug_zero=False, retain_nablay=False, projection_matrix=None, memory_index=-1): 173 | super(LayerNormDescent_Backward, self).__init__() 174 | 175 | 176 | self.config = config 177 | self.backward = LayerNormBackward(config, \ 178 | din=din, \ 179 | use_softmax=use_softmax, \ 180 | retain_nablay=retain_nablay, \ 181 | memory_index=memory_index, \ 182 | ) 183 | 184 | self.descent = LayerNormDescent(config, \ 185 | din=din, \ 186 | use_softmax=use_softmax, \ 187 | memory_index=memory_index,\ 188 | debug_zero=debug_zero, \ 189 | ) 190 | 191 | def forward(self, hidden_states, position_embeddings, attention_mask, activation_memory=None, icl_mask=None): 192 | backward_out = self.backward(hidden_states, position_embeddings) 193 | descent_out = self.descent (hidden_states, position_embeddings, attention_mask) 194 | 195 | return torch.cat( [ descent_out[:, :self.config.num_prefixes], backward_out[:, self.config.num_prefixes:] ], axis=1) 196 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## TinT: Trainable Transformer in Transformer 2 | 3 | This repository contains the code for our paper Trainable Transformer in Transformer (TinT). 4 | 5 | ## Quick Links 6 | 7 | - [TinT: Trainable Transformer in Transformer](#tint-trainable-transformer-in-transformer) 8 | - [Quick Links](#quick-links) 9 | - [Overview](#overview) 10 | - [Structure of TinT](#structure-of-tint) 11 | - [Creating TinT](#creating-tint) 12 | - [Requirements](#requirements) 13 | - [Perplexity Evaluation](#perplexity-evaluation) 14 | - [Downstream Evaluation](#downstream-evaluation) 15 | - [Hyperparameter Considerations](#hyperparameter-considerations) 16 | - [Unavailable features](#unavailable-features) 17 | - [Bugs or Questions](#bugs-or-questions) 18 | 19 | 20 | 21 | ## Overview 22 | 23 | We propose an efficient construction, Transformer in Transformer (in short, TinT), that allows a transformer to simulate and fine-tune complex models internally during inference (e.g., pre-trained language models). In particular, we introduce innovative approximation techniques that allow a TinT model with less than 2 billion parameters to simulate and fine-tune a 125 million parameter transformer model within a single forward pass. TinT accommodates many common transformer variants and its design ideas also improve the efficiency of past instantiations of simple models inside transformers. We conduct end-to-end experiments to validate the internal fine-tuning procedure of TinT on various language modeling and downstream tasks. For example, even with a limited one-step budget, we observe TinT for a OPT-125M model improves performance by 4-16% absolute on average compared to OPT-125M. These findings suggest that large pre-trained language models are capable of performing intricate subroutines. 24 | 25 | ### Structure of TinT 26 | 27 | Each Forward, Backward, and Descent module is represented using combinations of linear, self-attention, layernorm, and activation layers. The input consists of prefix embeddings, that represent relevant auxiliary model parameters in each layer, input token embeddings, and a binary prefix mask to separate the train and evaluation segments of the input. The auxiliary model parameters are updated in the descent module using the training part of the segment, and the updated prefix tokens are transferred to the forward modules via residual connections for evaluating the rest of the segment. 28 | 29 | ![](figures/icl_overview.png) 30 | 31 | 32 | ## Creating TinT 33 | 34 | In the following section, we provide instructions on creating and evaluating TinT models with our code. 35 | 36 | ### Requirements 37 | Install necessary conda environment using 38 | 39 | ```bash 40 | conda env create -n icl_as_ft --file icl_as_ft.yml 41 | ``` 42 | 43 | ### Create and store TinT 44 | 45 | ```bash 46 | python -m tests.create_model \ 47 | --model_name_or_path $model_path \ 48 | --cache_dir $cache_dir \ 49 | --construct_model_path $model_path \ 50 | --n_simulation_layers $nsim_layers \ 51 | --n_forward_backward $n_forward_backward \ 52 | --inner_lr $lr \ 53 | --n_layers_pergpu $n_layers_pergpu \ 54 | --num_attention_heads $num_attention_heads \ 55 | --hidden_size $hidden_size \ 56 | --num_prefixes $num_prefixes \ 57 | --construct_save_model $construct_save_model \ 58 | --reuse_forward_blocks $reuse_forward_blocks \ 59 | --reuse_backward_blocks $reuse_backward_blocks \ 60 | --restrict_prefixes $restrict_prefixes; 61 | ``` 62 | 63 | 64 | * `model_path`: facebook/opt-125m or gpt2, Auxiliary model to create the TinT model 65 | * `cache_dir`: Directory to store and load opt/gpt2 models 66 | * `construct_save_model`: Whether to save the constructed model 67 | * `construct_model_path`: Path to load or save the constructed model 68 | * `n_simulation_layers`: Number of layers to update during dynamic evaluation 69 | * `n_forward_backward`: Number of SGD steps 70 | * `num_attention_heads`: Number of attention heads in constructed model 71 | * `hidden_size`: Embedding size of constructed model 72 | * `num_prefixes`: Number of prefix tokens 73 | * `inner_lr`: Learning rate for dynamic evaluation; note that in our construction, gradients are summed over tokens (and not averaged) 74 | * `n_layers_pergpu`: When using multiple gpus, partition layers, with n_layers_pergpu per gpu 75 | * `reuse_forward_blocks`: True/False, For multi step SGD, reuse transformer blocks for simulating forward pass 76 | * `reuse_backward_blocks`: True/False, For multi step SGD, reuse transformer blocks for simulating backward pass 77 | * `restrict_prefixes`: For linear operations, we can decide the linear attention heads to only restrict to interactions between prefix tokens and input embeddings 78 | 79 | 80 | An example to create a TinT model from auxiliary model gpt2 is as follows: 81 | 82 | ```bash 83 | python -m tests.create_model \ 84 | --model_name_or_path gpt2 \ 85 | --cache_dir "cache/" \ 86 | --construct_model_path "Constructed_model/TinT_gpt2_innerlr04_ngradlayers12_sgdsteps1" \ 87 | --n_simulation_layers 12 \ 88 | --n_forward_backward 1 \ 89 | --inner_lr 1e-04 \ 90 | --n_layers_pergpu 36 \ 91 | --num_attention_heads 12 \ 92 | --hidden_size 3072 \ 93 | --num_prefixes 256 \ 94 | --construct_save_model True \ 95 | --reuse_forward_blocks True \ 96 | --reuse_backward_blocks True \ 97 | --restrict_prefixes True; 98 | ``` 99 | 100 | 101 | ### Perplexity Evaluation 102 | Use the following commandline to run perplexity evaluation on wikitext-2, wikitext-103, and c4. 103 | 104 | 105 | ```bash 106 | python -m tests.perplexity_eval \ 107 | --dataset $dataset \ 108 | --model_name_or_path $model_path \ 109 | --cache_dir $cache_dir \ 110 | --construct_model_path $model_path \ 111 | --train_fraction $train_fraction \ 112 | --batch_size $batch_size \ 113 | --use_eval_set $use_eval_set\ 114 | --use_test_set $use_test_set\ 115 | --data_subset $data_subset; 116 | ``` 117 | 118 | * `dataset`: c4/wikitext-2/wikitext-103 119 | * `model_path`: facebook/opt-125m or gpt2, Auxiliary model used to create the TinT model 120 | * `cache_dir`: Directory to store and load opt/gpt2 models 121 | * `construct_model_path`: Path to load the constructed model 122 | * `train_fraction`: Fraction of input to use for training (float between 0 and 1)! 123 | * `batch_size`: Batch size for the forward passes 124 | * `use_eval_set`: True/False, Use validation set? 125 | * `use_test_set`: True/False, Use test set? (if both use_eval_set and use_test_set are True, test set is used for evaluation) 126 | * `data_subset`: Evaluation on subset of data (must be a multiple of batch size). 127 | 128 | 129 | The results are stored in a json format, with all the arguments, in a file named **log_exp_construct**. An example for perplexity evaluation of the TinT model "Constructed_model/TinT_gpt2_innerlr04_ngradlayers12_sgdsteps1" on wikitext-103 is as follows: 130 | 131 | ```bash 132 | python -m tests.perplexity_eval \ 133 | --dataset wikitext-103 \ 134 | --model_name_or_path gpt2 \ 135 | --cache_dir "cache/" \ 136 | --construct_model_path "Constructed_model/TinT_gpt2_innerlr04_ngradlayers12_sgdsteps1" \ 137 | --train_fraction 0.5 \ 138 | --batch_size 4 \ 139 | --use_eval_set True\ 140 | --use_test_set False; 141 | ``` 142 | 143 | 144 | ### Downstream Evaluation 145 | 146 | Please refer to the README file in [**icl_eval**](https://github.com/abhishekpanigrahi1996/transformer_in_transformer/tree/main/icl_eval) folder. 147 | 148 | 149 | ### Hyperparameter considerations 150 | 151 | Embedding size and number of attention heads in the TinT model depends on the number of weights that we stack in each prefix, the number of prefix tokens, and dimensions of the auxiliary model. Multiple assertions are present in the code to pertain to these inter-dependencies. We give a set of general rules below to decide on the hyperparameters and provide the hyperparameters that we used to construct the TinT models. 152 | 153 | 154 | There are three important dependencies to consider. 155 | * Embedding size of TinT (given by hidden_size argument) must be equal to the embedding size of the auxiliary model times (the number of weight rows that we stack per prefix token + 1). The addition of 1 is to include the bias terms in the first prefix token. E.g. for gpt2, whose embedding dimension is 768, if we decide to stack 3 weight rows per prefix token, the embedding dimension of TinT should be 768 * 4. 156 | * The number of weight rows that we stack per prefix token is equal to the number of weight rows divided by the number of prefixes (given by num_prefixes argument). E.g. for gpt2, whose embedding dimension is 768, if we decide to stack 3 weight rows per prefix token, the number of prefix tokens should be 256. 157 | * hidden_size should be divisible by the number of attention heads (given by num_attention_heads). 158 | * Attention head dimension (given by hidden_size // num_attention_heads) should be a factor of the auxiliary model's embedding dimension. This is to ensure that we can partition the embeddings of the auxiliary model equally across a few attention heads. 159 | * hidden_size must be divisible by num_prefix_tokens. Our current implementation allows unequal number of attention heads in linear attention, used to simulate linear operations, and softmax attention, used to simulate operations involving self attention. The number of attention heads in linear attention is given by (hidden_size // num_prefix_tokens). 160 | 161 | 162 | We use the following hyperparameters to create the TinT model and the perplexity/downstream evaluations. We report inner_lr (learning rate of dynamic evaluation) for the models that we report the numbers on. 163 | 164 | | | gpt2 | gpt2-medium | gpt2-large | gpt2-xl | 165 | |:--------------|:-----------:|:--------------:|:---------:|:---------:| 166 | | Auxiliary model embedding size | 768 | 1024 | 1280 | 1600 | 167 | | Auxiliary model attention heads | 12 | 16 | 20 | 25 | 168 | | Number of layers | 12 | 24 | 36 | 48 | 169 | | TinT hidden_size | 3072 | 5120 | 6400 | 9600 | 170 | | TinT num_prefixes | 256 | 256 | 320 | 320 | 171 | | TinT num_attention_heads | 12 | 20 | 20 | 30 | 172 | | Inner LR (dynamic eval) | 1e-3, 5e-4, 1e-4, 1e-5 | - | - | - | 173 | 174 | 175 | 176 | | | facebook/opt-125m | facebook/opt-350m* | facebook/opt-1.3b | facebook/opt-2.7b | 177 | |:--------------|:-----------:|:--------------:|:---------:|:---------:| 178 | | Auxiliary model embedding size | 768 | 1024 | 2048 | 2560 | 179 | | Auxiliary model attention heads | 12 | 16 | 32 | 32 | 180 | | Number of layers | 12 | 24 | 24 | 32 | 181 | | TinT hidden_size | 3072 | 5120 | 10240 | 12800 | 182 | | TinT num_prefixes | 256 | 256 | 512 | 640 | 183 | | TinT num_attention_heads | 12 | 20 | 40 | 40 | 184 | | Inner LR (dynamic eval) | 1e-5, 1e-6, 1e-7 | - | - | - | 185 | 186 | *We can't handle post layer norm in facebook/opt-350m in the current code. 187 | 188 | ### Unavailable features 189 | 190 | The current code doesn't contain the following features, which we plan to slowly integrate in the future. 191 | * `Post layer norm`: Currently, our code doesn't handle post layer norm and hence can't create TinT for facebook/opt-350m. 192 | * `Cross self attention`: The self attention module hasn't been modified to handle cross attention. 193 | * `TinT modules for gated linear units (GLUs)`: We will integrate the modules for GLUs soon. 194 | * `Attention variants`: We will integrate attention variants like AliBi and relative attention soon. 195 | * `RMSnorm`: We will integrate modules for RMSnorm soon. 196 | * `TinT for GPT-J, BLOOM, LLaMA`: We will include TinT creators for these models soon. 197 | 198 | 199 | ## Bugs or Questions 200 | If you have any questions related to the code, feel free to email Abhishek or Mengzhou (`{ap34,mengzhou}@cs.princeton.edu`). If you encounter a problem or bug when using the code, you can also open an issue. 201 | 202 | 203 | -------------------------------------------------------------------------------- /tint_main/utils/self_attention/forward.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import math 4 | import os 5 | from dataclasses import dataclass 6 | from typing import Optional, Tuple, Union 7 | 8 | import torch 9 | import torch.utils.checkpoint 10 | from torch import nn 11 | from torch.cuda.amp import autocast 12 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 13 | from ..modules import * 14 | from ..linear import * 15 | 16 | 17 | 18 | 19 | 20 | 21 | #The following module simulates the forward operation of a softmax self-attention in the Auxiliary model. 22 | #Output: A module that first computes Query, Key, Value using a single LinearForward module, followed by a softmax attention layer 23 | 24 | #All arguments 25 | #self, \ 26 | #config, \ #TinT config file 27 | #din, \ #input dimension 28 | #num_attnt_heads, \ #number of attention heads in the Auxiliary's self-attention 29 | #use_softmax=False, \ #use_softmax=False, implying we use linear attention head 30 | #projection_matrix=None, \ #projection_matrix to project the linear operation (not used in the current code) 31 | #separate_QK,\ #Unnecessary argument, to be removed 32 | #memory_index=-1, \ #memory_index to store activations for the backward and descent pass! 33 | class AttentionForward (nn.Module): 34 | def __init__ (self, \ 35 | config, \ 36 | din, \ 37 | num_attnt_heads, \ 38 | use_softmax=False, \ 39 | projection_matrix=None, \ 40 | separate_QK=False, \ 41 | memory_index=0,\ 42 | ): 43 | super(AttentionForward, self).__init__() 44 | 45 | assert use_softmax==False ,\ 46 | "Currently I only use linear attention for linear operations in this module" 47 | 48 | assert num_attnt_heads <= config.num_attention_heads,\ 49 | "Number of attention heads should be at least the number of attention heads necessary to simulate" 50 | 51 | self.separate_QK = separate_QK 52 | if projection_matrix is not None: 53 | dout = projection_matrix.shape[1] 54 | else: 55 | if separate_QK: dout = 2*din 56 | else: dout = din 57 | 58 | self.linear = LinearForward(config, \ 59 | din=din, \ 60 | dout=dout, \ 61 | use_softmax=use_softmax, \ 62 | projection_matrix=projection_matrix, \ 63 | memory_index=-1,\ 64 | ) 65 | 66 | self.key_linear = self.linear 67 | self.value_linear = self.linear 68 | 69 | 70 | 71 | self.gates = Gates (config) 72 | 73 | self.din = din 74 | self.num_attnt_heads = num_attnt_heads 75 | self.config = config 76 | self.memory_index = memory_index 77 | 78 | 79 | head_dim = config.hidden_size // config.num_attention_heads 80 | basemodel_head_dim = din // num_attnt_heads 81 | 82 | self.attnt_module = Attention (config, normalize=True, proj_conv2d=True, proj_conv_dim=head_dim, proj_transpose=True) 83 | 84 | assert din % head_dim == 0, \ 85 | "a bug! 'din' should be divisible by head dimensions" 86 | 87 | num_partitions = din // head_dim 88 | 89 | assert num_attnt_heads % num_partitions == 0, \ 90 | "Num of attention heads should be divisible by num of partitions" 91 | 92 | num_attnt_heads_per_partition = num_attnt_heads // num_partitions 93 | 94 | #--------------------------------#--------------------------------# 95 | #For all Attention heads on the embeddings 96 | #Query uses the first set of din coordinates and splits them among the first 'num_attnt_heads' attention heads 97 | #Key uses the second set of din coordinates and splits them among the first 'num_attnt_heads' attention heads 98 | #Value uses the third set of din coordinates and splits them among the first 'num_attnt_heads' attention heads 99 | #Key and Query of embeddings ignore the position dependence. 100 | #--------------------------------#--------------------------------# 101 | 102 | q_attn_head = torch.zeros((head_dim, config.num_attention_heads, config.num_attention_heads)) 103 | for i in range(num_partitions): 104 | for j in range(num_attnt_heads_per_partition): 105 | q_attn_head[ :, i * num_attnt_heads_per_partition + j, i ] = 1. 106 | 107 | 108 | 109 | q_attn = torch.zeros((config.num_attention_heads, head_dim, head_dim)) 110 | for i in range(num_attnt_heads): 111 | partition = i % num_attnt_heads_per_partition 112 | q_attn[ i, :basemodel_head_dim, partition*basemodel_head_dim: (partition + 1)*basemodel_head_dim ] = torch.eye(basemodel_head_dim) 113 | 114 | 115 | k_attn_head = torch.zeros((head_dim, config.num_attention_heads, config.num_attention_heads)) 116 | for i in range(num_partitions): 117 | for j in range(num_attnt_heads_per_partition): 118 | k_attn_head[ :, i * num_attnt_heads_per_partition + j, i + num_partitions] = 1. 119 | 120 | 121 | 122 | k_attn = torch.zeros((config.num_attention_heads, head_dim, head_dim)) 123 | for i in range(num_attnt_heads): 124 | partition = i % num_attnt_heads_per_partition 125 | k_attn[ i, :basemodel_head_dim, partition*basemodel_head_dim: (partition + 1)*basemodel_head_dim ] = torch.eye(basemodel_head_dim) 126 | 127 | 128 | 129 | v_attn_head = torch.zeros((head_dim, config.num_attention_heads, config.num_attention_heads)) 130 | for i in range(num_partitions): 131 | for j in range(num_attnt_heads_per_partition): 132 | v_attn_head[ :, i * num_attnt_heads_per_partition + j, i + 2 * num_partitions] = 1. 133 | 134 | 135 | v_attn = torch.zeros((config.num_attention_heads, head_dim, head_dim)) 136 | for i in range(num_attnt_heads): 137 | partition = i % num_attnt_heads_per_partition 138 | v_attn[ i, partition*basemodel_head_dim: (partition + 1)*basemodel_head_dim, partition*basemodel_head_dim: (partition + 1)*basemodel_head_dim ] = torch.eye(basemodel_head_dim) 139 | 140 | 141 | #c_attn_init, c_attn_bias = torch.cat([query, key, value], axis=0), torch.zeros(5 * config.hidden_size) 142 | 143 | 144 | #--------------------------------#--------------------------------# 145 | #For all Attention heads on the positions 146 | #Query, Key are set such that we never attend to the blank tokens! 147 | #--------------------------------#--------------------------------# 148 | 149 | #--------------------------------#--------------------------------# 150 | #The projection matrix takes the output of the attention heads, which has the required signal only in its first basemodel_head_dim coordiantes 151 | #We merge them together and return them at the head of the embedding 152 | #--------------------------------#--------------------------------# 153 | c_proj_init = torch.zeros((head_dim, config.num_attention_heads, config.num_attention_heads)) 154 | for i in range(num_partitions): 155 | c_proj_init[:, i, i*num_attnt_heads_per_partition: (i+1)*num_attnt_heads_per_partition] = 1. 156 | 157 | 158 | 159 | self.attnt_module.initialize_weights(q_attn_init=q_attn,\ 160 | q_attn_init_head=q_attn_head,\ 161 | k_attn_init=k_attn,\ 162 | k_attn_init_head=k_attn_head,\ 163 | v_attn_init=v_attn,\ 164 | v_attn_init_head=v_attn_head,\ 165 | c_proj_init=c_proj_init ) 166 | 167 | 168 | #Initialize Gates 169 | #Ignore the changes for the prefixes! 170 | #w, u, v, w_bias, u_bias, v_bias 171 | w = torch.zeros((1, 2*config.hidden_size)) 172 | u = torch.zeros((1, 2*config.hidden_size)) 173 | v = torch.zeros((1, 2*config.position_dim)) 174 | w_bias = torch.zeros(2) 175 | u_bias = torch.zeros(2) 176 | v_bias = torch.zeros(2) 177 | 178 | #Input Gate is 1 on prefixes and 0 for non-prefixes 179 | v [0, config.seq_length: config.position_dim] = config.gate_scale * torch.ones(config.num_prefixes) 180 | 181 | #Change Gate is 0 on prefixes and 1 for non-prefixes 182 | v [0, config.position_dim+config.seq_length: 2*config.position_dim] = - config.gate_scale * torch.ones(config.num_prefixes) 183 | v_bias [1] += config.gate_scale 184 | 185 | self.gates.initialize_weights (w, u, v, w_bias, u_bias, v_bias) 186 | 187 | 188 | 189 | 190 | def forward(self, hidden_states, position_states, key_weights=None, value_weights=None, icl_mask=None): 191 | 192 | linear_output = self.linear.forward(hidden_states, position_states) 193 | 194 | 195 | if not self.separate_QK: 196 | inp_hidden_states = torch.cat( [key_weights, hidden_states[:, self.config.num_prefixes:] ], axis=1) 197 | key_out = self.key_linear(inp_hidden_states, position_states) 198 | assert torch.sum(linear_output[:, self.config.num_prefixes:, self.din:]).item() < 1e-10,\ 199 | "Key portion not empty!" 200 | linear_output[:, self.config.num_prefixes:, self.din:] += key_out[:, self.config.num_prefixes:, :-self.din] 201 | 202 | 203 | 204 | inp_hidden_states = torch.cat( [value_weights, hidden_states[:, self.config.num_prefixes:] ], axis=1) 205 | value_out = self.value_linear(inp_hidden_states, position_states) 206 | 207 | assert torch.sum(linear_output[:, self.config.num_prefixes:, 2*self.din:]).item() < 1e-10,\ 208 | "Value portion not empty!" 209 | linear_output[:, self.config.num_prefixes:, 2*self.din:] += value_out[:, self.config.num_prefixes:, :-2*self.din] 210 | 211 | #Send a mask such that the tokens don't attend to the blank tokens 212 | normalization_mask = torch.zeros( (1, 1, len(hidden_states[0]), len(hidden_states[0]) ) ) 213 | normalization_mask[:, :, :, :self.config.num_prefixes] = torch.finfo(self.attnt_module.p_attn.weight.dtype).min 214 | 215 | #icl_mask needs to be a 3D tensor of shape (batch_size, seqlen, seqlen) 216 | #icl_mask[i, j] = 1 if token i tends to token j 217 | 218 | if icl_mask is not None: 219 | 220 | bt = icl_mask.shape[0] 221 | for i in range( bt ): 222 | sq1 = icl_mask[i].shape[0] 223 | sq2 = icl_mask[i].shape[1] 224 | nb = self.config.num_prefixes 225 | 226 | normalization_mask[i, :, nb: nb+sq1, nb: nb+sq2] = torch.tril( torch.round(torch.clip(1. - icl_mask[i], 0., 1.)) ) * torch.finfo(self.attnt_module.p_attn.weight.dtype).min 227 | 228 | #print ("------Attention-------") 229 | attnt_output = self.attnt_module.forward(linear_output, position_states, normalization_mask=normalization_mask) [0] 230 | 231 | if self.memory_index != -1: 232 | #keep Qx, Kx in memory! 233 | #Keep also x separately afterwards! 234 | assert torch.sum(attnt_output[:, self.config.num_prefixes:, self.memory_index:]).item() < 1e-10,\ 235 | "Memory portion not empty!" 236 | 237 | attnt_output[:, self.config.num_prefixes:, self.memory_index: self.memory_index+2*self.din] += linear_output[:, self.config.num_prefixes:, :2*self.din] 238 | attnt_output[:, self.config.num_prefixes:, self.memory_index+2*self.din: self.memory_index+3*self.din] += hidden_states[:, self.config.num_prefixes:, :self.din] 239 | 240 | gate_output = self.gates.forward(linear_output, attnt_output, position_states) 241 | 242 | return gate_output 243 | -------------------------------------------------------------------------------- /tint_main/utils/self_attention/backward.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import math 4 | import os 5 | from dataclasses import dataclass 6 | from typing import Optional, Tuple, Union 7 | 8 | import torch 9 | import torch.utils.checkpoint 10 | from torch import nn 11 | from torch.cuda.amp import autocast 12 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 13 | from ..modules import * 14 | from ..linear import * 15 | 16 | 17 | 18 | #Implements the stop-attention gradient, where we don't compute the gradient w.r.t. the attention scores 19 | #The module contains one attention module, where the attention scores are re-computed between query and key vectors, transposed before dispersing the gradients. 20 | 21 | #All arguments 22 | #self, \ 23 | #config, \ #TinT config file 24 | #din, \ #input dimension of auxiliary's linear layer 25 | #num_attnt_heads, \ #number of attention heads in the Auxiliary's self-attention 26 | #use_softmax=False, \ #linear self_attention used 27 | #retain_nablay=False, \ #Retain nablay for Descent pass 28 | #memory_index=-1,\ #Start index where activations are stored in Linear Forward. 29 | 30 | class AttentionBackward(nn.Module): 31 | 32 | def __init__ (self, \ 33 | config, \ 34 | din, \ 35 | num_attnt_heads, \ 36 | use_softmax, \ 37 | retain_nablay=False, \ 38 | memory_index=-1,\ 39 | ): 40 | super(AttentionBackward, self).__init__() 41 | 42 | assert use_softmax==False ,\ 43 | "Currently I only use linear attention in this module" 44 | 45 | 46 | 47 | self.attnt_gates = Gates (config) 48 | 49 | self.retain_nablay = retain_nablay 50 | self.memory_index = memory_index 51 | self.config = config 52 | self.din = din 53 | 54 | ##### First attention module ####### 55 | ########### Assumption ############# 56 | #The memory part has the following format [Qx, Kx] which helps us to re-compute the attention scores 57 | #We compute \sum_j a_{j, i} \nabla y_j 58 | ########### Assumption ############# 59 | 60 | 61 | head_dim = config.hidden_size // config.num_attention_heads 62 | basemodel_head_dim = din // num_attnt_heads 63 | 64 | self.attnt_module = Attention (config, peak_into_future=True, normalize=True, attnt_back=True, proj_conv2d=True, proj_conv_dim=head_dim, proj_transpose=True) 65 | 66 | 67 | assert self.memory_index <= config.hidden_size - ( 3 * self.din ), \ 68 | "Not enough memory to simulate backward pass" 69 | assert self.memory_index == -1 or self.memory_index >= self.din, \ 70 | "Memory is crossing current signal (and additional computation space)!" 71 | 72 | 73 | #--------------------------------#--------------------------------# 74 | #For all Attention heads on the embeddings 75 | #Query uses the first set of din coordinates in memory and splits them among the first 'num_attnt_heads' attention heads 76 | #Key uses the second set of din coordinates in memory and splits them among the first 'num_attnt_heads' attention heads 77 | #Value uses the first set of din coordinates and splits them among the first 'num_attnt_heads' attention heads 78 | #Key and Query of embeddings ignore the position dependence. 79 | #--------------------------------#--------------------------------# 80 | 81 | num_partitions = din // head_dim 82 | 83 | assert num_attnt_heads % num_partitions == 0, \ 84 | "Num of attention heads should be divisible by num of partitions" 85 | 86 | num_attnt_heads_per_partition = num_attnt_heads // num_partitions 87 | 88 | 89 | q_attn_head = torch.zeros((head_dim, config.num_attention_heads, config.num_attention_heads)) 90 | 91 | assert memory_index % head_dim == 0,\ 92 | "Memory index should be multiple of head_dim" 93 | mem_head_start = memory_index // head_dim 94 | 95 | 96 | for i in range(num_partitions): 97 | for j in range(num_attnt_heads_per_partition): 98 | q_attn_head[ :, i * num_attnt_heads_per_partition + j, i + mem_head_start ] = 1. 99 | 100 | q_attn = torch.zeros((config.num_attention_heads, head_dim, head_dim)) 101 | for i in range(num_attnt_heads): 102 | partition = i % num_attnt_heads_per_partition 103 | q_attn[ i, :basemodel_head_dim, partition*basemodel_head_dim: (partition + 1)*basemodel_head_dim ] = torch.eye(basemodel_head_dim) 104 | 105 | 106 | 107 | 108 | k_attn_head = torch.zeros((head_dim, config.num_attention_heads, config.num_attention_heads)) 109 | for i in range(num_partitions): 110 | for j in range(num_attnt_heads_per_partition): 111 | k_attn_head[ :, i * num_attnt_heads_per_partition + j, i + num_partitions + mem_head_start] = 1. 112 | 113 | 114 | k_attn = torch.zeros((config.num_attention_heads, head_dim, head_dim)) 115 | for i in range(num_attnt_heads): 116 | partition = i % num_attnt_heads_per_partition 117 | k_attn[ i, :basemodel_head_dim, partition*basemodel_head_dim: (partition + 1)*basemodel_head_dim ] = torch.eye(basemodel_head_dim) 118 | 119 | 120 | value_head = 0 121 | 122 | 123 | v_attn_head = torch.zeros((head_dim, config.num_attention_heads, config.num_attention_heads)) 124 | for i in range(num_partitions): 125 | for j in range(num_attnt_heads_per_partition): 126 | v_attn_head[ :, i * num_attnt_heads_per_partition + j, i + value_head ] = 1. 127 | 128 | 129 | 130 | 131 | v_attn = torch.zeros((config.num_attention_heads, head_dim, head_dim)) 132 | for i in range(num_attnt_heads): 133 | partition = i % num_attnt_heads_per_partition 134 | v_attn[ i, partition*basemodel_head_dim: (partition + 1)*basemodel_head_dim, partition*basemodel_head_dim: (partition + 1)*basemodel_head_dim ] = torch.eye(basemodel_head_dim) 135 | 136 | 137 | c_proj_init = torch.zeros((head_dim, config.num_attention_heads, config.num_attention_heads)) 138 | 139 | for i in range(num_partitions): 140 | c_proj_init[:, i + value_head, i*num_attnt_heads_per_partition: (i+1)*num_attnt_heads_per_partition] = 1. 141 | 142 | 143 | 144 | self.attnt_module.initialize_weights(q_attn_init=q_attn,\ 145 | q_attn_init_head=q_attn_head,\ 146 | k_attn_init=k_attn,\ 147 | k_attn_init_head=k_attn_head,\ 148 | v_attn_init=v_attn,\ 149 | v_attn_init_head=v_attn_head,\ 150 | c_proj_init=c_proj_init ) 151 | 152 | #Initialize the first attention Gates 153 | #Ignore the changes for the prefixes! 154 | #w, u, v, w_bias, u_bias, v_bias 155 | w = torch.zeros((1, 2*config.hidden_size)) 156 | u = torch.zeros((1, 2*config.hidden_size)) 157 | v = torch.zeros((1, 2*config.position_dim)) 158 | w_bias = torch.zeros(2) 159 | u_bias = torch.zeros(2) 160 | v_bias = torch.zeros(2) 161 | 162 | #if self.retain_nablay: 163 | #Input Gate is 1 164 | # v_bias [0] += config.gate_scale 165 | #else: 166 | #Input Gate is 1 on prefixesß and 0 for non-prefixes 167 | v [0, config.seq_length:config.position_dim] = config.gate_scale * torch.ones(config.num_prefixes) 168 | 169 | #Change Gate is 0 on prefixes and 1 for non-prefixes 170 | v [0, config.position_dim+config.seq_length: 2*config.position_dim] = -config.gate_scale * torch.ones(config.num_prefixes) 171 | v_bias [1] += config.gate_scale 172 | 173 | self.attnt_gates.initialize_weights (w, u, v, w_bias, u_bias, v_bias) 174 | 175 | 176 | 177 | def forward(self, hidden_states, position_states, attention_mask, icl_mask=None): 178 | 179 | 180 | #add a mask to avoid attention on blank tokens! 181 | 182 | 183 | normalization_mask = torch.zeros( (1, 1, len(hidden_states[0]), len(hidden_states[0]) ) ) 184 | normalization_mask[:, :, :, :self.config.num_prefixes] = torch.finfo(self.attnt_module.p_attn.weight.dtype).min 185 | 186 | #icl_mask needs to be a 3D tensor of shape (batch_size, seqlen, seqlen) 187 | #icl_mask[i, j] = 1 if token i tends to token j 188 | 189 | if icl_mask is not None: 190 | bt = icl_mask.shape[0] 191 | for i in range( bt ): 192 | sq1 = icl_mask[i].shape[0] 193 | sq2 = icl_mask[i].shape[1] 194 | nb = self.config.num_prefixes 195 | normalization_mask[i, :, nb: nb+sq1, nb: nb+sq2] = torch.tril( torch.round(torch.clip(1. - icl_mask[i], 0., 1.)) ) * torch.finfo(self.attnt_module.p_attn.weight.dtype).min 196 | 197 | #print ("----Mask----", attention_mask) 198 | modified_attention_mask = attention_mask.detach().clone() 199 | modified_attention_mask[:, :, :, :self.config.num_prefixes] = 0. 200 | 201 | attnt_output = self.attnt_module.forward(hidden_states, \ 202 | position_states, \ 203 | attention_mask=modified_attention_mask, \ 204 | normalization_mask=normalization_mask\ 205 | ) [0] 206 | 207 | 208 | end_dim = self.memory_index + 3*self.din 209 | 210 | attnt_output[:, self.config.num_prefixes:, self.memory_index: end_dim] += hidden_states[:, self.config.num_prefixes:, self.memory_index: end_dim] 211 | 212 | 213 | gate_output = self.attnt_gates.forward(hidden_states, \ 214 | attnt_output, \ 215 | position_states\ 216 | ) 217 | 218 | return gate_output 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | #Implements descent w.r.t. the value matrix 227 | #The module simply calls LinearDescent module on the current embeddings 228 | class AttentionDescent(nn.Module): 229 | def __init__ (self, config, din, num_attnt_heads, use_softmax, memory_index=-1, debug_zero=False, retain_nablay=False): 230 | super(AttentionDescent, self).__init__() 231 | self.linear = LinearDescent(config, din=din, dout=din, use_softmax=use_softmax, memory_index=memory_index+2*din, debug_zero=debug_zero) 232 | 233 | 234 | def forward(self, hidden_states, position_states, attention_mask, icl_mask=None): 235 | return self.linear.forward(hidden_states, position_states, attention_mask) 236 | 237 | 238 | 239 | #Combines Backward and Descent module, since Descent module uses the gradient from Backward pass. 240 | class AttentionBackward_Descent(nn.Module): 241 | def __init__ (self, config, din, num_attnt_heads, use_softmax, memory_index=-1, debug_zero=False, projection_matrix=None, retain_nablay=False): 242 | super(AttentionBackward_Descent, self).__init__() 243 | 244 | self.config = config 245 | self.memory_index = memory_index 246 | self.din = din 247 | self.attention_back = AttentionBackward(config, \ 248 | din=din, \ 249 | num_attnt_heads=num_attnt_heads, \ 250 | memory_index=memory_index, \ 251 | use_softmax=use_softmax, \ 252 | retain_nablay=retain_nablay,\ 253 | ) 254 | 255 | 256 | self.linearback_descent = Linear_Descent_Backward(config, \ 257 | din=din, \ 258 | dout=din, \ 259 | use_softmax=use_softmax, \ 260 | memory_index=memory_index+2*din, \ 261 | debug_zero=debug_zero, \ 262 | projection_matrix=projection_matrix, \ 263 | retain_nablay=retain_nablay, \ 264 | ) 265 | 266 | 267 | def forward(self, hidden_states, position_states, attention_mask, icl_mask=None): 268 | if self.config.backprop_through_attention: 269 | attention_backout = self.attention_back(hidden_states, position_states, attention_mask, icl_mask=icl_mask) 270 | else: 271 | attention_backout = hidden_states 272 | 273 | 274 | attention_descentout = self.linearback_descent(attention_backout, position_states, attention_mask) 275 | 276 | 277 | return attention_descentout -------------------------------------------------------------------------------- /icl_eval/models/modeling_opt.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ PyTorch OPT model.""" 16 | import random 17 | from typing import List, Optional, Tuple, Union 18 | 19 | import torch 20 | import torch.utils.checkpoint 21 | from torch import nn 22 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 23 | 24 | from transformers.activations import ACT2FN 25 | from transformers.modeling_outputs import ( 26 | BaseModelOutputWithPast, 27 | CausalLMOutputWithPast, 28 | QuestionAnsweringModelOutput, 29 | SequenceClassifierOutputWithPast, 30 | ) 31 | from transformers.modeling_utils import PreTrainedModel 32 | from transformers.utils import ( 33 | add_code_sample_docstrings, 34 | add_start_docstrings, 35 | add_start_docstrings_to_model_forward, 36 | logging, 37 | replace_return_docstrings, 38 | ) 39 | from transformers.models.opt.configuration_opt import OPTConfig 40 | from transformers.models.opt.modeling_opt import OPTPreTrainedModel, OPTDecoder, OPTModel, OPTForCausalLM 41 | 42 | logger = logging.get_logger(__name__) 43 | 44 | class LocalOPTDecoder(OPTDecoder): 45 | # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask 46 | def forward( 47 | self, 48 | input_ids: torch.LongTensor = None, 49 | attention_mask: Optional[torch.Tensor] = None, 50 | head_mask: Optional[torch.Tensor] = None, 51 | past_key_values: Optional[List[torch.FloatTensor]] = None, 52 | inputs_embeds: Optional[torch.FloatTensor] = None, 53 | use_cache: Optional[bool] = None, 54 | output_attentions: Optional[bool] = None, 55 | output_hidden_states: Optional[bool] = None, 56 | return_dict: Optional[bool] = None, 57 | twod_attention_mask: Optional[torch.Tensor] = None 58 | ) -> Union[Tuple, BaseModelOutputWithPast]: 59 | 60 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 61 | output_hidden_states = ( 62 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 63 | ) 64 | use_cache = use_cache if use_cache is not None else self.config.use_cache 65 | 66 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 67 | 68 | # retrieve input_ids and inputs_embeds 69 | if input_ids is not None and inputs_embeds is not None: 70 | raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") 71 | elif input_ids is not None: 72 | input_shape = input_ids.size() 73 | input_ids = input_ids.view(-1, input_shape[-1]) 74 | elif inputs_embeds is not None: 75 | input_shape = inputs_embeds.size()[:-1] 76 | else: 77 | raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") 78 | 79 | if inputs_embeds is None: 80 | inputs_embeds = self.embed_tokens(input_ids) 81 | 82 | batch_size, seq_length = input_shape 83 | past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 84 | # required mask seq length can be calculated via length of past 85 | mask_seq_length = past_key_values_length + seq_length 86 | 87 | # embed positions 88 | if attention_mask is None: 89 | attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) 90 | causal_attention_mask = self._prepare_decoder_attention_mask( 91 | attention_mask, input_shape, inputs_embeds, past_key_values_length 92 | ) 93 | if twod_attention_mask is not None: 94 | causal_attention_mask += twod_attention_mask 95 | pos_embeds = self.embed_positions(attention_mask, past_key_values_length) 96 | 97 | if self.project_in is not None: 98 | inputs_embeds = self.project_in(inputs_embeds) 99 | 100 | hidden_states = inputs_embeds + pos_embeds 101 | 102 | if self.gradient_checkpointing and self.training: 103 | if use_cache: 104 | logger.warning_once( 105 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 106 | ) 107 | use_cache = False 108 | 109 | # decoder layers 110 | all_hidden_states = () if output_hidden_states else None 111 | all_self_attns = () if output_attentions else None 112 | next_decoder_cache = () if use_cache else None 113 | 114 | # check if head_mask has a correct number of layers specified if desired 115 | for attn_mask, mask_name in zip([head_mask], ["head_mask"]): 116 | if attn_mask is not None: 117 | if attn_mask.size()[0] != (len(self.layers)): 118 | raise ValueError( 119 | f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" 120 | f" {head_mask.size()[0]}." 121 | ) 122 | 123 | for idx, decoder_layer in enumerate(self.layers): 124 | # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) 125 | if output_hidden_states: 126 | all_hidden_states += (hidden_states,) 127 | 128 | dropout_probability = random.uniform(0, 1) 129 | if self.training and (dropout_probability < self.layerdrop): 130 | continue 131 | 132 | past_key_value = past_key_values[idx] if past_key_values is not None else None 133 | 134 | if self.gradient_checkpointing and self.training: 135 | 136 | def create_custom_forward(module): 137 | def custom_forward(*inputs): 138 | # None for past_key_value 139 | return module(*inputs, output_attentions, None) 140 | 141 | return custom_forward 142 | 143 | layer_outputs = torch.utils.checkpoint.checkpoint( 144 | create_custom_forward(decoder_layer), 145 | hidden_states, 146 | causal_attention_mask, 147 | head_mask[idx] if head_mask is not None else None, 148 | None, 149 | ) 150 | else: 151 | layer_outputs = decoder_layer( 152 | hidden_states, 153 | attention_mask=causal_attention_mask, 154 | layer_head_mask=(head_mask[idx] if head_mask is not None else None), 155 | past_key_value=past_key_value, 156 | output_attentions=output_attentions, 157 | use_cache=use_cache, 158 | ) 159 | 160 | hidden_states = layer_outputs[0] 161 | 162 | if use_cache: 163 | next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) 164 | 165 | if output_attentions: 166 | all_self_attns += (layer_outputs[1],) 167 | 168 | if self.final_layer_norm is not None: 169 | hidden_states = self.final_layer_norm(hidden_states) 170 | 171 | if self.project_out is not None: 172 | hidden_states = self.project_out(hidden_states) 173 | 174 | # add hidden states from the last decoder layer 175 | if output_hidden_states: 176 | all_hidden_states += (hidden_states,) 177 | 178 | next_cache = next_decoder_cache if use_cache else None 179 | if not return_dict: 180 | return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) 181 | return BaseModelOutputWithPast( 182 | last_hidden_state=hidden_states, 183 | past_key_values=next_cache, 184 | hidden_states=all_hidden_states, 185 | attentions=all_self_attns, 186 | ) 187 | 188 | 189 | class LocalOPTModel(OPTModel): 190 | def __init__(self, config: OPTConfig): 191 | super().__init__(config) 192 | self.decoder = LocalOPTDecoder(config) 193 | # Initialize weights and apply final processing 194 | self.post_init() 195 | 196 | def forward( 197 | self, 198 | input_ids: torch.LongTensor = None, 199 | attention_mask: Optional[torch.Tensor] = None, 200 | head_mask: Optional[torch.Tensor] = None, 201 | past_key_values: Optional[List[torch.FloatTensor]] = None, 202 | inputs_embeds: Optional[torch.FloatTensor] = None, 203 | use_cache: Optional[bool] = None, 204 | output_attentions: Optional[bool] = None, 205 | output_hidden_states: Optional[bool] = None, 206 | return_dict: Optional[bool] = None, 207 | twod_attention_mask: Optional[torch.Tensor] = None 208 | ) -> Union[Tuple, BaseModelOutputWithPast]: 209 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 210 | output_hidden_states = ( 211 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 212 | ) 213 | use_cache = use_cache if use_cache is not None else self.config.use_cache 214 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 215 | 216 | # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) 217 | decoder_outputs = self.decoder( 218 | input_ids=input_ids, 219 | attention_mask=attention_mask, 220 | head_mask=head_mask, 221 | past_key_values=past_key_values, 222 | inputs_embeds=inputs_embeds, 223 | use_cache=use_cache, 224 | output_attentions=output_attentions, 225 | output_hidden_states=output_hidden_states, 226 | return_dict=return_dict, 227 | twod_attention_mask=twod_attention_mask 228 | ) 229 | 230 | if not return_dict: 231 | return decoder_outputs 232 | 233 | return BaseModelOutputWithPast( 234 | last_hidden_state=decoder_outputs.last_hidden_state, 235 | past_key_values=decoder_outputs.past_key_values, 236 | hidden_states=decoder_outputs.hidden_states, 237 | attentions=decoder_outputs.attentions, 238 | ) 239 | 240 | class LocalOPTForCausalLM(OPTForCausalLM): 241 | def __init__(self, config): 242 | super().__init__(config) 243 | self.model = LocalOPTModel(config) 244 | 245 | # the lm_head weight is automatically tied to the embed tokens weight 246 | self.lm_head = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False) 247 | 248 | # Initialize weights and apply final processing 249 | self.post_init() 250 | 251 | def forward( 252 | self, 253 | input_ids: torch.LongTensor = None, 254 | attention_mask: Optional[torch.Tensor] = None, 255 | head_mask: Optional[torch.Tensor] = None, 256 | past_key_values: Optional[List[torch.FloatTensor]] = None, 257 | inputs_embeds: Optional[torch.FloatTensor] = None, 258 | labels: Optional[torch.LongTensor] = None, 259 | use_cache: Optional[bool] = None, 260 | output_attentions: Optional[bool] = None, 261 | output_hidden_states: Optional[bool] = None, 262 | return_dict: Optional[bool] = None, 263 | twod_attention_mask: Optional[torch.Tensor] = None 264 | ) -> Union[Tuple, CausalLMOutputWithPast]: 265 | 266 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 267 | output_hidden_states = ( 268 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 269 | ) 270 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 271 | 272 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 273 | 274 | outputs = self.model.decoder( 275 | input_ids=input_ids, 276 | attention_mask=attention_mask, 277 | head_mask=head_mask, 278 | past_key_values=past_key_values, 279 | inputs_embeds=inputs_embeds, 280 | use_cache=use_cache, 281 | output_attentions=output_attentions, 282 | output_hidden_states=output_hidden_states, 283 | return_dict=return_dict, 284 | twod_attention_mask=twod_attention_mask 285 | ) 286 | 287 | logits = self.lm_head(outputs[0]).contiguous() 288 | 289 | loss = None 290 | if labels is not None: 291 | # Shift so that tokens < n predict n 292 | shift_logits = logits[..., :-1, :].contiguous() 293 | shift_labels = labels[..., 1:].contiguous() 294 | # Flatten the tokens 295 | loss_fct = CrossEntropyLoss() 296 | loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) 297 | 298 | if not return_dict: 299 | output = (logits,) + outputs[1:] 300 | return (loss,) + output if loss is not None else output 301 | 302 | return CausalLMOutputWithPast( 303 | loss=loss, 304 | logits=logits, 305 | past_key_values=outputs.past_key_values, 306 | hidden_states=outputs.hidden_states, 307 | attentions=outputs.attentions, 308 | ) 309 | -------------------------------------------------------------------------------- /tests/perplexity_dynamiceval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | from Conversion import * 4 | from transformers import AutoTokenizer, OPTForCausalLM 5 | from datasets import load_dataset 6 | import copy 7 | from transformers import HfArgumentParser 8 | from filelock import Timeout, FileLock 9 | from data_utils import * 10 | from dataclasses import dataclass, field 11 | from typing import Callable, Dict, Optional, Union, List 12 | 13 | @dataclass 14 | class ModelArguments: 15 | """ 16 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 17 | """ 18 | model_name_or_path: str = field( 19 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 20 | ) 21 | config_name: Optional[str] = field( 22 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 23 | ) 24 | cache_dir: Optional[str] = field( 25 | default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} 26 | ) 27 | 28 | device: Optional[str] = field( 29 | default="cuda", metadata={"help": "cuda/cpu"} 30 | ) 31 | 32 | 33 | final_layers_to_train: Optional[int] = field( 34 | default=-1, metadata={"help": "Number of layers to train"} 35 | ) 36 | 37 | gradient_steps : Optional[int] = field( 38 | default=1, metadata={"help": "Number of gradient steps"} 39 | ) 40 | 41 | learning_rate: Optional[float] = field( 42 | default=1e-04, metadata={"help": "Learning rate for projection!"} 43 | ) 44 | 45 | batch_size: Optional[int] = field( 46 | default=1, metadata={"help": "Batch size for training!"} 47 | ) 48 | 49 | train_fraction: Optional[float] = field( 50 | default=0.5, metadata={"help": "Fraction of sentence for training!"} 51 | ) 52 | 53 | dynamic_chunks: Optional[int] = field( 54 | default=1, metadata={"help": "Number of dynamic chunks (before the final test per sequence)!"} 55 | ) 56 | 57 | 58 | 59 | 60 | @dataclass 61 | class DynamicArguments: 62 | 63 | data_cache_dir: Optional[str] = field( 64 | default='/scratch/gpfs/ap34/icl-as-ft/Dynamic_initialization/data', 65 | metadata={ 66 | 'help': 'where to store downloaded model, datasets, and tokenizer' 67 | } 68 | ) 69 | 70 | log_dir: Optional[str] = field( 71 | default='/scratch/gpfs/smalladi/icl_as_ft/logs' 72 | ) 73 | 74 | dataset: Optional[str] = field( 75 | default='wikitext-103', 76 | metadata={ 77 | 'help': 'dataset to use, should be in HF datasets repo' 78 | } 79 | ) 80 | 81 | use_eval_set: bool = field( 82 | default=True, 83 | metadata={'help': 'when on, uses the eval dataset only for ppl measurement.'} 84 | ) 85 | 86 | use_test_set: bool = field( 87 | default=False, 88 | metadata={'help': 'when on, uses the test dataset only for ppl measurement.'} 89 | ) 90 | 91 | ### finetuning strategies ### 92 | 93 | chunk_data: bool = field( 94 | default=True, 95 | metadata={'help': 'when on, chunks the data into block_size chunks'} 96 | ) 97 | 98 | baseline: bool = field( 99 | default=False 100 | ) 101 | 102 | num_subsets: int = field( 103 | default=16 104 | ) 105 | data_chunk_index: int = field( 106 | default=0 107 | ) 108 | 109 | block_size: Optional[int] = field( 110 | default=1024, 111 | metadata={'help': 'size of chunks to break corpus into'} 112 | ) 113 | 114 | num_workers: Optional[int] = field( 115 | default=4, 116 | metadata={'help': 'number of workers'} 117 | ) 118 | 119 | data_subset: Optional[int] = field( 120 | default=-1, metadata={"help": "Number of training examples to use in the subset!"} 121 | ) 122 | 123 | 124 | 125 | 126 | parser = HfArgumentParser((ModelArguments, DynamicArguments)) 127 | model_args, data_args = parser.parse_args_into_dataclasses() 128 | 129 | learning_rate = model_args.learning_rate 130 | gradient_steps = model_args.gradient_steps 131 | final_layers_to_train = model_args.final_layers_to_train 132 | train_fraction = model_args.train_fraction 133 | dynamic_chunks = model_args.dynamic_chunks 134 | 135 | device = model_args.device 136 | model_config = AutoConfig.from_pretrained( 137 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 138 | cache_dir=model_args.cache_dir 139 | ) 140 | 141 | 142 | if 'gpt' in model_args.model_name_or_path: 143 | model_fn = GPT2LMHeadModel 144 | elif 'opt' in model_args.model_name_or_path: 145 | model_fn = OPTForCausalLM 146 | else: 147 | raise NotImplmentedError 148 | 149 | simulated_gpt2 = model_fn.from_pretrained( 150 | model_args.model_name_or_path, 151 | config=model_config, 152 | cache_dir=model_args.cache_dir, 153 | ) 154 | 155 | simulated_gpt2.to(device) 156 | simulated_gpt2.eval() 157 | 158 | 159 | #dataset = load_dataset("wikitext", "wikitext-2-raw-v1", cache_dir='data', download_mode='reuse_cache_if_exists') 160 | #test_data = dataset['test'] 161 | #test_data = dataset['test'] 162 | #valid_data = dataset['validation'] 163 | 164 | batch_size=model_args.batch_size 165 | 166 | assert batch_size == 1, \ 167 | "Assume batch size as 1 for proper perplexity computation (lazy to do for multi input batch)" 168 | 169 | # Download vocabulary from huggingface.co and cache. 170 | tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, cache_dir="../..") 171 | if 'gpt' in model_args.model_name_or_path: 172 | tokenizer.pad_token = tokenizer.eos_token 173 | pad_token_id=tokenizer.convert_tokens_to_ids(tokenizer.pad_token) 174 | 175 | elif 'opt' in model_args.model_name_or_path: 176 | tokenizer.bos_token_id = 0 177 | 178 | dataset = preprocess_dataset(data_args, tokenizer) 179 | #_, simulated_gpt2, model_config, config = Construct_NASgpt() 180 | 181 | 182 | simulated_gpt2.eval() 183 | device=next(simulated_gpt2.parameters()).device 184 | 185 | 186 | 187 | num_valid_batches = len(dataset) // batch_size 188 | 189 | if data_args.data_subset != -1: 190 | num_valid_batches = min(data_args.data_subset, num_valid_batches) 191 | 192 | avg_model_perplexity = 0. 193 | avg_eval_perplexity = 0. 194 | total_words = 0. 195 | 196 | 197 | avg_model_test_perplexity = 0. 198 | avg_eval_test_perplexity = 0. 199 | total_test_words = 0. 200 | 201 | 202 | 203 | for batch_id in tqdm( range( num_valid_batches ) ): 204 | if data_args.dataset == 'c4' and batch_id == 100: break 205 | model_copy = copy.deepcopy(simulated_gpt2) 206 | model_copy.eval() 207 | 208 | trainable_parameters = [] 209 | trainable_name = [] 210 | if final_layers_to_train == -1: 211 | final_layers_to_train = 12 212 | 213 | #for i in range( 12 - final_layers_to_train, 12 ): 214 | for n, p in model_copy.named_parameters(): 215 | if 'ln_' not in n and '.h.' in n: 216 | layer_num = int(n.split('.h.')[-1].split('.')[0]) 217 | if layer_num >= 12 - final_layers_to_train: 218 | trainable_parameters += [ p ] 219 | trainable_name += [n] 220 | 221 | 222 | optimizer = torch.optim.SGD(model_copy.parameters(), lr=learning_rate) 223 | model_copy.zero_grad() 224 | 225 | 226 | 227 | data = dataset [ batch_id * batch_size : (batch_id + 1) * batch_size ] 228 | batch_sentences = torch.tensor( data ['input_ids'] ) 229 | attention_mask = torch.tensor( data ['attention_mask'] ) 230 | labels = torch.tensor( data ['labels'] ) 231 | 232 | 233 | 234 | 235 | 236 | 237 | if len(batch_sentences.shape) == 1: 238 | batch_sentences = batch_sentences.view((1, -1)) 239 | attention_mask = attention_mask.view((1, -1)) 240 | labels = labels.view((1, -1)) 241 | 242 | train_subchunk_fraction = train_fraction / dynamic_chunks 243 | 244 | #initialize the dynamic loss 245 | final_loss = 0. 246 | og_loss = 0. 247 | total_terms = 0. 248 | loss_fct = torch.nn.CrossEntropyLoss(reduction='sum') 249 | 250 | 251 | og_test_loss = 0. 252 | final_test_loss = 0. 253 | #original model's loss 254 | #with torch.no_grad(): 255 | # train_seq_loss = model_copy(batch_sentences.cuda(), \ 256 | # output_hidden_states=False, \ 257 | # labels=labels.long().to(device))[0].item() 258 | 259 | #dynamic training 260 | for dynamic_chunk_id in range(dynamic_chunks): 261 | 262 | batch_seq_lengths = torch.sum(attention_mask.int(), dim=-1) 263 | mask = torch.zeros_like(attention_mask) 264 | 265 | for i in range(len(batch_seq_lengths)): 266 | len_chunk = int(batch_seq_lengths[i] * train_subchunk_fraction) 267 | actual_fraction = len_chunk / batch_seq_lengths[i] 268 | #print (actual_fraction, train_subchunk_fraction) 269 | mask[i, dynamic_chunk_id * len_chunk: dynamic_chunk_id * len_chunk + len_chunk] = 1. 270 | bidirection_mask = mask.float() 271 | 272 | target = labels.detach().clone() 273 | target[ torch.where(bidirection_mask == 0.) ] = -100 274 | 275 | 276 | #print (target) 277 | #with torch.no_grad(): 278 | input_ids = batch_sentences.to(device) 279 | target = target.to(device) 280 | 281 | #first a simple evaluation on the current model 282 | with torch.no_grad(): 283 | output = model_copy(input_ids, \ 284 | attention_mask=attention_mask.to(device), \ 285 | output_hidden_states=False, 286 | ) 287 | 288 | print (output[0].shape, input_ids.shape) 289 | final_loss += loss_fct( output[0][:, :-1].view((-1, model_config.vocab_size)), \ 290 | target[:, 1:].long().view((-1,)) \ 291 | ).item() 292 | 293 | gpt_output = simulated_gpt2(input_ids, \ 294 | attention_mask=attention_mask.to(device), \ 295 | output_hidden_states=False, \ 296 | ) 297 | og_loss += loss_fct( gpt_output[0][:, :-1].view((-1, model_config.vocab_size)), \ 298 | target[:, 1:].long().view((-1,)) \ 299 | ).item() 300 | total_terms += bidirection_mask.sum() 301 | 302 | 303 | for _ in range(gradient_steps): 304 | simulated_output = model_copy(input_ids, \ 305 | attention_mask=attention_mask.to(device), \ 306 | output_hidden_states=True, \ 307 | labels=target.long()\ 308 | ) 309 | small_model_loss = simulated_output[0] 310 | #print (small_model_loss.item()) 311 | small_model_loss.backward() 312 | optimizer.step() 313 | optimizer.zero_grad() 314 | 315 | for n, p in model_copy.named_parameters(): 316 | for n_, p_ in simulated_gpt2.named_parameters(): 317 | if n == n_ and n in trainable_name: 318 | print ( n, torch.max (torch.absolute(p - p_)) ) 319 | 320 | 321 | 322 | 323 | #print ([p for n, p in model_copy.named_parameters() if 'ln_f' in n]) 324 | #test on the remaining chunk 325 | 326 | batch_seq_lengths = torch.sum(attention_mask.int(), dim=-1) 327 | mask = torch.zeros_like(attention_mask) 328 | 329 | 330 | for i in range(len(batch_seq_lengths)): 331 | len_chunk = dynamic_chunks * int(batch_seq_lengths[i] * train_subchunk_fraction) 332 | test_fraction = 1. - len_chunk / (1. * batch_seq_lengths[i]) 333 | mask[i, len_chunk:] = 1. 334 | bidirection_mask = mask.float() 335 | 336 | 337 | with torch.no_grad(): 338 | target = labels.detach().clone() 339 | target[ torch.where(bidirection_mask == 0.) ] = -100 340 | target = target.to(device) 341 | 342 | 343 | 344 | simulated_output = model_copy(input_ids, \ 345 | attention_mask=attention_mask.to(device), \ 346 | output_hidden_states=False, 347 | ) 348 | loss = loss_fct( simulated_output[0][:, :-1].view((-1, model_config.vocab_size)), \ 349 | target[:, 1:].long().view((-1,)) \ 350 | ).item() 351 | final_loss += loss 352 | avg_eval_test_perplexity += loss 353 | 354 | gpt_output = simulated_gpt2(input_ids, \ 355 | attention_mask=attention_mask.to(device), \ 356 | output_hidden_states=False, 357 | ) 358 | 359 | loss = loss_fct( gpt_output[0][:, :-1].view((-1, model_config.vocab_size)), \ 360 | target[:, 1:].long().view((-1,)) \ 361 | ).item() 362 | og_loss += loss 363 | avg_model_test_perplexity += loss 364 | 365 | total_terms += bidirection_mask.sum() 366 | total_test_words += bidirection_mask.sum() 367 | 368 | del (model_copy) 369 | avg_model_perplexity += og_loss 370 | avg_eval_perplexity += final_loss 371 | total_words += total_terms 372 | 373 | 374 | final_result = {} 375 | final_result[ 'Validation Dynamic eval acc' ] = np.exp(avg_eval_perplexity / total_words) 376 | final_result[ 'Validation Model acc' ] = np.exp(avg_model_perplexity / total_words) 377 | 378 | final_result[ 'Validation Dynamic eval acc (on test)' ] = np.exp(avg_eval_test_perplexity / total_test_words) 379 | final_result[ 'Validation Model acc (on test)' ] = np.exp(avg_model_test_perplexity / total_test_words) 380 | 381 | 382 | with FileLock('log_exp.lock'): 383 | with open('log_exp', 'a') as f: 384 | final_result.update(vars(model_args)) 385 | final_result.update(vars(data_args)) 386 | f.write(str(final_result) + '\n') -------------------------------------------------------------------------------- /tint_main/utils/linear/forward.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import torch 4 | 5 | import math 6 | import os 7 | from dataclasses import dataclass 8 | from typing import Optional, Tuple, Union 9 | 10 | import torch 11 | import torch.utils.checkpoint 12 | from torch import nn 13 | from torch.cuda.amp import autocast 14 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 15 | from ..modules import * 16 | import numpy as np 17 | 18 | 19 | 20 | #------------------------------------------------------------------# 21 | #LinearForward module computes Wx_i at every position i 22 | #Important arguments: input dimension (din), output dimension (dout) 23 | #output: TinT LinearForward module, primarily containing a linear self_attention layer. (See figure 2 in the paper) 24 | #We assume that the rows of W have been stacked onto the prefix tokens, 25 | #before calling this module. 26 | 27 | #We also allow projection matrix on the linear operations, however we donot 28 | #use them in the current version. 29 | #------------------------------------------------------------------# 30 | 31 | 32 | #All arguments 33 | #self, \ 34 | #config, \ #TinT config file 35 | #din, \ #input dimension 36 | #dout, \ #output dimension 37 | #use_softmax=False, \ #use_softmax=False, implying we use linear attention head 38 | #projection_matrix=None, \ #projection_matrix to project the linear operation (not used in the current code) 39 | #shift_top=0, \ #shifts the output to a shifted index if necessary 40 | #memory_index=-1, \ #memory_index to store activations for the backward and descent pass! 41 | 42 | class LinearForward(nn.Module): 43 | 44 | def __init__(self, \ 45 | config, \ 46 | din, \ 47 | dout, \ 48 | use_softmax=False, \ 49 | projection_matrix=None, \ 50 | shift_top=0, \ 51 | memory_index=-1, \ 52 | ): 53 | super(LinearForward, self).__init__() 54 | 55 | self.attnt_module = None #initialized later 56 | #We use gates to differentiate the operations on prefix embeddings and non-prefix embeddings. 57 | self.gates = Gates (config) 58 | self.din = din 59 | self.config = config 60 | self.projection_matrix = projection_matrix 61 | self.memory_index = memory_index 62 | 63 | #initialized later 64 | self.permutation_conv = None 65 | self.bias_add_conv = None 66 | self.projection_layer = None 67 | self.proj_conv2d = (config.hidden_size % dout == 0) 68 | 69 | assert use_softmax == False, \ 70 | "Currently only works without softmax!" 71 | 72 | head_dim = config.num_prefixes 73 | assert config.hidden_size % head_dim == 0,\ 74 | "Dimension should perfectly distribute over the prefixes" 75 | 76 | num_attention_heads = config.hidden_size // head_dim 77 | 78 | 79 | 80 | assert config.hidden_size >= din * (dout // config.num_prefixes), \ 81 | "Total embedding size must be greater than the dimension necessary to store weights in the prefixes" 82 | 83 | assert dout % config.num_prefixes == 0, \ 84 | "I assume uniform distribution of the weights over the prefixes" 85 | 86 | assert din % head_dim == 0,\ 87 | "Currently this is a bug! I assume that the input dimension is easily divisible across the heads we want to distribute to" 88 | assert dout % head_dim == 0,\ 89 | "Currently this is a bug! I assume that the output dimension is easily divisible across the heads we want to distribute to" 90 | 91 | num_wts_per_blank = dout // config.num_prefixes 92 | #initialize attention module 93 | self.attnt_module = Attention (config, \ 94 | num_attention_heads=num_attention_heads, \ 95 | normalize=use_softmax, \ 96 | proj_conv_dim=config.num_prefixes, \ 97 | proj_transpose=True, \ 98 | proj_conv2d=self.proj_conv2d\ 99 | ) 100 | 101 | attnt_head_per_wt = din // head_dim 102 | useful_attnt_heads=attnt_head_per_wt * (dout // config.num_prefixes) 103 | extra_heads=dout // head_dim 104 | 105 | #print (config.num_attention_heads, extra_heads, useful_attnt_heads, attnt_head_per_wt) 106 | assert num_attention_heads >= extra_heads + useful_attnt_heads, \ 107 | "Number of attention heads should be atleast the number of weights + biases present in each blank" 108 | 109 | assert config.num_prefixes <= head_dim ,\ 110 | "Currently I assume the head dimension is atleast the number of prefixes in the original model" 111 | 112 | #--------------------------------#--------------------------------# 113 | #For all Attention heads on the embeddings 114 | #Query repeats the first din dimensions dout times, so that we can split them among the different attention heads 115 | #Key is Identity 116 | #Value is all zeros 117 | #Key and Query of embeddings ignore the position dependence. 118 | 119 | #Final attention head simply copies the bias present in the first blank 120 | #--------------------------------#--------------------------------# 121 | key_attn = torch.zeros((num_attention_heads, head_dim, head_dim)) 122 | 123 | din_partition = din // attnt_head_per_wt 124 | 125 | key_attn[:useful_attnt_heads] = torch.eye(head_dim) 126 | 127 | 128 | 129 | query_attn_head = torch.zeros((head_dim, num_attention_heads, num_attention_heads)) 130 | for i in range (dout // config.num_prefixes): 131 | query_attn_head[:, i*attnt_head_per_wt: (i+1)*attnt_head_per_wt, :attnt_head_per_wt] = torch.eye(attnt_head_per_wt) 132 | 133 | 134 | value_attn = torch.zeros((num_attention_heads, head_dim, head_dim)) 135 | value_attn[useful_attnt_heads: useful_attnt_heads+extra_heads] = torch.eye(head_dim) 136 | 137 | 138 | #--------------------------------#--------------------------------# 139 | #For all Attention heads on the positions 140 | #Query is Identity (on the component corresponding to one-hot encodings 141 | #of the input sequence to the smaller model) 142 | #Key is Identity (on the component corresponding to one-hot encodings 143 | #of the input sequence to the smaller model) + all-ones on the the blank identifiers 144 | #Value moves the blank identifiers to the fore-front 145 | #Key and Query ignore dependence on the signal. 146 | 147 | #Final attention head simply copies the bias present in the first blank 148 | #--------------------------------#--------------------------------# 149 | 150 | 151 | query = torch.zeros((head_dim, config.position_dim)) 152 | query[0, :config.seq_length] = 1. 153 | 154 | key = torch.zeros((head_dim, config.position_dim)) 155 | key[:config.num_prefixes, config.seq_length: config.position_dim] = torch.eye(config.num_prefixes) 156 | 157 | value = torch.zeros((head_dim, config.position_dim)) 158 | value[:config.num_prefixes, config.seq_length: config.position_dim] = torch.eye(config.num_prefixes) 159 | 160 | expand_ = torch.zeros((3, 1, num_attention_heads)) 161 | #for query, we use position embedding only at heads useful_attnt_heads: useful_attnt_heads+extra_heads 162 | expand_[0, 0, useful_attnt_heads: useful_attnt_heads+extra_heads] = 1. 163 | #for key, we use position embedding only at heads useful_attnt_heads: useful_attnt_heads+extra_heads 164 | expand_[1, 0, useful_attnt_heads: useful_attnt_heads+extra_heads] = 1. 165 | #for value, we use position embedding only at heads :useful_attnt_heads 166 | expand_[2, 0, :useful_attnt_heads] = 1. 167 | 168 | 169 | p_attn_init = torch.cat([query, key, value], axis=0) 170 | 171 | #--------------------------------#--------------------------------# 172 | #The projection matrix after the attention module reorders such that 173 | #, , ..., appear in a sequential order. 174 | #--------------------------------#--------------------------------# 175 | 176 | if not self.proj_conv2d: 177 | c_proj_init = torch.zeros(( config.hidden_size, config.hidden_size )) 178 | 179 | for i in range(dout): 180 | num_useful_heads = dout // config.num_prefixes 181 | desd_loc = head_dim * attnt_head_per_wt * (i % num_useful_heads) + i // num_useful_heads 182 | 183 | for sub_head in range(attnt_head_per_wt): 184 | if use_softmax: 185 | c_proj_init[shift_top+i, desd_loc + sub_head * head_dim] = config.scale_embeddings 186 | else: 187 | c_proj_init[shift_top+i, desd_loc + sub_head * head_dim] = 1. 188 | 189 | 190 | for i in range(dout): 191 | desd_loc = head_dim * num_useful_heads * attnt_head_per_wt + i 192 | c_proj_init[shift_top+i, desd_loc] = 1. 193 | 194 | if projection_matrix is not None: 195 | projection_tensor = torch.zeros((config.hidden_size, config.hidden_size)) 196 | projection_tensor[shift_top:shift_top+projection_matrix.shape[0], :projection_matrix.shape[1]] = torch.tensor(projection_matrix, dtype=c_proj_init.dtype) 197 | c_proj_init = projection_tensor @ c_proj_init 198 | 199 | else: 200 | assert head_dim % config.num_prefixes == 0, \ 201 | "This is a bug! For simpler operation, I assume head_dim to be divisible by config.num_prefixes" 202 | 203 | num_partitions_head = head_dim // config.num_prefixes 204 | num_channels = config.hidden_size // config.num_prefixes 205 | num_wt_channels = dout // config.num_prefixes 206 | c_proj_init = torch.zeros(( config.num_prefixes, num_channels, num_channels )) 207 | for i in range(num_wt_channels): 208 | for j in range(attnt_head_per_wt * num_partitions_head): 209 | c_proj_init[:, i, i * attnt_head_per_wt * num_partitions_head + j ] = 1. 210 | 211 | c_proj_init[:, num_wt_channels + i, num_wt_channels * attnt_head_per_wt * num_partitions_head + i] = 1. 212 | 213 | 214 | num_abs_heads = config.hidden_size // dout 215 | shift_top_head = shift_top // dout 216 | 217 | #permute the final computation 218 | self.permutation_conv = Conv2D(nf=num_abs_heads, nx=dout, transpose=False) 219 | permutation_wt = torch.zeros((num_abs_heads, dout, dout)) 220 | for i in range(dout): 221 | desd_loc = ( i % num_wt_channels ) * config.num_prefixes + i // num_wt_channels 222 | permutation_wt[0, i, desd_loc] = 1. 223 | permutation_wt[1] = torch.eye(dout) 224 | with torch.no_grad(): 225 | self.permutation_conv.weight.copy_(permutation_wt.transpose(-1, -2)) 226 | 227 | 228 | #add bias to the 229 | self.bias_add_conv = Conv2D(nf=num_abs_heads, nx=dout, transpose=True, use_einsum=self.config.use_einsum) 230 | bias_add_wt = torch.zeros((dout, num_abs_heads, num_abs_heads)) 231 | bias_add_wt[:, shift_top_head, 0: 2 ] = 1. 232 | with torch.no_grad(): 233 | self.bias_add_conv.weight.copy_(bias_add_wt.transpose(-1, -2)) 234 | 235 | if projection_matrix is not None: 236 | start_index=shift_top 237 | if projection_matrix.shape[0] >= projection_matrix.shape[1]: 238 | self.projection_layer = up_projection (config, projection_matrix, signal_index=start_index, store_index=start_index) 239 | else: 240 | self.projection_layer = down_projection (config, projection_matrix, signal_index=start_index, store_index=start_index) 241 | 242 | 243 | self.attnt_module.initialize_weights(q_attn_init_head=query_attn_head, \ 244 | k_attn_init=key_attn, \ 245 | v_attn_init=value_attn, 246 | p_attn_init=p_attn_init, \ 247 | p_expand_init=expand_,\ 248 | c_proj_init=c_proj_init, \ 249 | ) 250 | 251 | 252 | 253 | #Initialize Gates 254 | #Ignore the changes for the prefixes! 255 | #w, u, v, w_bias, u_bias, v_bias 256 | w = torch.zeros((1, 2*config.hidden_size)) 257 | u = torch.zeros((1, 2*config.hidden_size)) 258 | v = torch.zeros((1, 2*config.position_dim)) 259 | w_bias = torch.zeros(2) 260 | u_bias = torch.zeros(2) 261 | v_bias = torch.zeros(2) 262 | 263 | #Input Gate is 1 on prefixes and 0 for non-prefixes 264 | v [0, config.seq_length: config.position_dim] = config.gate_scale * torch.ones(config.num_prefixes) 265 | 266 | 267 | #Change Gate is 0 on prefixes and 1 for non-prefixes 268 | v [0, config.position_dim+config.seq_length: 2*config.position_dim] = -config.gate_scale * torch.ones(config.num_prefixes) 269 | v_bias [1] += config.gate_scale 270 | 271 | self.gates.initialize_weights (w, u, v, w_bias, u_bias, v_bias) 272 | 273 | 274 | def forward(self, hidden_states, position_embeddings): 275 | output = self.attnt_module.forward(hidden_states=hidden_states, positions=position_embeddings, restrict_prefixes=self.config.restrict_prefixes)[0] 276 | 277 | if self.permutation_conv is not None: 278 | output = self.permutation_conv(output) 279 | output = self.bias_add_conv( output ) 280 | if self.projection_layer is not None: 281 | output = self.projection_layer(output) 282 | #store the input in memory for backward pass later on 283 | if self.memory_index != -1: 284 | assert torch.sum(output[:, self.config.num_prefixes:, self.memory_index: ]).item() < 1e-10,\ 285 | "Memory portion not empty!" 286 | 287 | output[:, self.config.num_prefixes:, self.memory_index: self.memory_index + self.din ] += hidden_states[:, self.config.num_prefixes:, :self.din] 288 | return self.gates.forward(hidden_states=hidden_states, \ 289 | output_states=output, \ 290 | position=position_embeddings\ 291 | ) 292 | -------------------------------------------------------------------------------- /icl_eval/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import contextlib 4 | from typing import Optional, Union 5 | import numpy as np 6 | from dataclasses import dataclass, is_dataclass, asdict 7 | import logging 8 | import time 9 | from torch.nn import CrossEntropyLoss 10 | import torch.nn.functional as F 11 | from transformers.modeling_outputs import CausalLMOutputWithPast 12 | import torch 13 | from transformers.utils import PaddingStrategy 14 | from transformers import PreTrainedTokenizerBase 15 | from typing import Optional, Union, List, Dict, Any 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | @dataclass 20 | class Prediction: 21 | correct_candidate: Union[int, str] 22 | predicted_candidate: Union[int, str] 23 | scores: List[float] = None 24 | 25 | @contextlib.contextmanager 26 | def count_time(name): 27 | logger.info("%s..." % name) 28 | start_time = time.time() 29 | try: 30 | yield 31 | finally: 32 | logger.info("Done with %.2fs" % (time.time() - start_time)) 33 | 34 | @contextlib.contextmanager 35 | def temp_seed(seed): 36 | state = np.random.get_state() 37 | np.random.seed(seed) 38 | try: 39 | yield 40 | finally: 41 | np.random.set_state(state) 42 | 43 | def forward_wrap_with_option_len(self, input_ids=None, labels=None, option_len=None, num_options=None, return_dict=None, **kwargs): 44 | outputs = self.original_forward(input_ids=input_ids, **kwargs) 45 | if labels is None: 46 | return outputs 47 | logits = outputs.logits 48 | 49 | loss = None 50 | # shift so that tokens < n predict n 51 | shift_logits = logits[..., :-1, :].contiguous() 52 | # here we use input_ids (which should always = labels) bc sometimes labels are correct candidate IDs 53 | shift_labels = torch.clone(input_ids)[..., 1:].contiguous() 54 | shift_labels[shift_labels == self.config.pad_token_id] = -100 55 | 56 | # apply option len 57 | for _i, _len in enumerate(option_len): 58 | shift_labels[_i, :-_len] = -100 59 | 60 | # calculate the loss 61 | loss_fct = CrossEntropyLoss(ignore_index=-100) 62 | if num_options is not None: 63 | # Train as classificaiton tasks 64 | log_probs = F.log_softmax(shift_logits, dim=-1) 65 | mask = shift_labels != -100 # option part 66 | shift_labels[~mask] = 0 # so that it doesn't mess up with indexing 67 | 68 | selected_log_probs = torch.gather(log_probs, dim=-1, index=shift_labels.unsqueeze(-1)).squeeze(-1) # (bsz x num_options, len) 69 | selected_log_probs = (selected_log_probs * mask).sum(-1) # (bsz x num_options) 70 | 71 | num_options = num_options[0] 72 | selected_log_probs = selected_log_probs.view(-1, num_options) # (bsz, num_options) 73 | labels = labels.view(-1, num_options)[:, 0] # labels repeat so we only take the first one 74 | loss = loss_fct(selected_log_probs, labels) 75 | else: 76 | loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) 77 | if not return_dict: 78 | output = (logits,) + outputs[1:] 79 | return (loss,) + output if loss is not None else output 80 | 81 | return CausalLMOutputWithPast( 82 | loss=loss, 83 | logits=logits, 84 | past_key_values=outputs.past_key_values, 85 | hidden_states=outputs.hidden_states, 86 | attentions=outputs.attentions, 87 | ) 88 | 89 | @dataclass 90 | class DataCollatorWithPaddingAndNesting: 91 | 92 | tokenizer: PreTrainedTokenizerBase 93 | padding: Union[bool, str, PaddingStrategy] = True 94 | max_length: Optional[int] = None 95 | pad_to_multiple_of: Optional[int] = None 96 | return_tensors: str = "pt" 97 | 98 | def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: 99 | features = [ff for f in features for ff in f] 100 | batch = self.tokenizer.pad( 101 | features, 102 | padding=self.padding, 103 | max_length=self.max_length, 104 | pad_to_multiple_of=self.pad_to_multiple_of, 105 | return_tensors=self.return_tensors, 106 | ) 107 | if "label" in batch: 108 | batch["labels"] = batch["label"] 109 | del batch["label"] 110 | if "label_ids" in batch: 111 | batch["labels"] = batch["label_ids"] 112 | del batch["label_ids"] 113 | return batch 114 | 115 | 116 | def encode_prompt(task, template, train_samples, eval_sample, tokenizer, max_length, sfc=False, icl_sfc=False, generation=False, generation_with_gold=False, **kwargs): 117 | """ 118 | sfc: calibration (surface form competition) 119 | icl_sfc: calibration (surface form competition) with in-context demonstrations 120 | """ 121 | train_prompts = [template.verbalize(sample, sample.correct_candidate).strip() for sample in train_samples] 122 | 123 | 124 | ############### New code to include a label mask ################## 125 | train_label_positions = [] 126 | correct_label_lengths = [] 127 | prompt = '' 128 | 129 | for (sample, sample_cand) in zip( train_samples, train_prompts ): 130 | 131 | example_prompt = (prompt + sample_cand).strip() 132 | example_sent = (prompt + template.encode(sample)).strip() 133 | 134 | label_len = len(tokenizer.encode(example_prompt)) - len(tokenizer.encode(example_sent)) 135 | train_label_positions += [ len(tokenizer.encode(example_sent)) ] 136 | correct_label_lengths += [ label_len ] 137 | prompt = example_prompt + task.train_sep 138 | label_mask = None 139 | ############### New code to include a label mask ################## 140 | 141 | train_prompts = task.train_sep.join(train_prompts).strip() 142 | 143 | if sfc or icl_sfc: 144 | encode_fn = template.encode_sfc; verbalize_fn = template.verbalize_sfc 145 | else: 146 | encode_fn = template.encode; verbalize_fn = template.verbalize 147 | 148 | unverbalized_eval_prompt = encode_fn(eval_sample).strip(' ') 149 | if not generation: 150 | verbalized_eval_prompts = [verbalize_fn(eval_sample, cand).strip(' ') for cand in eval_sample.candidates] 151 | unverbalized_eval_prompt_length = len(tokenizer.encode(unverbalized_eval_prompt)) 152 | option_lens = [(len(tokenizer.encode(verbalized_eval_prompt)) - unverbalized_eval_prompt_length) for verbalized_eval_prompt in verbalized_eval_prompts] 153 | 154 | if sfc or train_prompts == '': 155 | # without demonstrations 156 | final_prompts = verbalized_eval_prompts 157 | else: 158 | # with demonstrations 159 | final_prompts = [(train_prompts + task.train_sep + eval_prompt).lstrip().strip(' ') for eval_prompt in verbalized_eval_prompts] 160 | 161 | else: 162 | assert not sfc and not icl_sfc, "Generation tasks do not support SFC" 163 | if generation_with_gold: 164 | verbalized_eval_prompts = [verbalize_fn(eval_sample, eval_sample.correct_candidate)] 165 | unverbalized_eval_prompt_length = len(tokenizer.encode(unverbalized_eval_prompt)) 166 | option_lens = [(len(tokenizer.encode(verbalized_eval_prompt)) - unverbalized_eval_prompt_length) for verbalized_eval_prompt in verbalized_eval_prompts] 167 | final_prompts = [(train_prompts + task.train_sep + eval_prompt).lstrip().strip(' ') for eval_prompt in verbalized_eval_prompts] 168 | else: 169 | option_lens = [0] 170 | final_prompts = [(train_prompts + task.train_sep + unverbalized_eval_prompt).lstrip().strip(' ')] 171 | 172 | # tokenize 173 | encodings = [tokenizer.encode(final_prompt) for final_prompt in final_prompts] 174 | 175 | ############### New code to include a label mask ################## 176 | if not generation: 177 | if sfc: 178 | label_mask = [np.zeros( len(encoding) ) for encoding in encodings] 179 | else: 180 | label_mask = [] 181 | 182 | for enc_order in range(len(encodings)): 183 | mask = np.zeros( len(encodings[enc_order]) ) 184 | for (pos, length) in zip( train_label_positions, correct_label_lengths ): 185 | mask [pos: pos+length] = 1. 186 | label_mask += [mask] 187 | # label_mask = np.stack(label_mask) 188 | ############### New code to include a label mask ################## 189 | # print (label_mask) 190 | # truncate 191 | if any([len(encoding) > max_length for encoding in encodings]): 192 | logger.warn("Exceed max length") 193 | if tokenizer.add_bos_token: 194 | encodings = [encoding[0:1] + encoding[1:][-(max_length-1):] for encoding in encodings] 195 | # label_mask = [lmask[0:1] + lmask[1:][-(max_length-1):] for lmask in label_mask] 196 | else: 197 | encodings = [encoding[-max_length:] for encoding in encodings] 198 | # label_mask = [lmask[-max_length:] for lmask in label_mask] 199 | 200 | return encodings, option_lens, label_mask 201 | 202 | 203 | def encode_prompt_with_construction(task, template, train_samples, eval_sample, tokenizer, max_length, sfc=False, icl_sfc=False, generation=False, generation_with_gold=False, **kwargs): 204 | 205 | """ 206 | sfc: calibration (surface form competition) 207 | icl_sfc: calibration (surface form competition) with in-context demonstrations 208 | """ 209 | 210 | restrict_attention_demonstration=kwargs['restrict_attention_demonstration'] 211 | position_modify_demonstration=kwargs['position_modify_demonstration'] 212 | mask_demonstration_eval=kwargs['mask_demonstration_eval'] 213 | position_modify_eval=kwargs['position_modify_eval'] 214 | 215 | 216 | assert restrict_attention_demonstration or not position_modify_demonstration, \ 217 | "Can't change position if we are not restricting attention of each demonstration to itself" 218 | 219 | 220 | assert mask_demonstration_eval or not position_modify_eval, \ 221 | "Can't change position if we are not masking demonstrations" 222 | 223 | 224 | 225 | train_prompts = [template.verbalize(sample, sample.correct_candidate).strip() for sample in train_samples] 226 | 227 | ############### New code to include a label mask, icl mask, and modify positions ################## 228 | train_example_start = [] 229 | train_example_end = [] 230 | 231 | train_label_positions = [] 232 | correct_label_lengths = [] 233 | 234 | prompt = '' 235 | if tokenizer.add_bos_token: shift_right = 1 236 | else: shift_right = 0 237 | for (sample, sample_cand) in zip( train_samples, train_prompts ): 238 | 239 | example_prompt = (prompt + sample_cand).lstrip() 240 | example_sent = (prompt + template.encode(sample)).lstrip() 241 | 242 | example_prompt_enc = tokenizer.encode(example_prompt) 243 | example_sent_enc = tokenizer.encode(example_sent) 244 | example_sample_enc = tokenizer.encode(sample_cand.strip()) 245 | 246 | label_len = len(example_prompt_enc) - len(example_sent_enc) 247 | 248 | train_example_start += [len(example_prompt_enc) - len(example_sample_enc) + shift_right] 249 | train_example_end += [len(example_prompt_enc)] 250 | 251 | train_label_positions += [ len(example_sent_enc) ] 252 | correct_label_lengths += [ label_len ] 253 | 254 | prompt = example_prompt + task.train_sep 255 | label_mask = None 256 | test_example_start = len(tokenizer.encode(prompt + 'end')) - len(tokenizer.encode('end')) + shift_right 257 | 258 | 259 | ############### New code to include a label mask ################## 260 | 261 | train_prompts = task.train_sep.join(train_prompts).strip() 262 | 263 | if sfc or icl_sfc: 264 | encode_fn = template.encode_sfc; verbalize_fn = template.verbalize_sfc 265 | else: 266 | encode_fn = template.encode; verbalize_fn = template.verbalize 267 | 268 | unverbalized_eval_prompt = encode_fn(eval_sample).strip(' ') 269 | if not generation: 270 | verbalized_eval_prompts = [verbalize_fn(eval_sample, cand).strip(' ') for cand in eval_sample.candidates] 271 | unverbalized_eval_prompt_length = len(tokenizer.encode(unverbalized_eval_prompt)) 272 | option_lens = [(len(tokenizer.encode(verbalized_eval_prompt)) - unverbalized_eval_prompt_length) for verbalized_eval_prompt in verbalized_eval_prompts] 273 | 274 | if sfc or train_prompts == '': 275 | # without demonstrations 276 | final_prompts = verbalized_eval_prompts 277 | else: 278 | # with demonstrations 279 | final_prompts = [(train_prompts + task.train_sep + eval_prompt).lstrip().strip(' ') for eval_prompt in verbalized_eval_prompts] 280 | 281 | else: 282 | assert not sfc and not icl_sfc, "Generation tasks do not support SFC" 283 | if generation_with_gold: 284 | verbalized_eval_prompts = [verbalize_fn(eval_sample, eval_sample.correct_candidate)] 285 | unverbalized_eval_prompt_length = len(tokenizer.encode(unverbalized_eval_prompt)) 286 | option_lens = [(len(tokenizer.encode(verbalized_eval_prompt)) - unverbalized_eval_prompt_length) for verbalized_eval_prompt in verbalized_eval_prompts] 287 | final_prompts = [(train_prompts + task.train_sep + eval_prompt).lstrip().strip(' ') for eval_prompt in verbalized_eval_prompts] 288 | else: 289 | option_lens = [0] 290 | final_prompts = [(train_prompts + task.train_sep + unverbalized_eval_prompt).lstrip().strip(' ')] 291 | 292 | # tokenize 293 | encodings = [tokenizer.encode(final_prompt) for final_prompt in final_prompts] 294 | 295 | ############### New code to include a label mask ################## 296 | if not generation: 297 | if sfc: 298 | label_mask = [np.zeros( len(encoding) ) for encoding in encodings] 299 | else: 300 | label_mask = [] 301 | 302 | for enc_order in range(len(encodings)): 303 | mask = np.zeros( len(encodings[enc_order]) ) 304 | for (pos, length) in zip( train_label_positions, correct_label_lengths ): 305 | mask [pos: pos+length] = 1. 306 | label_mask += [mask] 307 | label_mask = np.stack(label_mask) 308 | ############### New code to include a label mask ################## 309 | 310 | ############### New code to include a icl mask and position ids################## 311 | icl_mask = [ np.ones( (len(encoding), len(encoding)) ) for encoding in encodings ] 312 | if restrict_attention_demonstration: 313 | for (start, end) in zip(train_example_start, train_example_end): 314 | for i in range(len(encodings)): 315 | icl_mask[i][start:end, start:end] = 1. 316 | icl_mask[i][start:end, :start] = 0. 317 | icl_mask[i][start:end, end:] = 0. 318 | 319 | if mask_demonstration_eval: 320 | for i in range(len(encodings)): 321 | icl_mask[i][test_example_start:, test_example_start:] = 1. 322 | icl_mask[i][test_example_start:, :test_example_start] = 0. 323 | 324 | if tokenizer.add_bos_token: 325 | for i in range(len(encodings)): icl_mask[i][:, 0] = 1. 326 | 327 | position_ids = [ np.arange( len(encoding) ) for encoding in encodings ] 328 | if position_modify_demonstration: 329 | if tokenizer.add_bos_token: start_id = 1 330 | else: start_id = 0 331 | 332 | for (start, end) in zip(train_example_start, train_example_end): 333 | example_length = end - start 334 | for i in range(len(encodings)): 335 | position_ids [i][start:end] = np.arange( start_id, start_id + example_length ) 336 | 337 | if position_modify_eval: 338 | if tokenizer.add_bos_token: start_id = 1 339 | else: start_id = 0 340 | 341 | for i in range( len(encodings) ): 342 | start = test_example_start 343 | end = len(encodings[i]) 344 | 345 | example_length = end - start 346 | position_ids [i][start:end] = np.arange( start_id, start_id + example_length ) 347 | 348 | ############### New code to include a icl mask and position ids ################## 349 | 350 | 351 | # truncate 352 | if any([len(encoding) > max_length for encoding in encodings]): 353 | logger.warn("Exceed max length") 354 | if tokenizer.add_bos_token: 355 | 356 | new_position_ids = [] 357 | new_label_mask = [] 358 | new_icl_mask = [] 359 | for i in range( len(encodings) ): 360 | max_len = min(max_length, len(encodings[i])) 361 | 362 | nicm = np.zeros((max_len, max_len)) 363 | nicm [ -(max_len-1):, -(max_len-1): ] = icl_mask[i] [-(max_len-1):, -(max_len-1):] 364 | nicm [ :, 0 ] = 1. 365 | 366 | new_icl_mask += [nicm] 367 | new_position_ids += [ np.concatenate( [np.asarray([0]), position_ids[i][ -(max_len-1): ] ], axis=0 ) ] 368 | new_label_mask += [ np.concatenate( [np.asarray([label_mask[i][0]]), label_mask[i][ -(max_len-1): ] ], axis=0 ) ] 369 | icl_mask = new_icl_mask 370 | position_ids = new_position_ids 371 | label_mask = new_label_mask 372 | 373 | encodings = [encoding[0:1] + encoding[1:][-(max_length-1):] for encoding in encodings] 374 | else: 375 | encodings = [encoding[-max_length:] for encoding in encodings] 376 | label_mask = [lmask[-max_length:] for lmask in label_mask] 377 | position_ids = [ position_id[ -max_length: ] for position_id in position_ids] 378 | icl_mask = [ icm[-max_length:, -max_length:] for icm in icl_mask ] 379 | 380 | return encodings, option_lens, label_mask, icl_mask, position_ids 381 | 382 | 383 | def load_generation(): 384 | out_dir = "/scratch/gpfs/mengzhou/space6/out/the_pile_corrected" 385 | json_files = [f"{out_dir}/ft/ft_opt-125m-lr1e-4/generation_downstream/ft_opt-125m-lr1e-4-hf-sample-0.90-len20-num5-copa.json"] 386 | for model in ["350m", "1.3b", "2.7b"]: 387 | json_files.append(f"{out_dir}/kd_pretrained_ce1_layer0_bs512/kd_pretrained_temp1_tmodelft{model}_lr1e-4/generation_downstream/kd_pretrained_temp1_tmodelft{model}_lr1e-4-hf-sample-0.90-len20-num5-copa.json") 388 | 389 | i = 0 390 | for json_file in json_files[3:4]: 391 | name = os.path.basename(os.path.dirname(os.path.dirname(json_file))) 392 | generations = json.load(open(json_file)) 393 | gen = generations[i] 394 | print(name) 395 | print("\033[1mprefix\033[0m:", gen["prefix"]) 396 | print("\033[1mcorrect_options\033[0m:", gen["correct_options"]) 397 | for i in range(len(gen["incorrect_options"])): 398 | print("\033[1mincorrect_options\033[0m" + f" {i}:", end=" ") 399 | print(gen["incorrect_options"][i]) 400 | for i in range(len(gen["generated"])): 401 | print("\033[1mgenerated_option\033[0m" + f" {i}:", end=" ") 402 | print(f"[{round(gen['scores'][i], 2)}]:", end=" ") 403 | print(gen["generated"][i]) 404 | 405 | def read_jsonl(file): 406 | ds = [] 407 | try: 408 | with open(file) as f: 409 | for i, line in enumerate(f): 410 | d = json.loads(line.strip()) 411 | ds.append(d) 412 | except: 413 | import pdb 414 | pdb.set_trace() 415 | return ds 416 | 417 | 418 | from collections.abc import Mapping 419 | from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union 420 | InputDataClass = NewType("InputDataClass", Any) 421 | from dataclasses import dataclass 422 | from transformers.tokenization_utils_base import PreTrainedTokenizerBase 423 | import torch 424 | @dataclass 425 | class ICLCollator: 426 | tokenizer: PreTrainedTokenizerBase 427 | 428 | def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: 429 | if not isinstance(features[0], Mapping): 430 | features = [vars(f) for f in features] 431 | first = features[0] 432 | batch = {} 433 | 434 | pad_id = self.tokenizer.pad_token_id 435 | 436 | pad_ids = {"input_ids": pad_id, "attention_mask": 0, "sfc_input_ids": pad_id, "sfc_attention_mask": 0, "labels": pad_id} 437 | for key in first: 438 | pp = pad_ids[key] 439 | lens = [len(f[key]) for f in features] 440 | max_len = max(lens) 441 | feature = np.stack([np.pad(f[key], (0, max_len - lens[i]), "constant", constant_values=(0, pp)) for i, f in enumerate(features)]) 442 | padded_feature = torch.from_numpy(feature).long() 443 | batch[key] = padded_feature 444 | 445 | return batch 446 | 447 | 448 | import json 449 | class EnhancedJSONEncoder(json.JSONEncoder): 450 | def default(self, o): 451 | if is_dataclass(o): 452 | return asdict(o) 453 | return super().default(o) 454 | 455 | def write_predictions_to_file(final_preds, output): 456 | with open(output, "w") as f: 457 | for pred in final_preds: 458 | f.write(json.dumps(pred, cls=EnhancedJSONEncoder) + "\n") 459 | 460 | def write_metrics_to_file(metrics, output): 461 | with open(output, "w") as f: 462 | if type(metrics) == list: 463 | for metric in metrics: 464 | json.dump(metric, f, cls=EnhancedJSONEncoder) 465 | f.write("\n") 466 | else: 467 | json.dump(metrics, f, cls=EnhancedJSONEncoder, indent=4) 468 | 469 | if __name__ == "__main__": 470 | load_generation() 471 | --------------------------------------------------------------------------------