├── __init__.py
├── tests
├── __init__.py
├── create_model.py
├── perplexity_eval.py
├── data_utils
│ └── data_utils.py
└── perplexity_dynamiceval.py
├── tint_main
├── utils
│ ├── __init__.py
│ ├── activation
│ │ ├── __init__.py
│ │ ├── backward.py
│ │ └── forward.py
│ ├── linear
│ │ ├── __init__.py
│ │ └── forward.py
│ ├── layernorm
│ │ ├── __init__.py
│ │ ├── forward.py
│ │ └── backward.py
│ ├── self_attention
│ │ ├── __init__.py
│ │ ├── forward.py
│ │ └── backward.py
│ ├── config.py
│ ├── activations.py
│ └── all_arguments.py
└── tint_creator.py
├── figures
└── icl_overview.png
├── icl_eval
├── images
│ └── single_FT_or_multiple_FT.png
├── metrics.py
├── scripts
│ ├── others
│ │ ├── print_result.py
│ │ ├── print_result_new.py
│ │ ├── print_dynamic_eval_results.py
│ │ └── print_result_May10.py
│ ├── zero_shot
│ │ └── print_result_May10.py
│ ├── few_shot
│ │ └── print_result_May10.py
│ ├── TINT
│ │ ├── print_paper_result_zero_shot.py
│ │ └── print_paper_result.py
│ ├── dynamic_eval
│ │ └── print_result.py
│ └── run_eval.sh
├── README.md
├── models
│ └── modeling_opt.py
└── utils.py
├── .gitignore
├── env.yml
└── README.md
/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tint_main/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tint_main/utils/activation/__init__.py:
--------------------------------------------------------------------------------
1 | from .forward import ActivationForward
2 | from .backward import ActivationBackward
3 |
--------------------------------------------------------------------------------
/figures/icl_overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abhishekpanigrahi1996/transformer_in_transformer/HEAD/figures/icl_overview.png
--------------------------------------------------------------------------------
/tint_main/utils/linear/__init__.py:
--------------------------------------------------------------------------------
1 | from .forward import LinearForward
2 | from .backward import LinearBackward, LinearDescent, Linear_Descent_Backward
--------------------------------------------------------------------------------
/tint_main/utils/layernorm/__init__.py:
--------------------------------------------------------------------------------
1 | from .forward import LayerNormForward
2 | from .backward import LayerNormBackward, LayerNormDescent, LayerNormDescent_Backward
3 |
--------------------------------------------------------------------------------
/icl_eval/images/single_FT_or_multiple_FT.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abhishekpanigrahi1996/transformer_in_transformer/HEAD/icl_eval/images/single_FT_or_multiple_FT.png
--------------------------------------------------------------------------------
/tint_main/utils/self_attention/__init__.py:
--------------------------------------------------------------------------------
1 | from .forward import AttentionForward
2 | from .backward import AttentionBackward, AttentionDescent, AttentionBackward_Descent
3 |
4 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | not_firstversion/
7 | *.sh
8 | Constructed_model/*
9 | log*
10 |
--------------------------------------------------------------------------------
/tests/create_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from tqdm import tqdm
3 | from tint_main.tint_creator import *
4 | from transformers import AutoTokenizer
5 | from datasets import load_dataset
6 | from transformers import HfArgumentParser
7 | from filelock import Timeout, FileLock
8 | from .data_utils.data_utils import *
9 | from tint_main.utils.all_arguments import *
10 |
11 | parser = HfArgumentParser((ModelArguments, ConstructionArguments, DynamicArguments,))
12 | model_args, construction_args, data_args, = parser.parse_args_into_dataclasses()
13 |
14 | constructed_model, _, _, _ = TinT_Creator(model_args, construction_args)
15 |
16 |
17 |
18 |
19 |
20 |
--------------------------------------------------------------------------------
/icl_eval/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | # For question answering
4 | import re
5 | import string
6 |
7 | def normalize_answer(s):
8 | """Lower text and remove punctuation, articles and extra whitespace."""
9 |
10 | def remove_articles(text):
11 | return re.sub(r'\b(a|an|the)\b', ' ', text)
12 |
13 | def white_space_fix(text):
14 | return ' '.join(text.split())
15 |
16 | def remove_punc(text):
17 | exclude = set(string.punctuation)
18 | return ''.join(ch for ch in text if ch not in exclude)
19 |
20 | def lower(text):
21 | return text.lower()
22 |
23 | return white_space_fix(remove_articles(remove_punc(lower(s))))
24 |
25 | def calculate_metric(predictions, metric_name):
26 | if metric_name == "accuracy":
27 | if isinstance(predictions[0].correct_candidate, list):
28 | return np.mean([pred.predicted_candidate in pred.correct_candidate for pred in predictions])
29 | else:
30 | return np.mean([pred.correct_candidate == pred.predicted_candidate for pred in predictions])
31 | elif metric_name == "em":
32 | return np.mean([any([normalize_answer(ans) == normalize_answer(pred.predicted_candidate) for ans in pred.correct_candidate]) for pred in predictions])
33 | elif metric_name == "substring_em":
34 | return np.mean([any([normalize_answer(ans) in normalize_answer(pred.predicted_candidate) for ans in pred.correct_candidate]) for pred in predictions])
35 |
36 |
--------------------------------------------------------------------------------
/icl_eval/scripts/others/print_result.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import json
3 | import numpy as np
4 | import os
5 |
6 |
7 | def print_numpy_result(result, models):
8 | for i in range(len(result)):
9 | # print(models[i], end="\t")
10 | for j in range(result.shape[1]):
11 | print(f"{result[i, j]:.4f}", end="\t")
12 | print()
13 | print()
14 |
15 | out_dir = Path("/n/fs/nlp-mengzhou/space9/out/llm_eval/constructed-lm-results")
16 |
17 | # rows
18 | models = ["gpt2", "constructed-gpt2-l12-ns1-lr1e-06", "constructed-gpt2-l12-ns1-lr1e-05", "constructed-gpt2-l12-ns1-lr1e-04", "opt-125m",
19 | "constructed-opt-l12-ns1-lr1e-06", "constructed-opt-l12-ns1-lr2e-06", "constructed-opt-l12-ns1-lr5e-07", "gpt2-large", "opt-1.3b"]
20 |
21 | shots = [0, 2, 4, 8]
22 |
23 | task = "AGNews"
24 | icl_sfc = False
25 | sfc = False
26 | metric = "accuracy"
27 |
28 | result = np.zeros((len(models), len(shots)))
29 | for i, model in enumerate(models):
30 | for j, shot in enumerate(shots):
31 | if sfc: tag = "-sfc"
32 | elif icl_sfc: tag = "-icl_sfc"
33 | else: tag = ""
34 | if shot == 0: file = out_dir / model / f"{task}-{model}{tag}-sampleeval200-onetrainpereval.json"
35 | else: file = out_dir / model / f"{task}-{model}{tag}-sampleeval200-ntrain{shot}-onetrainpereval.json"
36 | if os.path.exists(file):
37 | re = json.load(open(file, "r"))
38 | result[i, j] = re[metric]
39 |
40 | print(f"{task} {metric}")
41 | print_numpy_result(result, models)
42 |
43 |
--------------------------------------------------------------------------------
/icl_eval/scripts/zero_shot/print_result_May10.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import json
3 | import numpy as np
4 | import os
5 |
6 | def read_jsonl(file):
7 | ds = []
8 | try:
9 | with open(file) as f:
10 | for i, line in enumerate(f):
11 | if line.strip() == "":
12 | continue
13 | d = json.loads(line.strip())
14 | ds.append(d)
15 | except:
16 | import pdb
17 | pdb.set_trace()
18 | return ds
19 |
20 | def print_numpy_result(result, models):
21 | for i in range(len(result)):
22 | # print(models[i], end="\t")
23 | for j in range(result.shape[1]):
24 | print(f"{result[i, j]:.4f}", end="\t")
25 | print()
26 | print()
27 |
28 | out_dir = Path("/n/fs/nlp-mengzhou/space9/out/llm_eval/zeroshot_eval_May15")
29 |
30 |
31 | shots = 0
32 | task = "MPQA"
33 | metric = "accuracy"
34 | eval_type = "eval"
35 | seed = 1
36 | # result = np.zeros((len(models), len(shots)))
37 |
38 | print_model_name = False
39 | for tmp in [0]:
40 | for icl in ["plain", "icl_sfc"]:
41 | count = 0
42 | model = f"opt-1.3b-ntrain{shots}-{task}-seed{seed}-tmp{tmp}-{icl}-eval"
43 | task_dir = out_dir / model
44 | file = task_dir / "metrics.jsonl"
45 | if os.path.exists(file):
46 | re = read_jsonl(file)
47 | acc = re[0]["accuracy"]
48 | count += 1
49 | else:
50 | print(file)
51 | print(acc, end=" ")
52 | # print(seed_accs)
53 | print()
54 |
55 |
56 |
57 | # print(f"{task} {metric}")
58 | # print_numpy_result(result, models)
59 |
60 |
--------------------------------------------------------------------------------
/icl_eval/scripts/few_shot/print_result_May10.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import json
3 | import numpy as np
4 | import os
5 |
6 | def read_jsonl(file):
7 | ds = []
8 | try:
9 | with open(file) as f:
10 | for i, line in enumerate(f):
11 | if line.strip() == "":
12 | continue
13 | d = json.loads(line.strip())
14 | ds.append(d)
15 | except:
16 | import pdb
17 | pdb.set_trace()
18 | return ds
19 |
20 | def print_numpy_result(result, models):
21 | for i in range(len(result)):
22 | # print(models[i], end="\t")
23 | for j in range(result.shape[1]):
24 | print(f"{result[i, j]:.4f}", end="\t")
25 | print()
26 | print()
27 |
28 | out_dir = Path("/n/fs/nlp-mengzhou/space9/out/llm_eval/fewshot_eval_May15")
29 |
30 | shots = 32
31 | # task = sys.argv[1]
32 | metric = "accuracy"
33 | eval_type = "eval"
34 | # result = np.zeros((len(models), len(shots)))
35 |
36 | all_seed_accs = []
37 | for task in ["Subj", "AGNews", "SST2", "CR", "MR", "MPQA", "AmazonPolarity"]:
38 | # print(task)
39 | print_model_name = False
40 | for tmp in [0]:
41 | for icl in ["icl_sfc"]:
42 | seed_accs = []
43 | seed_models = []
44 | for seed in range(1, 4):
45 | count = 0
46 | model = f"opt-1.3b-ntrain{shots}-{task}-seed{seed}-tmp{tmp}-{icl}-eval"
47 | task_dir = out_dir / model
48 | file = task_dir / "metrics.jsonl"
49 | if os.path.exists(file):
50 | re = read_jsonl(file)
51 | acc = re[0]["accuracy"]
52 | count += 1
53 | seed_accs.append(acc * 100)
54 | else:
55 | print(file)
56 | all_seed_accs.append(np.array(seed_accs))
57 | acc = np.std(seed_accs)
58 | if print_model_name:
59 | print(" ".join(seed_models), end=" ")
60 | else:
61 | print(acc, end=" ")
62 | # print(seed_accs)
63 | print()
64 |
65 | all_seed_accs = np.stack(all_seed_accs, axis=1)
66 | print("sigma:", all_seed_accs.mean(1).std())
67 |
68 | # print(f"{task} {metric}")
69 | # print_numpy_result(result, models)
70 |
71 |
--------------------------------------------------------------------------------
/icl_eval/scripts/others/print_result_new.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import json
3 | import numpy as np
4 | import os
5 | from cus_data.dataloader import read_jsonl
6 |
7 |
8 | def print_numpy_result(result, models):
9 | for i in range(len(result)):
10 | # print(models[i], end="\t")
11 | for j in range(result.shape[1]):
12 | print(f"{result[i, j]:.4f}", end="\t")
13 | print()
14 | print()
15 |
16 | out_dir = Path("/scratch/gpfs/mengzhou/space9/out/llm_eval/dynamic_icl_eval_May8_test")
17 |
18 | # rows
19 | models = ["constructed-opt-backprop-6-nstep-2-lr-1e-04", "constructed-opt-backprop-6-nstep-2-lr-1e-05", "constructed-opt-backprop-6-nstep-2-lr-1e-06",
20 | "constructed-opt-backprop-12-nstep-1-lr-1e-04", "constructed-opt-backprop-12-nstep-1-lr-1e-05", "constructed-opt-backprop-12-nstep-1-lr-1e-06",
21 | "constructed-opt-backprop-3-nstep-4-lr-1e-04", "constructed-opt-backprop-3-nstep-4-lr-1e-05", "constructed-opt-backprop-3-nstep-4-lr-1e-06"]
22 |
23 | shots = 16
24 | task = "MR"
25 | metric = "accuracy"
26 | eval_type = "eval"
27 |
28 | # result = np.zeros((len(models), len(shots)))
29 |
30 | for single_FT in [True, False]:
31 | for label_type in ["context", "label_only"]:
32 | for icl in ["plain", "icl_sfc"]:
33 | seed_accs = []
34 | for seed in range(1, 3):
35 | acc = np.zeros((len(models), 2))
36 | for i, model in enumerate(models):
37 | model_dir = out_dir / model
38 | lr = "-".join(model.split("-")[-2:])
39 | layer = model.split("-")[3]
40 | task_dir = model_dir / f"ntrain{shots}-{task}-lr{lr}-seed{seed}-single_FT{single_FT}-layers{layer}-{label_type}-{icl}-{eval_type}"
41 | file = task_dir / "metrics.jsonl"
42 | if os.path.exists(file):
43 | re = read_jsonl(file)
44 | acc[i, 0] = re[0]["accuracy"]
45 | # acc[i, 1] = re[1]["accuracy"]
46 | else:
47 | print(file)
48 | seed_acc = acc.max(axis=0)
49 | seed_accs.append(seed_acc)
50 | acc = np.stack(seed_accs).mean(axis=0)
51 | print(" ".join([str(a) for a in acc]), end = " ")
52 | # print(seed_accs)
53 | print()
54 |
55 |
56 |
57 | # print(f"{task} {metric}")
58 | # print_numpy_result(result, models)
59 |
60 |
--------------------------------------------------------------------------------
/icl_eval/README.md:
--------------------------------------------------------------------------------
1 | # In-context Learning Evaluation
2 |
3 | We demonstrate the process of conducting in-context learning experiments with our TINT or other models for comparison. We support running experiments with following three modes, `standard`, `dynamic evaluation` and `TINT`. `TINT` effectively approximates `dynamical evaluation` with only forward passes with no backward pass. `standard` is simply the standard in-context learning setting.
4 |
5 | We use `scripts/run_eval.sh` as our main script and you can pass arguments to the script to enable different types of evaluations.
6 |
7 | ```
8 | bash scripts/run_eval.sh [ARGUMENTS]
9 | ```
10 |
11 | We explain the arguments as follows:
12 |
13 | `model_path`: path of the model, it could be a huggingface model path, e.g., `facebook/opt-125m`, or a TINT model.
14 | `model_name`: name of the model and can be freely decided by users, for logging purposes.
15 | `output_dir`: output directory.
16 | `train_num`: number of training examples for in-context learning, dynamic evaluation, or internal training for TINT. Supports both zero-shot (=0) and few-shot (>0) settings.
17 | `task`: task name and currently supports `AmazonPolarity`, `YelpPolarity`, `AGNews`, `SST2`, `MR`, `CR`, `MPQA`, and `Subj`.
18 | `lr`: learning rate for training. We grid search a learning rate of 1e-3, 1e-4, 1e-5 and 1e-6 in our experiments.
19 | `train_set_seed`: seed for selecting training examples.
20 | `single_FT`: whether training with a concatenation of in-context learning examples or training with seperate examples in each input. Please find the image below for illustration.
21 | `num_train_layers`: number of training layers (top layers). We experimented with 3, 6, 12 training layers.
22 | `label_type`: whether we use the loss of all the tokens in the context (`context`) or only use loss on the label words (`label_only`).
23 | `sfc_type`: whether or not to use the [surface form competition algorithm](https://arxiv.org/abs/2104.08315) to calibrate.
24 | `plain`: we do not use any calibration method.
25 | `icl_sfc`: we use the sfc algorithm for calibration.
26 | `eval_type`: `eval` or `test`, we use `eval` as the default setting.
27 | `template`: we use default template `0`, uses could define templates in `templates.py` and add it to the specific task in `tasks.py`.
28 | `test`: if `True`, the evaluation enters a test mode for debugging.
29 | `dynamic_eval`: `True` for `dynamic evaluation` and `TINT` mode, and `False` for `standard` mode.
30 |
31 |
32 |
33 |
--------------------------------------------------------------------------------
/icl_eval/scripts/others/print_dynamic_eval_results.py:
--------------------------------------------------------------------------------
1 | model_name = "opt-125m"
2 | task = "MRPC"
3 | eval_type = "valid"
4 |
5 | # MRPC-train8-eval200-lr1e-03-labelcontext-dynamicFalse
6 |
7 | import json
8 | from pathlib import Path
9 | import numpy as np
10 | import os
11 | def read_jsonl(file):
12 | lines = open(file, "r").readlines()
13 | d = []
14 | for line in lines:
15 | d.append(json.loads(line.strip()))
16 | return d
17 |
18 | output_dir = Path(f"/n/fs/nlp-mengzhou/space9/out/llm_eval/dynamic_icl_eval_May3")
19 | # for sfc_type in ["plain", "icl", "icl_sfc"]:
20 | # for shot in [16]:
21 | # for single_FT in [True, False]:
22 | # for label_type in ["label_only", "context"]:
23 | # for num_train_layers in [3, 6, 12]:
24 | # accuracies_all_seeds = []
25 | # for seed in [1, 2, 3]:
26 | # accuracies = []
27 | # for lr in ["1e-05", "1e-04", "1e-03", "1e-02"]:
28 | # file_name=f"{model_name}-ntrain{shot}-{task}-lr{lr}-seed{seed}-single_FT{single_FT}-layers{num_train_layers}-{label_type}-{sfc_type}-{eval_type}.json"
29 | # # file_name=f"{model_name}-ntrain{shot}-{task}-seed{seed}-{sfc_type}-{eval_type}.json"
30 | # file = output_dir / file_name
31 | # if not os.path.exists(file):
32 | # continue
33 | # d = read_jsonl(file)
34 | # accuracies.append(np.array([dd["accuracy"] for dd in d]))
35 | # accuracies = np.stack(accuracies).max(axis=0).tolist()
36 | # accuracies_all_seeds.append(accuracies)
37 | # accuracies_all_seeds = np.stack(accuracies_all_seeds)
38 | # accuracies = np.mean(accuracies_all_seeds, axis=0)
39 | # re = " ".join(str(acc) for acc in accuracies)
40 | # print(re)
41 |
42 | accuracies = []
43 | shot=16
44 | for sfc_type in ["plain", "icl", "icl_sfc"]:
45 | for seed in [1, 2, 3]:
46 | file = output_dir / f"{model_name}-ntrain{shot}-{task}-seed{seed}-{sfc_type}-{eval_type}.json"
47 | if not os.path.exists(file):
48 | continue
49 | d = read_jsonl(file)
50 | accuracies.append(np.array([dd["accuracy"] for dd in d]))
51 | accuracies = np.stack(accuracies)
52 | accuracies = np.mean(accuracies, axis=0).tolist()
53 | re = " ".join(str(acc) for acc in accuracies)
54 | print(re)
55 |
56 |
57 |
58 |
--------------------------------------------------------------------------------
/icl_eval/scripts/TINT/print_paper_result_zero_shot.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import json
3 | import numpy as np
4 | import os
5 |
6 | def read_jsonl(file):
7 | ds = []
8 | try:
9 | with open(file) as f:
10 | for i, line in enumerate(f):
11 | d = json.loads(line.strip())
12 | ds.append(d)
13 | except:
14 | import pdb
15 | pdb.set_trace()
16 | return ds
17 |
18 | def print_numpy_result(result, models):
19 | for i in range(len(result)):
20 | # print(models[i], end="\t")
21 | for j in range(result.shape[1]):
22 | print(f"{result[i, j]:.4f}", end="\t")
23 | print()
24 | print()
25 |
26 | out_dir = Path("/scratch/gpfs/smalladi/mengzhou/out/llm_eval/dynamic_icl_eval_zero-shot")
27 |
28 | # rows
29 | def get_models(label_type):
30 | models = []
31 | for layer in [12, 6, 3]:
32 | steps = 12 // layer
33 | if label_type == "context": lrs = ["1e-05", "1e-06", "1e-07"]
34 | else: lrs = ["1e-03", "1e-04", "1e-05"]
35 | for lr in lrs:
36 | models.append(f"constructed-opt-backprop-{layer}-nstep-{steps}-lr-{lr}")
37 | return models
38 |
39 | shots = 0
40 | task = "AmazonPolarity"
41 | metric = "accuracy"
42 | eval_type = "eval"
43 | # result = np.zeros((len(models), len(shots)))
44 | icl = "plain"
45 | tmp = 0
46 | seed = 1
47 |
48 | print_model_name = False
49 | for label_type in ["context"]:
50 | for single_FT in [True]:
51 | seed_accs = []
52 | seed_models = []
53 | sub_models = get_models(label_type)
54 | acc = np.zeros((len(sub_models), 2))
55 | for i, model in enumerate(sub_models):
56 | model_dir = out_dir / model
57 | lr = "-".join(model.split("-")[-2:])
58 | layer = model.split("-")[3]
59 | task_dir = model_dir / f"ntrain{shots}-{task}-lr{lr}-seed{seed}-single_FT{single_FT}-layers{layer}-{label_type}-tmp{tmp}-{icl}-{eval_type}"
60 | file = task_dir / "metrics.jsonl"
61 | if os.path.exists(file):
62 | re = read_jsonl(file)
63 | acc[i, 0] = re[0]["accuracy"]
64 | # acc[i, 1] = re[1]["accuracy"]
65 | else:
66 | pass
67 | # print(file)
68 | model_name = sub_models[acc[:, 0].argmax(axis=0)]
69 | acc = acc[:, 0].max(axis=0)
70 | if print_model_name:
71 | print(" ".join(seed_models))
72 | else:
73 | print(acc)
74 | # print(seed_accs)
75 |
76 |
77 |
78 | # print(f"{task} {metric}")
79 | # print_numpy_result(result, models)
80 |
81 |
--------------------------------------------------------------------------------
/icl_eval/scripts/dynamic_eval/print_result.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import json
3 | import numpy as np
4 | import os
5 |
6 | def read_jsonl(file):
7 | ds = []
8 | try:
9 | with open(file) as f:
10 | for i, line in enumerate(f):
11 | if line.strip() == "":
12 | continue
13 | d = json.loads(line.strip())
14 | ds.append(d)
15 | except:
16 | import pdb
17 | pdb.set_trace()
18 | return ds
19 |
20 | def read_json(file):
21 | return json.load(open(file))
22 |
23 | def print_numpy_result(result, models):
24 | for i in range(len(result)):
25 | # print(models[i], end="\t")
26 | for j in range(result.shape[1]):
27 | print(f"{result[i, j]:.4f}", end="\t")
28 | print()
29 | print()
30 |
31 | out_dir = Path("/n/fs/nlp-mengzhou/space9/out/llm_eval/dynamic_icl_eval_May10")
32 |
33 |
34 | shots = 32
35 | task = "Subj"
36 | metric = "accuracy"
37 | eval_type = "eval"
38 | # result = np.zeros((len(models), len(shots)))
39 |
40 | print_model_name = False
41 | all_accs = []
42 | for task in ["Subj", "AGNews", "SST2", "CR", "MR", "MPQA", "AmazonPolarity"]:
43 | print(task)
44 | for tmp in [0, 1]:
45 | for single_FT in [True, False]:
46 | for label_type in ["context", "label_only"]:
47 | for icl in ["plain"]:
48 | seed_accs = []
49 | seed_models = []
50 | for seed in range(1, 4):
51 | acc = np.zeros((9, 2))
52 | count = 0
53 | for lr in ["1e-03", "1e-04", "1e-05"]:
54 | for layer in [3, 6, 12]:
55 | model = f"opt-125m-ntrain{shots}-{task}-lr{lr}-seed{seed}-single_FT{single_FT}-layers{layer}-{label_type}-tmp{tmp}-{icl}-{eval_type}"
56 | task_dir = out_dir / model
57 | file = task_dir / "metrics.jsonl"
58 | if os.path.exists(file):
59 | re = read_json(file)
60 | acc[count, 0] = re["accuracy"]
61 | else:
62 | print(file)
63 | count += 1
64 | model_name = f"lr{lr}-seed{seed}"
65 | seed_models.append(model_name)
66 | seed_acc = acc.max(axis=0)
67 | seed_accs.append(seed_acc * 100)
68 | acc = np.stack(seed_accs).std(axis=0)
69 | if print_model_name:
70 | print(" ".join(seed_models), end=" ")
71 | else:
72 | print(" ".join([str(a) for a in acc[0:1]]), end = " ")
73 | # print(seed_accs)
74 | print()
75 | all_accs.append(seed_accs)
76 |
77 | all_accs = np.stack(all_accs)
78 | print("std:", all_accs[:, :, 0].mean(0).std())
79 |
80 |
81 |
82 | # print(f"{task} {metric}")
83 | # print_numpy_result(result, models)
84 |
85 |
--------------------------------------------------------------------------------
/tint_main/utils/layernorm/forward.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import math
4 | import os
5 | from dataclasses import dataclass
6 | from typing import Optional, Tuple, Union
7 |
8 | import torch
9 | import torch.utils.checkpoint
10 | from torch import nn
11 | from torch.cuda.amp import autocast
12 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
13 | from ..modules import *
14 | from ..linear import *
15 |
16 |
17 |
18 |
19 | class LayerNormForward(nn.Module):
20 | def __init__(self, config, din, use_softmax, memory_index=-1):
21 | super(LayerNormForward, self).__init__()
22 | assert use_softmax==False ,\
23 | "Currently I only use linear attention in this module"
24 |
25 |
26 |
27 | self.linear=LinearForward ( config, din=din, dout=din, use_softmax=use_softmax, memory_index=memory_index )
28 | self.din=din
29 | self.epsilon = config.epsilon
30 | self.config=config
31 | self.memory_index = memory_index
32 |
33 | #w acts like a gate to decide what portion of the embedding we apply layernorm on
34 | self.w = torch.zeros (( 1, 1, config.hidden_size ))
35 | self.w [:, :, :din] += config.gate_scale
36 | self.gate = torch.nn.Tanh()
37 |
38 |
39 | #mask out normalization on prefixes
40 | self.normalization_gates = Gates (config)
41 | #Initialize Gates
42 | #Ignore the changes for the prefixes!
43 | #w, u, v, w_bias, u_bias, v_bias
44 | w = torch.zeros((1, 2*config.hidden_size))
45 | u = torch.zeros((1, 2*config.hidden_size))
46 | v = torch.zeros((1, 2*config.position_dim))
47 | w_bias = torch.zeros(2)
48 | u_bias = torch.zeros(2)
49 | v_bias = torch.zeros(2)
50 |
51 | #Input Gate is 1 on prefixes and 0 for non-prefixes
52 | v [0, config.seq_length: config.position_dim] = config.gate_scale * torch.ones(config.num_prefixes)
53 |
54 |
55 | #Change Gate is 0 on prefixes and 1 for non-prefixes
56 | v [0, config.position_dim+config.seq_length: 2*config.position_dim] = -config.gate_scale * torch.ones(config.num_prefixes)
57 | v_bias [1] += config.gate_scale
58 |
59 | self.normalization_gates.initialize_weights (w, u, v, w_bias, u_bias, v_bias)
60 |
61 |
62 |
63 |
64 | def forward(self, hidden_states, position_embeddings):
65 |
66 |
67 | weights = self.gate ( self.w ).to(hidden_states.device)
68 | mean = torch.sum(hidden_states * weights, dim=-1, keepdim=True) / torch.sum(weights, dim=-1, keepdim=True)
69 |
70 | var = ( self.epsilon + torch.sum( (weights * (hidden_states - mean)) ** 2, dim=-1, keepdim=True) / torch.sum(weights, dim=-1, keepdim=True) ) ** 0.5
71 |
72 | normalized_states = (hidden_states - mean) / var
73 | normalized_states = weights * normalized_states + (1. - weights) * hidden_states
74 |
75 | gated_output = self.normalization_gates.forward (hidden_states, normalized_states, position_embeddings)
76 |
77 | output = self.linear.forward ( gated_output, position_embeddings )
78 |
79 |
80 | #store [(x-\mu)/\sigma, x] for memory in backward pass
81 | assert torch.sum( torch.absolute( output[:, self.config.num_prefixes:, self.memory_index+self.din:]) ).item() < 1e-10,\
82 | "Memory portion not empty!"
83 | output[:, self.config.num_prefixes:, self.memory_index+self.din: self.memory_index+2*self.din] += hidden_states[:, self.config.num_prefixes:, :self.din]
84 |
85 | return output
86 |
87 |
88 |
--------------------------------------------------------------------------------
/icl_eval/scripts/TINT/print_paper_result.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import json
3 | import numpy as np
4 | import os
5 |
6 | def read_jsonl(file):
7 | ds = []
8 | try:
9 | with open(file) as f:
10 | for i, line in enumerate(f):
11 | d = json.loads(line.strip())
12 | ds.append(d)
13 | except:
14 | import pdb
15 | pdb.set_trace()
16 | return ds
17 |
18 | def print_numpy_result(result, models):
19 | for i in range(len(result)):
20 | # print(models[i], end="\t")
21 | for j in range(result.shape[1]):
22 | print(f"{result[i, j]:.4f}", end="\t")
23 | print()
24 | print()
25 |
26 | out_dir = Path("/scratch/gpfs/smalladi/mengzhou/out/llm_eval/dynamic_icl_eval_May10")
27 |
28 | # rows
29 | def get_models(label_type, layer, steps):
30 | models = []
31 | if label_type == "context": lrs = ["1e-05", "1e-06", "1e-07"]
32 | else: lrs = ["1e-03", "1e-04", "1e-05"]
33 | for lr in lrs:
34 | models.append(f"constructed-opt-backprop-{layer}-nstep-{steps}-lr-{lr}")
35 | return models
36 |
37 | shots = 32
38 | metric = "accuracy"
39 | eval_type = "eval"
40 | # result = np.zeros((len(models), len(shots)))
41 | icl = "icl_sfc"
42 | tmp = 0
43 |
44 | print_model_name = False
45 |
46 | all_res = []
47 | print_res = []
48 | for task in ["Subj", "AGNews", "SST2", "CR", "MR", "MPQA", "AmazonPolarity"]:
49 | print_re = []
50 | task_re = []
51 | for label_type in ["context", "label_only"]:
52 | for single_FT in [True, False]:
53 | for layer in [12, 6, 3]:
54 | step = 12 // layer
55 | sub_models = get_models(label_type, layer, step)
56 | seed_accs = []
57 | seed_models = []
58 | for seed in range(1, 4):
59 | acc = np.zeros((len(sub_models), 2))
60 | for i, model in enumerate(sub_models):
61 | model_dir = out_dir / model
62 | lr = "-".join(model.split("-")[-2:])
63 | layer = model.split("-")[3]
64 | task_dir = model_dir / f"ntrain{shots}-{task}-lr{lr}-seed{seed}-single_FT{single_FT}-layers{layer}-{label_type}-tmp{tmp}-{icl}-{eval_type}"
65 | file = task_dir / "metrics.jsonl"
66 | if os.path.exists(file):
67 | re = read_jsonl(file)
68 | acc[i, 0] = re[0]["accuracy"]
69 | # acc[i, 1] = re[1]["accuracy"]
70 | else:
71 | pass
72 | # print(file)
73 | model_name = sub_models[acc[:, 0].argmax(axis=0)]
74 | seed_models.append(model_name)
75 | seed_acc = acc.max(axis=0)
76 | seed_accs.append(seed_acc * 100)
77 | task_re.append(seed_acc * 100)
78 | acc = np.stack(seed_accs).std(axis=0)
79 | if print_model_name:
80 | print(" ".join(seed_models))
81 | else:
82 | print(" ".join([str(a) for a in acc[0:1]]))
83 | print_res.append(print_re)
84 | # print(seed_accs)
85 | all_res.append(task_re)
86 |
87 | print_res = np.array(print_res).transpose()
88 | for line in print_res:
89 | print(' '.join(map(str, line)))
90 |
91 | # print(f"{task} {metric}")
92 | # print_numpy_result(result, models)
93 |
94 | import pdb; pdb.set_trace()
95 | all_res = np.array(all_res)
96 | ree = all_res.reshape(7, 12, 3, -1)[:, :, :, 0]
97 |
98 | for n in ree.mean(0).std(-1): print(n)
--------------------------------------------------------------------------------
/tint_main/utils/activation/backward.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import math
4 | import os
5 | from dataclasses import dataclass
6 | from typing import Optional, Tuple, Union
7 |
8 | import torch
9 | import torch.utils.checkpoint
10 | from torch import nn
11 | from torch.cuda.amp import autocast
12 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
13 | from ..modules import *
14 | from ..linear import *
15 | from ..activations import *
16 |
17 |
18 |
19 | class ActivationBackward (nn.Module):
20 | def __init__ (self, config, din, input_projection=None, projection_matrix=None, memory_index=-1, retain_og_act=False):
21 | super(ActivationBackward, self).__init__()
22 |
23 |
24 |
25 | assert memory_index == -1 or memory_index >= din, \
26 | "memory crosses current signal"
27 |
28 |
29 |
30 | self.epsilon = config.epsilon
31 | self.memory_index = memory_index
32 | self.config = config
33 |
34 | head_dim = config.hidden_size // config.num_attention_heads
35 | self.c_fc = Conv2D(config.num_attention_heads, head_dim, transpose=True, use_einsum=self.config.use_einsum)
36 | self.proj_fc = Conv2D(config.num_attention_heads, head_dim, transpose=True, use_einsum=self.config.use_einsum)
37 |
38 | self.config = config
39 | self.din = din
40 | self.act = ACT2FN[config.activation_function]
41 |
42 |
43 | c_fc_init = torch.zeros((head_dim, config.num_attention_heads, config.num_attention_heads))
44 | c_proj_init = torch.zeros((head_dim, config.num_attention_heads, config.num_attention_heads))
45 |
46 |
47 |
48 | assert din % head_dim == 0, \
49 | " 'din' should be a multiple of head_dim! "
50 |
51 | num_partitions = din // head_dim
52 |
53 |
54 |
55 | assert self.memory_index % head_dim == 0, \
56 | "Memory should start at a multiple of head_dim!"
57 |
58 | mem_head_start = self.memory_index // head_dim
59 |
60 |
61 | start_shift = 0
62 | c_fc_init[:, start_shift: start_shift + num_partitions, start_shift: start_shift + num_partitions] = 1. / config.scale_embeddings * torch.eye(num_partitions)
63 | c_fc_init[:, start_shift: start_shift + num_partitions, mem_head_start: mem_head_start + num_partitions] = torch.eye(num_partitions)
64 |
65 | #pass x as well
66 | c_fc_init[:, mem_head_start: mem_head_start + num_partitions, mem_head_start: mem_head_start + num_partitions] = torch.eye(num_partitions)
67 |
68 |
69 | #Compute GeLU(x + 1/N \nabla y) - GeLU(x)
70 |
71 | c_proj_init[:, start_shift: start_shift + num_partitions, start_shift: start_shift + num_partitions] = config.scale_embeddings * torch.eye(num_partitions)
72 | c_proj_init[:, start_shift: start_shift + num_partitions, mem_head_start: mem_head_start + num_partitions] = -config.scale_embeddings * torch.eye(num_partitions)
73 |
74 | #retain Act (x) for future purposes?
75 | if retain_og_act:
76 | c_proj_init[:, mem_head_start: mem_head_start + num_partitions, mem_head_start: mem_head_start + num_partitions] = torch.eye(num_partitions)
77 |
78 |
79 | with torch.no_grad():
80 | self.c_fc.weight.copy_(torch.swapaxes(c_fc_init, axis0=-1, axis1=-2))
81 | self.proj_fc.weight.copy_(torch.swapaxes(c_proj_init, axis0=-1, axis1=-2))
82 |
83 |
84 |
85 | def forward(self, hidden_states, position_embeddings, attention_mask=None, activation_memory=None, icl_mask=None):
86 | output = self.proj_fc ( self.act( self.c_fc(hidden_states) ) )
87 | return output
88 |
89 |
--------------------------------------------------------------------------------
/icl_eval/scripts/run_eval.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | #SBATCH --job-name=icl
3 | #SBATCH --time=1:00:00
4 | #SBATCH --gres=gpu:1
5 | #SBATCH --mem=30gb
6 |
7 | # $model_path $model_name $output_dir $ntrain $task $lr $dynamic_eval $seed $sf $layer $sfc_type $eval_type $label_type
8 |
9 | script=$base_dir/llm_eval_construction/run.py
10 | model_path=$1
11 | model_name=$2
12 | output_dir=$3
13 | ntrain=$4
14 | task=$5
15 | lr=$6
16 | train_set_seed=$7
17 | single_FT=$8
18 | num_train_layers=${9}
19 | label_type=${10}
20 | sfc_type=${11}
21 | eval_type=${12}
22 | template=${13}
23 | test=${14}
24 | num_epochs=$(( 12 / $num_train_layers ))
25 | dynamic_eval=${15}
26 |
27 | # change the path of the constructed models
28 | model_path=/scratch/gpfs/ap34/icl-as-ft/Dynamic_initialization/Constructed_models_withlnupdate_fastlin/model_opt_backprop_${num_train_layers}_nstep_${num_epochs}_lr_${lr}
29 | model_name=constructed-opt-backprop-${num_train_layers}-nstep-${num_epochs}-lr-${lr}
30 |
31 | restrict_attention_demonstration=False # whether to use restrict the attention of each demonstration to itself in icl experiments
32 | position_modify_demonstration=False # whether to change the position ids of each demonstration to top; only relevant if restrict_attention_demonstration is true
33 |
34 | if [[ $single_FT == True ]]; then
35 | restrict_attention_demonstration=True
36 | position_modify_demonstration=True
37 | else
38 | restrict_attention_demonstration=False
39 | position_modify_demonstration=False
40 | fi
41 |
42 | if [[ $label_type == "label_only" ]]; then loss_label_only=True;
43 | else loss_label_only=False; fi
44 |
45 | sub_dir=${model_name}/ntrain${ntrain}-${task}-lr${lr}-seed${train_set_seed}-single_FT${single_FT}-layers${num_train_layers}-${label_type}-tmp${template}-${sfc_type}-${eval_type}
46 | output_dir=$output_dir/$sub_dir
47 | mkdir -p $output_dir
48 |
49 | echo "********** inside single script **********"
50 | echo model_path=$model_path
51 | echo model_name=$model_name
52 | echo output_dir=$output_dir
53 | echo ntrain=$ntrain
54 | echo task=$task
55 | echo train_set_seed=$train_set_seed
56 | echo single_FT=$single_FT
57 | echo label_type=$label_type
58 | echo sfc_type=${sfc_type}
59 | echo eval_type=${eval_type}
60 | echo restrict_attention_demonstration=${restrict_attention_demonstration}
61 | echo position_modify_demonstration=${position_modify_demonstration}
62 | echo template=${template}
63 | echo "Outputting to $output_dir/${file_name}"
64 | echo "********** inside single script **********"
65 |
66 |
67 | if [[ -f $output_dir/metrics.jsonl ]]; then
68 | echo "File exists: $output_dir/metrics.jsonl"
69 | exit 0
70 | fi
71 |
72 | if [[ $test == True ]]; then num_eval=4; else num_eval=200; fi
73 |
74 | cmd="python3 $script \
75 | --task_name $task \
76 | --num_train ${n_train} \
77 | --num_eval ${num_eval} \
78 | --model_path $model_path \
79 | --model_name $model_name \
80 | --load_float16 True \
81 | --pruned False \
82 | --output_dir=$output_dir \
83 | --train_set_seed $train_set_seed \
84 | --loss_type $label_type \
85 | --result_file $output_dir/metrics.jsonl \
86 | --exclusive_training True \
87 | --single_FT ${single_FT} \
88 | --num_train_layers ${num_train_layers} \
89 | --num_epochs ${num_epochs} \
90 | --eval_set_seed 0 \
91 | --restrict_attention_demonstration ${restrict_attention_demonstration} \
92 | --position_modify_demonstration ${position_modify_demonstration} \
93 | --loss_label_only ${loss_label_only} \
94 | --test_mode False \
95 | --template_id $template"
96 |
97 | if [[ $eval_type == "test" ]]; then
98 | cmd="$cmd --test_set_seed 10"
99 | fi
100 | if [[ $sfc_type == "sfc" ]]; then
101 | cmd="$cmd --sfc True"
102 | fi
103 | if [[ $sfc_type == "icl_sfc" ]]; then
104 | cmd="$cmd --icl_sfc True"
105 | fi
106 | $cmd 2>&1 | tee $output_dir/log.txt
107 |
--------------------------------------------------------------------------------
/icl_eval/scripts/others/print_result_May10.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import json
3 | import numpy as np
4 | import os
5 |
6 | def read_jsonl(file):
7 | ds = []
8 | try:
9 | with open(file) as f:
10 | for i, line in enumerate(f):
11 | d = json.loads(line.strip())
12 | ds.append(d)
13 | except:
14 | import pdb
15 | pdb.set_trace()
16 | return ds
17 |
18 | def print_numpy_result(result, models):
19 | for i in range(len(result)):
20 | # print(models[i], end="\t")
21 | for j in range(result.shape[1]):
22 | print(f"{result[i, j]:.4f}", end="\t")
23 | print()
24 | print()
25 |
26 | out_dir = Path("/scratch/gpfs/smalladi/mengzhou/out/llm_eval/dynamic_icl_eval_May10")
27 |
28 | # rows
29 | def get_models(label_type):
30 | models = []
31 | for layer in [12, 6, 3]:
32 | steps = 12 // layer
33 | if label_type == "context": lrs = ["1e-05", "1e-06", "1e-07"]
34 | else: lrs = ["1e-03", "1e-04", "1e-05"]
35 | for lr in lrs:
36 | models.append(f"constructed-opt-backprop-{layer}-nstep-{steps}-lr-{lr}")
37 | return models
38 |
39 | shots = 32
40 | task = "AmazonPolarity"
41 | metric = "accuracy"
42 | eval_type = "eval"
43 | # result = np.zeros((len(models), len(shots)))
44 |
45 | print_model_name = False
46 |
47 | print_res = []
48 | all_accs = []
49 | for task in ["Subj", "AGNews", "SST2", "CR", "MR", "MPQA", "AmazonPolarity"]:
50 | task_acc = []
51 | print(task)
52 | print_re = []
53 | for tmp in [0]:
54 | for single_FT in [True, False]:
55 | for label_type in ["context", "label_only"]:
56 | sub_models = get_models(label_type)
57 | for icl in ["icl_sfc"]:
58 | seed_accs = []
59 | seed_models = []
60 | all_seed_accs = []
61 | for seed in range(1, 4):
62 | acc = np.zeros((len(sub_models), 2))
63 | for i, model in enumerate(sub_models):
64 | model_dir = out_dir / model
65 | lr = "-".join(model.split("-")[-2:])
66 | layer = model.split("-")[3]
67 | task_dir = model_dir / f"ntrain{shots}-{task}-lr{lr}-seed{seed}-single_FT{single_FT}-layers{layer}-{label_type}-tmp{tmp}-{icl}-{eval_type}"
68 | file = task_dir / "metrics.jsonl"
69 | if os.path.exists(file):
70 | re = read_jsonl(file)
71 | acc[i, 0] = re[0]["accuracy"]
72 | # acc[i, 1] = re[1]["accuracy"]
73 | else:
74 | pass
75 | # print(file)
76 | model_name = sub_models[acc[:, 0].argmax(axis=0)]
77 | seed_models.append(model_name)
78 | seed_acc = acc.max(axis=0)
79 | seed_accs.append(seed_acc * 100)
80 | all_seed_accs.append(seed_acc[0])
81 | task_acc.append(all_seed_accs)
82 | acc = np.stack(seed_accs).std(axis=0)
83 | print_re.append(acc[0])
84 | if print_model_name:
85 | print(" ".join(seed_models), end=" ")
86 | else:
87 | print(" ".join([str(a) for a in acc[0:1]]), end = " ")
88 | # print(seed_accs)
89 | print()
90 | print_res.append(print_re)
91 | all_accs.append(task_acc)
92 | # print(seed_accs)
93 |
94 | print_res = np.array(print_res).transpose()
95 | for line in print_res:
96 | print(' '.join(map(str, line)))
97 |
98 | all_accs = np.array(all_accs) # num_tasks * num_settings * num_seeds
99 | import pdb; pdb.set_trace()
100 | print((np.array(all_accs) * 100).max(1).mean(0).std())
101 |
102 |
103 | # print(f"{task} {metric}")
104 | # print_numpy_result(result, models)
105 |
106 |
--------------------------------------------------------------------------------
/tint_main/utils/config.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class Construct_config:
5 | def __init__(self, model_config, construction_args):
6 | #Sequence length of auxiliary model
7 | self.seq_length = construction_args.seq_length
8 | #Sequence length of auxiliary model + Number of prefix tokens
9 | self.position_dim = construction_args.seq_length + construction_args.num_prefixes
10 | #Number of prefix tokens
11 | self.num_prefixes = construction_args.num_prefixes
12 | #Number of attention heads in TinT
13 | self.num_attention_heads = construction_args.num_attention_heads
14 | #Scale_embeddings (for softmax attention, unnecessary argument for now)
15 | self.scale_embeddings = construction_args.scale_embeddings
16 | #inner_lr for dynamic evaluation
17 | self.inner_lr = construction_args.inner_lr
18 | #Scaling for sigmoid gates to behave as hard gates
19 | self.gate_scale = construction_args.gate_scale
20 | #Embedding dimension of TinT
21 | self.hidden_size = construction_args.hidden_size
22 | #A max bound on the final sequence length, for initialization purposes
23 | self.max_position_embeddings = construction_args.max_position_embeddings
24 |
25 | #Following three arguments will be useful only when we pre-train
26 | #Dropout rate of embedding
27 | self.embd_pdrop=construction_args.embd_pdrop
28 | #Dropout rate of attention
29 | self.attn_pdrop=construction_args.attn_pdrop
30 | #Dropout rate of residual connection
31 | self.resid_pdrop=construction_args.resid_pdrop
32 |
33 | #Auxiliary's activation function
34 | self.activation_function=construction_args.activation_function
35 | #Auxiliary's error term for layernorm
36 | self.epsilon=construction_args.epsilon
37 | #Whether Attention score scales before softmax, as determined by Auxiliary model
38 | self.scale_attn_weights=construction_args.scale_attn_weights
39 | #Attention score appropriate scaling, as determined by Auxiliary model
40 | self.initial_scale=np.sqrt( (construction_args.hidden_size / model_config.hidden_size) * (model_config.num_attention_heads / construction_args.num_attention_heads) )
41 |
42 | #Number of layers to involve in SGD
43 | self.n_simulation_layers=construction_args.n_simulation_layers
44 | #Number of SGD steps
45 | self.n_forward_backward=construction_args.n_forward_backward
46 | #Unnecessary argument, was used for debugging
47 | self.n_debug_layers=construction_args.n_debug_layers
48 | #Unnecessary argument, was used for projection
49 | self.projection_paths=construction_args.projection_paths
50 | #We never backprop through attention, hence unnecessary argument for now
51 | self.backprop_through_attention=construction_args.backprop_through_attention
52 | #Whether to restrict attention between prefix and non-prefix tokens for linear operations
53 | self.restrict_prefixes=construction_args.restrict_prefixes
54 | #Whether to use einsum, which speeds up inference
55 | self.use_einsum=construction_args.use_einsum
56 | #Whether to use classification loss with softmax, for computing gradients
57 | self.use_prediction_loss=construction_args.use_prediction_loss
58 | #Whether to use quad loss from Saunshi et al., for computing gradients
59 | self.use_quad=construction_args.use_quad
60 | #self.n_gpus = construction_args.n_gpus
61 | #For multiple gpus, we can further partition the model across multiple gpus
62 | self.n_layers_pergpu = construction_args.n_layers_pergpu
63 | #'cuda'/'cpu'
64 | self.device = construction_args.device
65 |
66 | #Whether to reuse forward blocks, when we do multiple forward passes
67 | self.reuse_forward_blocks = construction_args.reuse_forward_blocks
68 | #Whether to reuse backward blocks, when we do multiple forward passes
69 | self.reuse_backward_blocks = construction_args.reuse_backward_blocks
70 |
71 | #Whether to only update biases in layernorm.
72 | self.ln_update_bias_only = construction_args.ln_update_bias_only
73 |
--------------------------------------------------------------------------------
/tests/perplexity_eval.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from tqdm import tqdm
3 | import os
4 |
5 | from tint_main.tint_creator import *
6 | from transformers import AutoTokenizer
7 | from datasets import load_dataset
8 | from transformers import HfArgumentParser
9 | from filelock import Timeout, FileLock
10 | from .data_utils.data_utils import *
11 | from tint_main.utils.all_arguments import *
12 |
13 | parser = HfArgumentParser((ModelArguments, ConstructionArguments, DynamicArguments,))
14 | model_args, construction_args, data_args, = parser.parse_args_into_dataclasses()
15 |
16 |
17 | tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
18 | if 'gpt' in model_args.model_name_or_path:
19 | tokenizer.pad_token = tokenizer.eos_token
20 | pad_token_id=tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
21 | elif 'opt' in model_args.model_name_or_path:
22 | tokenizer.bos_token_id = 0
23 |
24 | dataset = preprocess_dataset(data_args, tokenizer)
25 | batch_size=data_args.batch_size
26 | constructed_model = load_construction(model_args.model_name_or_path, model_args.cache_dir, model_args.construct_model_path)
27 |
28 |
29 |
30 | constructed_model.eval()
31 | if construction_args.device == 'cuda': device='cuda:0'
32 | else: device='cpu'
33 |
34 | num_valid_batches = len(dataset) // batch_size
35 | train_fraction = data_args.train_fraction
36 |
37 | if data_args.data_subset == 0:
38 | exit(0)
39 |
40 | if data_args.data_subset != -1:
41 | num_valid_batches = min(data_args.data_subset // batch_size, num_valid_batches)
42 |
43 |
44 | avg_model_test_perplexity = 0.
45 | avg_eval_test_perplexity = 0.
46 | total_test_words = 0.
47 |
48 |
49 |
50 | for batch_id in tqdm( range( num_valid_batches ), desc='Inference' ):
51 |
52 |
53 |
54 | data = dataset [ batch_id * batch_size : (batch_id + 1) * batch_size ]
55 | batch_sentences = torch.tensor( data ['input_ids'] )
56 | attention_mask = torch.tensor( data ['attention_mask'] )
57 | labels = torch.tensor( data ['labels'] )
58 |
59 | if len(batch_sentences.shape) == 1:
60 | batch_sentences = batch_sentences.view((1, -1))
61 | attention_mask = attention_mask.view((1, -1))
62 | labels = labels.view((1, -1))
63 |
64 |
65 |
66 | mask = torch.zeros_like(attention_mask)
67 | target = batch_sentences.detach().clone()
68 | target [ torch.where(attention_mask == 0.) ] = -100
69 |
70 | batch_seq_lengths = torch.sum(attention_mask, dim=-1)
71 | for i in range(len(batch_seq_lengths)):
72 | mask[i, :int(batch_seq_lengths[i] * train_fraction)] = 1.
73 | target[ i, :int(batch_seq_lengths[i] * train_fraction) ] = -100
74 |
75 | bidirection_mask = mask.float()
76 | gradient_mask = None
77 |
78 | with torch.no_grad():
79 | results = constructed_model.forward(batch_sentences.to(device), bidirection_mask.to(device), gradient_mask=gradient_mask, test_backward_pass=True, continue_from_first_forward_pass=False, labels=target.to(device))
80 | original_loss, final_loss = results.original_loss, results.final_loss
81 |
82 |
83 | total_terms = torch.ne(target, -100).float().sum()
84 |
85 | avg_model_test_perplexity += original_loss.item() * total_terms
86 | avg_eval_test_perplexity += final_loss.item() * total_terms
87 | total_test_words += total_terms
88 |
89 |
90 | final_result = {}
91 | final_result[ 'Validation Dynamic eval acc (on test)' ] = np.exp(avg_eval_test_perplexity / total_test_words)
92 | final_result[ 'Validation Model acc (on test)' ] = np.exp(avg_model_test_perplexity / total_test_words)
93 |
94 |
95 |
96 | with FileLock('log_exp_construct.lock'):
97 | with open('log_exp_construct', 'a') as f:
98 | final_result.update(vars(model_args))
99 | final_result.update(vars(data_args))
100 | final_result.update(vars(construction_args))
101 | f.write(str(final_result) + '\n')
102 |
103 | import torch.utils.benchmark as benchmark
104 | def benchmark_forward(fn, *inputs, repeats=10, desc='', verbose=True, amp=False,
105 | amp_dtype=torch.float16, **kwinputs):
106 | """ Use Pytorch Benchmark on the forward pass of an arbitrary function. """
107 | if verbose:
108 | print(desc, '- Forward pass')
109 | def fn_amp(*inputs, **kwinputs):
110 | with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
111 | fn(*inputs, **kwinputs)
112 | for _ in range(repeats): # warmup
113 | fn_amp(*inputs, **kwinputs)
114 | t = benchmark.Timer(
115 | stmt='fn_amp(*inputs, **kwinputs)',
116 | globals={'fn_amp': fn_amp, 'inputs': inputs, 'kwinputs': kwinputs},
117 | num_threads=torch.get_num_threads(),
118 | )
119 | m = t.timeit(repeats)
120 | if verbose:
121 | print(m)
122 | return t, m
--------------------------------------------------------------------------------
/tests/data_utils/data_utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import torch
4 | import copy
5 | from filelock import FileLock
6 | from itertools import chain
7 |
8 | from accelerate import Accelerator
9 | from datasets import load_dataset, concatenate_datasets
10 | logger = logging.getLogger(__name__)
11 |
12 | def get_raw_data(args):
13 | if args.dataset == 'wikitext-103':
14 | dataset = load_dataset('wikitext', 'wikitext-103-v1', cache_dir=args.data_cache_dir)
15 | if args.use_eval_set:
16 | dataset = dataset['validation']
17 | elif args.use_test_set:
18 | dataset = dataset['test']
19 | else:
20 | dataset = dataset['train']
21 | return dataset
22 | elif args.dataset == 'wikitext-2':
23 | print (args.data_cache_dir)
24 | dataset = load_dataset('wikitext', 'wikitext-2-v1', cache_dir=args.data_cache_dir)
25 | if args.use_eval_set:
26 | dataset = dataset['validation']
27 | elif args.use_test_set:
28 | dataset = dataset['test']
29 | else:
30 | dataset = dataset['train']
31 | return dataset
32 | elif args.dataset == 'c4':
33 | print (args.data_cache_dir)
34 | dataset = load_dataset('c4', 'realnewslike', cache_dir=args.data_cache_dir)
35 | if args.use_eval_set:
36 | dataset = dataset['validation']
37 | elif args.use_test_set:
38 | dataset = dataset['validation']
39 | else:
40 | dataset = dataset['train']
41 | return dataset
42 | raise NotImplementedError
43 |
44 | def preprocess_dataset(args, tokenizer):
45 | raw_datasets = get_raw_data(args)
46 |
47 | text_column_name = "text"
48 | column_names = raw_datasets.column_names
49 |
50 |
51 | accelerator = Accelerator()
52 |
53 | def tokenize_function(examples):
54 | return tokenizer([i for i in examples[text_column_name] if len(i)>0])
55 |
56 | with accelerator.main_process_first():
57 | datasets = raw_datasets.map(
58 | tokenize_function,
59 | batched=True,
60 | num_proc=args.num_workers,
61 | remove_columns=column_names,
62 | load_from_cache_file=True,
63 | desc="Running tokenizer on dataset",
64 | )
65 |
66 | block_size = args.block_size
67 | # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
68 | def group_texts(examples):
69 | # Concatenate all texts.
70 | concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
71 | total_length = len(concatenated_examples[list(examples.keys())[0]])
72 | # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
73 | # customize this part to your needs.
74 | if total_length >= block_size:
75 | total_length = (total_length // block_size) * block_size
76 | # Split by chunks of max_len.
77 | result = {
78 | k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
79 | for k, t in concatenated_examples.items()
80 | }
81 | result["labels"] = result["input_ids"].copy()
82 | return result
83 |
84 | def concat_texts(examples):
85 | # Concatenate all texts.
86 | concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
87 | concatenated_examples['labels'] = concatenated_examples['input_ids'].copy()
88 | return concatenated_examples
89 |
90 | if args.chunk_data:
91 | with accelerator.main_process_first():
92 | datasets = datasets.map(
93 | group_texts,
94 | batched=True,
95 | load_from_cache_file=True,
96 | num_proc=args.num_workers,
97 | desc=f"Grouping texts in chunks of {block_size}",
98 | )
99 | else:
100 | with accelerator.main_process_first():
101 | datasets = datasets.map(
102 | concat_texts,
103 | batched=True,
104 | load_from_cache_file=True,
105 | num_proc=args.num_workers,
106 | desc=f"Concatenating the texts",
107 | )
108 |
109 | return datasets
110 |
111 |
112 | def prepare_inputs(batch, padding=None):
113 | for k in batch:
114 | if len(batch[k].shape) == 1:
115 | batch[k] = batch[k].reshape(1, -1)
116 | if padding is not None:
117 | for k in batch:
118 | batch[k] = torch.cat((batch[k], padding))
119 | batch[k] = batch[k].cuda()
120 |
121 | return batch
122 |
123 | def init_model_and_optimizer(model_init, args):
124 | model = copy.deepcopy(model_init)
125 | if args.freeze_embs:
126 | if 'opt' in args.model_name:
127 | model.model.decoder.embed_tokens.requires_grad = False
128 | model.model.decoder.embed_positions.requires_grad = False
129 | else:
130 | model.transformer.wte.requires_grad = False
131 | model.transformer.wpe.requires_grad = False
132 | optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate)
133 | return model, optimizer
--------------------------------------------------------------------------------
/tint_main/tint_creator.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | from .tint_gpt import *
5 | from .tint_opt import *
6 | from transformers import GPT2LMHeadModel, OPTForCausalLM
7 | from transformers import HfArgumentParser
8 | from transformers import AutoConfig
9 | from .utils.config import Construct_config
10 |
11 |
12 | #Creates the tree view of the model layers
13 | def nested_children(m: torch.nn.Module):
14 | children = dict(m.named_children())
15 | output = {}
16 | if children == {}:
17 | # if module has no children; m is last child! :O
18 | return m
19 | else:
20 | # look for children from children... to the last child!
21 | for name, child in children.items():
22 | try:
23 | output[name] = nested_children(child)
24 | except TypeError:
25 | output[name] = nested_children(child)
26 | return output
27 |
28 |
29 | #Creates TinT on a given auxiliary models
30 | #model_args contains the necessary config for auxiliary model
31 | #construction_args contains the necessary config for TinT
32 | #The module calls TinT_gpt for gpt models and TinT_opt for opt models
33 | def TinT_Creator(model_args, construction_args):
34 | model_config = AutoConfig.from_pretrained(
35 | model_args.config_name if model_args.config_name else model_args.model_name_or_path,
36 | cache_dir=model_args.cache_dir
37 | )
38 |
39 |
40 | construction_args.activation_function = model_config.activation_function
41 | construction_args.seq_length = model_config.max_position_embeddings
42 |
43 | config = Construct_config(model_config, construction_args)
44 |
45 | if 'gpt' in model_args.model_name_or_path:
46 | model_fn = GPT2LMHeadModel
47 | elif 'opt' in model_args.model_name_or_path:
48 | model_fn = OPTForCausalLM
49 | else:
50 | raise NotImplmentedError
51 |
52 | model = model_fn.from_pretrained(
53 | model_args.model_name_or_path,
54 | config=model_config,
55 | cache_dir=model_args.cache_dir,
56 | )
57 |
58 |
59 | print ("....Constructing the model....")
60 |
61 | if model_args.construct_load_model:
62 | print ("....Load constructed model....")
63 | checkpoint = torch.load(model_args.construct_model_path)
64 | model_config = checkpoint['model_config']
65 | config = checkpoint['construction_config']
66 |
67 | print (nested_children(model))
68 | if 'gpt' in model_args.model_name_or_path:
69 | constructed_model = TinT_gpt(config, model_config, nested_children(model))
70 | elif 'opt' in model_args.model_name_or_path:
71 | constructed_model = TinT_opt(config, model_config, nested_children(model))
72 | else:
73 | raise NotImplmentedError
74 |
75 | if model_args.construct_load_model:
76 | constructed_model.load_state_dict(checkpoint['model_state_dict'])
77 |
78 |
79 | if model_args.construct_save_model:
80 | print ("....Store constructed model....")
81 | torch.save({'model_state_dict': constructed_model.state_dict(),\
82 | 'model_config': model_config,\
83 | 'construction_config': config},\
84 | model_args.construct_model_path,\
85 | )
86 |
87 |
88 | tot = 0
89 | for parameter in constructed_model.parameters():
90 | tot += parameter.numel()
91 | print ("Total trainable parameters in constructed model", tot)
92 |
93 | return (constructed_model, model, model_config, config)
94 |
95 |
96 |
97 | #Creates TinT from a config checkpoint
98 | #model_path: auxiliary model name gpt2 or facebook/opt-125m
99 | #cache_dir is where the auxiliary model has been saved from huggingface
100 | #construction_args contains the necessary config for TinT
101 | #The module calls TinT_gpt for gpt models and TinT_opt for opt models
102 | def load_construction(model_path, \
103 | cache_dir, \
104 | construction_path, \
105 | ):
106 |
107 | model_config = AutoConfig.from_pretrained(
108 | model_path,
109 | cache_dir=cache_dir
110 | )
111 |
112 |
113 | if 'gpt' in model_path:
114 | model_fn = GPT2LMHeadModel
115 | elif 'opt' in model_path:
116 | model_fn = OPTForCausalLM
117 | else:
118 | raise NotImplmentedError
119 |
120 | model = model_fn.from_pretrained(
121 | model_path,
122 | config=model_config,
123 | cache_dir=cache_dir,
124 | )
125 |
126 | import time
127 | start = time.time()
128 | print ("....Load constructed model....")
129 | checkpoint = torch.load(construction_path)
130 | model_config = checkpoint['model_config']
131 | config = checkpoint['construction_config']
132 |
133 | if 'gpt' in model_path:
134 | constructed_model = TinT_gpt(config, model_config, nested_children(model))
135 | elif 'opt' in model_path:
136 | constructed_model = TinT_opt(config, model_config, nested_children(model))
137 | else:
138 | raise NotImplmentedError
139 |
140 | constructed_model.load_state_dict(checkpoint['model_state_dict'])
141 | end = time.time()
142 | print("Time for loading constructed model: ", round(end - start, 2))
143 | return constructed_model
144 |
145 |
146 | if __name__ == "__main__":
147 | parser = HfArgumentParser((ModelArguments,))
148 | model_args, = parser.parse_args_into_dataclasses()
149 |
150 | constructed_model = load_construction(model_args.model_name_or_path, \
151 | model_args.cache_dir, \
152 | model_args.construct_model_path, \
153 | )
154 |
155 |
156 |
157 |
158 |
--------------------------------------------------------------------------------
/tint_main/utils/activation/forward.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import math
4 | import os
5 | from dataclasses import dataclass
6 | from typing import Optional, Tuple, Union
7 |
8 | import torch
9 | import torch.utils.checkpoint
10 | from torch import nn
11 | from torch.cuda.amp import autocast
12 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
13 | from ..modules import *
14 | from ..linear import *
15 | from ..activations import *
16 |
17 |
18 |
19 |
20 | class ActivationForward (nn.Module):
21 | def __init__ (self, config, din, projection_matrix=None, memory_index=-1):
22 | super(ActivationForward, self).__init__()
23 |
24 | self.din=din
25 | self.config=config
26 | self.projection_matrix = projection_matrix
27 | self.memory_index = memory_index
28 |
29 |
30 | if projection_matrix is not None:
31 | self.dout = projection_matrix.shape[0]
32 | else:
33 | self.dout = din
34 |
35 |
36 | assert memory_index == -1 or memory_index >= self.dout,\
37 | "Memory interacts with final signal"
38 |
39 | assert memory_index == -1 or memory_index <= config.hidden_size - self.din, \
40 | "not enough space to store memory"
41 |
42 | if projection_matrix is not None:
43 | head_dim = self.dout
44 | num_channels = config.hidden_size // head_dim
45 | else:
46 | num_channels = config.num_attention_heads
47 | head_dim = config.hidden_size // num_channels
48 |
49 | self.mlp_module = MLP (config.hidden_size, \
50 | config, \
51 | conv2d=True, \
52 | transpose_intermediate=True, \
53 | transpose_proj=False, \
54 | conv_proj_features=num_channels, \
55 | )
56 |
57 | self.mlp_gates = Gates (config)
58 | self.projection_ = None
59 |
60 |
61 | if projection_matrix is not None:
62 |
63 | assert projection_matrix.shape[1] == din,\
64 | "Projection matrix must have 'din' in second coordinate"
65 | assert projection_matrix.shape[1] >= head_dim, \
66 | "Currently, this projection only works when we project down to a lower dimension"
67 | assert projection_matrix.shape[1] % head_dim == 0, \
68 | "Perfect division into channels"
69 |
70 | c_proj_init = torch.zeros((num_channels, head_dim, head_dim), dtype=self.mlp_module.c_proj.weight.dtype)
71 | num_useful_channels = projection_matrix.shape[1] // head_dim
72 | for i in range (num_useful_channels):
73 | c_proj_init[i] = torch.tensor(projection_matrix[:, i*head_dim: (i+1)*head_dim], dtype=self.mlp_module.c_proj.weight.dtype)
74 | self.mlp_module.initialize_weights(c_proj_init=c_proj_init)
75 |
76 | self.projection_ = Conv2D( nf=num_channels, nx=head_dim, transpose=True, use_einsum=self.config.use_einsum )
77 | with torch.no_grad():
78 | self.projection_.weight.copy_(torch.zeros(head_dim, num_channels, num_channels))
79 | self.projection_.weight[:, :num_useful_channels, 0] = 1.
80 |
81 | else:
82 | c_proj_init = torch.zeros((num_channels, head_dim, head_dim), dtype=self.mlp_module.c_proj.weight.dtype)
83 |
84 | if self.memory_index != -1:
85 | assert memory_index % head_dim == 0, \
86 | "Memory should be divisible by the number of channels!"
87 |
88 | mem_head_start = memory_index // head_dim
89 |
90 | c_proj_init[:mem_head_start] = torch.eye(head_dim)
91 | self.mlp_module.initialize_weights(c_proj_init=c_proj_init)
92 | else:
93 | c_proj_init[:] = torch.eye(head_dim)
94 | self.mlp_module.initialize_weights(c_proj_init=c_proj_init)
95 |
96 | #Initialize Gates
97 | #Ignore the changes for the prefixes!
98 | #w, u, v, w_bias, u_bias, v_bias
99 | w = torch.zeros((1, 2*config.hidden_size))
100 | u = torch.zeros((1, 2*config.hidden_size))
101 | v = torch.zeros((1, 2*config.position_dim))
102 | w_bias = torch.zeros(2)
103 | u_bias = torch.zeros(2)
104 | v_bias = torch.zeros(2)
105 |
106 | #Input Gate is 1 on prefixes and 0 for non-prefixes
107 | v [0, config.seq_length: config.position_dim] = config.gate_scale * torch.ones(config.position_dim-config.seq_length)
108 |
109 |
110 | #Change Gate is 0 on prefixes and 1 for non-prefixes
111 | v [0, config.position_dim+config.seq_length: 2*config.position_dim] = -config.gate_scale * torch.ones(config.position_dim-config.seq_length)
112 | v_bias [1] += config.gate_scale
113 |
114 | self.mlp_gates.initialize_weights (w, u, v, w_bias, u_bias, v_bias)
115 |
116 |
117 | def forward(self, hidden_states, position_embeddings):
118 | mlp_out = self.mlp_module.forward(hidden_states)
119 | if self.projection_ is not None:
120 | mlp_out = self.projection_(mlp_out)
121 |
122 | if self.memory_index != -1:
123 | assert torch.sum(torch.absolute(mlp_out[:, self.config.num_prefixes:, self.memory_index:])).item() < 1e-10,\
124 | "Memory portion not empty!"
125 |
126 | mlp_out[:, self.config.num_prefixes:, self.memory_index: self.memory_index+self.din] += hidden_states[:, self.config.num_prefixes:, :self.din]
127 |
128 | gate_out = self.mlp_gates.forward(hidden_states, mlp_out, position_embeddings)
129 |
130 | return gate_out
131 |
132 |
133 |
134 |
--------------------------------------------------------------------------------
/env.yml:
--------------------------------------------------------------------------------
1 | name: icl_as_ft
2 | channels:
3 | - pytorch
4 | - soumith
5 | - nvidia
6 | - defaults
7 | dependencies:
8 | - _libgcc_mutex=0.1=main
9 | - _openmp_mutex=5.1=1_gnu
10 | - blas=1.0=mkl
11 | - brotlipy=0.7.0=py310h7f8727e_1002
12 | - bzip2=1.0.8=h7b6447c_0
13 | - ca-certificates=2022.10.11=h06a4308_0
14 | - certifi=2022.9.24=py310h06a4308_0
15 | - cffi=1.15.1=py310h74dc2b5_0
16 | - cryptography=38.0.1=py310h9ce1e76_0
17 | - cuda=11.7.1=0
18 | - cuda-cccl=11.7.91=0
19 | - cuda-command-line-tools=11.7.1=0
20 | - cuda-compiler=11.7.1=0
21 | - cuda-cudart=11.7.99=0
22 | - cuda-cudart-dev=11.7.99=0
23 | - cuda-cuobjdump=11.7.91=0
24 | - cuda-cupti=11.7.101=0
25 | - cuda-cuxxfilt=11.7.91=0
26 | - cuda-demo-suite=11.8.86=0
27 | - cuda-documentation=11.8.86=0
28 | - cuda-driver-dev=11.7.99=0
29 | - cuda-gdb=11.8.86=0
30 | - cuda-libraries=11.7.1=0
31 | - cuda-libraries-dev=11.7.1=0
32 | - cuda-memcheck=11.8.86=0
33 | - cuda-nsight=11.8.86=0
34 | - cuda-nsight-compute=11.8.0=0
35 | - cuda-nvcc=11.7.99=0
36 | - cuda-nvdisasm=11.8.86=0
37 | - cuda-nvml-dev=11.7.91=0
38 | - cuda-nvprof=11.8.87=0
39 | - cuda-nvprune=11.7.91=0
40 | - cuda-nvrtc=11.7.99=0
41 | - cuda-nvrtc-dev=11.7.99=0
42 | - cuda-nvtx=11.7.91=0
43 | - cuda-nvvp=11.8.87=0
44 | - cuda-runtime=11.7.1=0
45 | - cuda-sanitizer-api=11.8.86=0
46 | - cuda-toolkit=11.7.1=0
47 | - cuda-tools=11.7.1=0
48 | - cuda-visual-tools=11.7.1=0
49 | - cuda80=1.0=0
50 | - cudatoolkit=11.1.74=h6bb024c_0
51 | - ffmpeg=4.3=hf484d3e_0
52 | - freetype=2.12.1=h4a9f257_0
53 | - gds-tools=1.4.0.31=0
54 | - giflib=5.2.1=h7b6447c_0
55 | - gmp=6.2.1=h295c915_3
56 | - gnutls=3.6.15=he1e5248_0
57 | - idna=3.4=py310h06a4308_0
58 | - intel-openmp=2021.4.0=h06a4308_3561
59 | - jpeg=9e=h7f8727e_0
60 | - lame=3.100=h7b6447c_0
61 | - lcms2=2.12=h3be6417_0
62 | - ld_impl_linux-64=2.38=h1181459_1
63 | - lerc=3.0=h295c915_0
64 | - libcublas=11.11.3.6=0
65 | - libcublas-dev=11.11.3.6=0
66 | - libcufft=10.9.0.58=0
67 | - libcufft-dev=10.9.0.58=0
68 | - libcufile=1.4.0.31=0
69 | - libcufile-dev=1.4.0.31=0
70 | - libcurand=10.3.0.86=0
71 | - libcurand-dev=10.3.0.86=0
72 | - libcusolver=11.4.1.48=0
73 | - libcusolver-dev=11.4.1.48=0
74 | - libcusparse=11.7.5.86=0
75 | - libcusparse-dev=11.7.5.86=0
76 | - libdeflate=1.8=h7f8727e_5
77 | - libffi=3.3=he6710b0_2
78 | - libgcc-ng=11.2.0=h1234567_1
79 | - libgomp=11.2.0=h1234567_1
80 | - libiconv=1.16=h7f8727e_2
81 | - libidn2=2.3.2=h7f8727e_0
82 | - libnpp=11.8.0.86=0
83 | - libnpp-dev=11.8.0.86=0
84 | - libnvjpeg=11.9.0.86=0
85 | - libnvjpeg-dev=11.9.0.86=0
86 | - libpng=1.6.37=hbc83047_0
87 | - libstdcxx-ng=11.2.0=h1234567_1
88 | - libtasn1=4.16.0=h27cfd23_0
89 | - libtiff=4.4.0=hecacb30_2
90 | - libunistring=0.9.10=h27cfd23_0
91 | - libuuid=1.41.5=h5eee18b_0
92 | - libwebp=1.2.4=h11a3e52_0
93 | - libwebp-base=1.2.4=h5eee18b_0
94 | - lz4-c=1.9.3=h295c915_1
95 | - mkl=2021.4.0=h06a4308_640
96 | - mkl-service=2.4.0=py310h7f8727e_0
97 | - mkl_fft=1.3.1=py310hd6ae3a3_0
98 | - mkl_random=1.2.2=py310h00e6091_0
99 | - ncurses=6.3=h5eee18b_3
100 | - nettle=3.7.3=hbbd107a_1
101 | - nsight-compute=2022.3.0.22=0
102 | - numpy=1.23.4=py310hd5efca6_0
103 | - numpy-base=1.23.4=py310h8e6c178_0
104 | - openh264=2.1.1=h4ff587b_0
105 | - openssl=1.1.1s=h7f8727e_0
106 | - pillow=9.2.0=py310hace64e9_1
107 | - pip=22.2.2=py310h06a4308_0
108 | - pycparser=2.21=pyhd3eb1b0_0
109 | - pyopenssl=22.0.0=pyhd3eb1b0_0
110 | - pysocks=1.7.1=py310h06a4308_0
111 | - python=3.10.8=haa1d7c7_0
112 | - pytorch=1.13.0=py3.10_cuda11.7_cudnn8.5.0_0
113 | - pytorch-cuda=11.7=h67b0de4_0
114 | - pytorch-mutex=1.0=cuda
115 | - readline=8.2=h5eee18b_0
116 | - requests=2.28.1=py310h06a4308_0
117 | - setuptools=65.5.0=py310h06a4308_0
118 | - six=1.16.0=pyhd3eb1b0_1
119 | - sqlite=3.39.3=h5082296_0
120 | - tk=8.6.12=h1ccaba5_0
121 | - torchaudio=0.13.0=py310_cu117
122 | - torchvision=0.14.0=py310_cu117
123 | - typing_extensions=4.3.0=py310h06a4308_0
124 | - tzdata=2022f=h04d1e81_0
125 | - urllib3=1.26.12=py310h06a4308_0
126 | - wheel=0.37.1=pyhd3eb1b0_0
127 | - xz=5.2.6=h5eee18b_0
128 | - zlib=1.2.13=h5eee18b_0
129 | - zstd=1.5.2=ha4553b6_0
130 | - pip:
131 | - accelerate==0.14.0
132 | - aiohttp==3.8.3
133 | - aiosignal==1.3.1
134 | - anyio==3.6.2
135 | - argon2-cffi==21.3.0
136 | - argon2-cffi-bindings==21.2.0
137 | - asttokens==2.2.0
138 | - async-timeout==4.0.2
139 | - attrs==22.1.0
140 | - backcall==0.2.0
141 | - beautifulsoup4==4.11.1
142 | - bleach==5.0.1
143 | - charset-normalizer==2.1.1
144 | - contourpy==1.0.6
145 | - cycler==0.11.0
146 | - datasets==2.7.1
147 | - debugpy==1.6.4
148 | - decorator==5.1.1
149 | - defusedxml==0.7.1
150 | - dill==0.3.6
151 | - entrypoints==0.4
152 | - executing==1.2.0
153 | - fastjsonschema==2.16.2
154 | - filelock==3.8.0
155 | - fonttools==4.38.0
156 | - frozenlist==1.3.3
157 | - fsspec==2022.11.0
158 | - huggingface-hub==0.11.0
159 | - ipykernel==6.17.1
160 | - ipython==8.7.0
161 | - ipython-genutils==0.2.0
162 | - jedi==0.18.2
163 | - jinja2==3.1.2
164 | - jsonschema==4.17.3
165 | - jupyter-client==7.4.7
166 | - jupyter-core==5.1.0
167 | - jupyter-server==1.23.3
168 | - jupyterlab-pygments==0.2.2
169 | - kiwisolver==1.4.4
170 | - markupsafe==2.1.1
171 | - matplotlib==3.6.2
172 | - matplotlib-inline==0.1.6
173 | - mistune==2.0.4
174 | - multidict==6.0.2
175 | - multiprocess==0.70.14
176 | - nbclassic==0.4.8
177 | - nbclient==0.7.2
178 | - nbconvert==7.2.5
179 | - nbformat==5.7.0
180 | - nest-asyncio==1.5.6
181 | - notebook==6.5.2
182 | - notebook-shim==0.2.2
183 | - packaging==21.3
184 | - pandas==1.5.1
185 | - pandocfilters==1.5.0
186 | - parso==0.8.3
187 | - pexpect==4.8.0
188 | - pickleshare==0.7.5
189 | - platformdirs==2.5.4
190 | - prometheus-client==0.15.0
191 | - prompt-toolkit==3.0.33
192 | - psutil==5.9.4
193 | - ptyprocess==0.7.0
194 | - pure-eval==0.2.2
195 | - pyarrow==10.0.0
196 | - pygments==2.13.0
197 | - pyparsing==3.0.9
198 | - pyrsistent==0.19.2
199 | - python-dateutil==2.8.2
200 | - pytz==2022.6
201 | - pyyaml==6.0
202 | - pyzmq==24.0.1
203 | - regex==2022.10.31
204 | - responses==0.18.0
205 | - send2trash==1.8.0
206 | - sniffio==1.3.0
207 | - soupsieve==2.3.2.post1
208 | - stack-data==0.6.2
209 | - terminado==0.17.0
210 | - tinycss2==1.2.1
211 | - tokenizers==0.13.2
212 | - tornado==6.2
213 | - tqdm==4.64.1
214 | - traitlets==5.6.0
215 | - transformers==4.24.0
216 | - wcwidth==0.2.5
217 | - webencodings==0.5.1
218 | - websocket-client==1.4.2
219 | - xxhash==3.1.0
220 | - yarl==1.8.1
221 |
--------------------------------------------------------------------------------
/tint_main/utils/activations.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import math
16 | from collections import OrderedDict
17 |
18 | import torch
19 | from packaging import version
20 | from torch import Tensor, nn
21 |
22 |
23 |
24 | class PytorchGELUTanh(nn.Module):
25 | """
26 | A fast C implementation of the tanh approximation of the GeLU activation function. See
27 | https://arxiv.org/abs/1606.08415.
28 |
29 | This implementation is equivalent to NewGELU and FastGELU but much faster. However, it is not an exact numerical
30 | match due to rounding errors.
31 | """
32 |
33 | def __init__(self):
34 | super().__init__()
35 | if version.parse(torch.__version__) < version.parse("1.12.0"):
36 | raise ImportError(
37 | f"You are using torch=={torch.__version__}, but torch>=1.12.0 is required to use "
38 | "PytorchGELUTanh. Please upgrade torch."
39 | )
40 |
41 | def forward(self, input: Tensor) -> Tensor:
42 | return nn.functional.gelu(input, approximate="tanh")
43 |
44 |
45 | class NewGELUActivation(nn.Module):
46 | """
47 | Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
48 | the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
49 | """
50 |
51 | def forward(self, input: Tensor) -> Tensor:
52 | return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
53 |
54 |
55 | class GELUActivation(nn.Module):
56 | """
57 | Original Implementation of the GELU activation function in Google BERT repo when initially created. For
58 | information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 +
59 | torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional
60 | Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
61 | """
62 |
63 | def __init__(self, use_gelu_python: bool = False):
64 | super().__init__()
65 | if use_gelu_python:
66 | self.act = self._gelu_python
67 | else:
68 | self.act = nn.functional.gelu
69 |
70 | def _gelu_python(self, input: Tensor) -> Tensor:
71 | return input * 0.5 * (1.0 + torch.erf(input / math.sqrt(2.0)))
72 |
73 | def forward(self, input: Tensor) -> Tensor:
74 | return self.act(input)
75 |
76 |
77 | class FastGELUActivation(nn.Module):
78 | """
79 | Applies GELU approximation that is slower than QuickGELU but more accurate. See: https://github.com/hendrycks/GELUs
80 | """
81 |
82 | def forward(self, input: Tensor) -> Tensor:
83 | return 0.5 * input * (1.0 + torch.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input)))
84 |
85 |
86 | class QuickGELUActivation(nn.Module):
87 | """
88 | Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs
89 | """
90 |
91 | def forward(self, input: Tensor) -> Tensor:
92 | return input * torch.sigmoid(1.702 * input)
93 |
94 |
95 | class ClippedGELUActivation(nn.Module):
96 | """
97 | Clip the range of possible GeLU outputs between [min, max]. This is especially useful for quantization purpose, as
98 | it allows mapping negatives values in the GeLU spectrum. For more information on this trick, please refer to
99 | https://arxiv.org/abs/2004.09602.
100 |
101 | Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when
102 | initially created.
103 |
104 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 +
105 | torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))). See https://arxiv.org/abs/1606.08415
106 | """
107 |
108 | def __init__(self, min: float, max: float):
109 | if min > max:
110 | raise ValueError(f"min should be < max (got min: {min}, max: {max})")
111 |
112 | super().__init__()
113 | self.min = min
114 | self.max = max
115 |
116 | def forward(self, x: Tensor) -> Tensor:
117 | return torch.clip(gelu(x), self.min, self.max)
118 |
119 |
120 | class SiLUActivation(nn.Module):
121 | """
122 | See Gaussian Error Linear Units (Hendrycks et al., https://arxiv.org/abs/1606.08415) where the SiLU (Sigmoid Linear
123 | Unit) was originally introduced and coined, and see Sigmoid-Weighted Linear Units for Neural Network Function
124 | Approximation in Reinforcement Learning (Elfwing et al., https://arxiv.org/abs/1702.03118) and Swish: a Self-Gated
125 | Activation Function (Ramachandran et al., https://arxiv.org/abs/1710.05941v1) where the SiLU was experimented with
126 | later.
127 | """
128 |
129 | def forward(self, input: Tensor) -> Tensor:
130 | return nn.functional.silu(input)
131 |
132 |
133 | class MishActivation(nn.Module):
134 | """
135 | See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://arxiv.org/abs/1908.08681). Also
136 | visit the official repository for the paper: https://github.com/digantamisra98/Mish
137 | """
138 |
139 | def __init__(self):
140 | super().__init__()
141 | if version.parse(torch.__version__) < version.parse("1.9.0"):
142 | self.act = self._mish_python
143 | else:
144 | self.act = nn.functional.mish
145 |
146 | def _mish_python(self, input: Tensor) -> Tensor:
147 | return input * torch.tanh(nn.functional.softplus(input))
148 |
149 | def forward(self, input: Tensor) -> Tensor:
150 | return self.act(input)
151 |
152 |
153 | class LinearActivation(nn.Module):
154 | """
155 | Applies the linear activation function, i.e. forwarding input directly to output.
156 | """
157 |
158 | def forward(self, input: Tensor) -> Tensor:
159 | return input
160 |
161 |
162 | class ClassInstantier(OrderedDict):
163 | def __getitem__(self, key):
164 | content = super().__getitem__(key)
165 | cls, kwargs = content if isinstance(content, tuple) else (content, {})
166 | return cls(**kwargs)
167 |
168 |
169 | ACT2CLS = {
170 | "gelu": GELUActivation,
171 | "gelu_10": (ClippedGELUActivation, {"min": -10, "max": 10}),
172 | "gelu_fast": FastGELUActivation,
173 | "gelu_new": NewGELUActivation,
174 | "gelu_python": (GELUActivation, {"use_gelu_python": True}),
175 | "gelu_pytorch_tanh": PytorchGELUTanh,
176 | "linear": LinearActivation,
177 | "mish": MishActivation,
178 | "quick_gelu": QuickGELUActivation,
179 | "relu": nn.ReLU,
180 | "relu6": nn.ReLU6,
181 | "sigmoid": nn.Sigmoid,
182 | "silu": SiLUActivation,
183 | "swish": SiLUActivation,
184 | "tanh": nn.Tanh,
185 | }
186 | ACT2FN = ClassInstantier(ACT2CLS)
187 |
188 |
189 | def get_activation(activation_string):
190 | if activation_string in ACT2FN:
191 | return ACT2FN[activation_string]
192 | else:
193 | raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")
194 |
195 |
196 | # For backwards compatibility with: from activations import gelu_python
197 | gelu_python = get_activation("gelu_python")
198 | gelu_new = get_activation("gelu_new")
199 | gelu = get_activation("gelu")
200 | gelu_fast = get_activation("gelu_fast")
201 | quick_gelu = get_activation("quick_gelu")
202 | silu = get_activation("silu")
203 | mish = get_activation("mish")
204 | linear_act = get_activation("linear")
--------------------------------------------------------------------------------
/tint_main/utils/all_arguments.py:
--------------------------------------------------------------------------------
1 | from transformers import HfArgumentParser
2 | from dataclasses import dataclass, field
3 | from typing import Callable, Dict, Optional, Union, List
4 |
5 | @dataclass
6 | class ModelArguments:
7 | """
8 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
9 | """
10 | model_name_or_path: str = field(
11 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
12 | )
13 | config_name: Optional[str] = field(
14 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
15 | )
16 | cache_dir: Optional[str] = field(
17 | default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
18 | )
19 |
20 |
21 | construct_save_model: Optional[bool] = field(
22 | default=False, metadata={"help": "Whether to save constructed model"}
23 | )
24 |
25 |
26 | construct_load_model: Optional[bool] = field(
27 | default=False, metadata={"help": "Whether to load constructed model from a specific path"}
28 | )
29 |
30 | construct_model_path: Optional[str] = field(
31 | default="", metadata={"help": "Path to save the model"}
32 | )
33 |
34 |
35 |
36 |
37 |
38 |
39 | @dataclass
40 | class ConstructionArguments:
41 | """
42 | Arguments pertaining
43 | """
44 |
45 | device: Optional[str] = field(
46 | default="cuda", metadata={"help": "cuda/cpu"}
47 | )
48 |
49 | seq_length: Optional[int] = field(
50 | default=1024, metadata={"help": "Sequence length for the smaller model"}
51 | )
52 |
53 | #position_dim: Optional[int] = field(
54 | # default=1024 + 256, metadata={"help": ""}
55 | #)
56 |
57 | num_prefixes: Optional[int] = field(
58 | default=256, metadata={"help": "Number of prefixes to encode at the start of the sequence"}
59 | )
60 |
61 | num_attention_heads: Optional[int] = field(
62 | default=12, metadata={"help": "Number of attention heads"}
63 | )
64 |
65 | scale_embeddings: Optional[float] = field(
66 | default=1000., metadata={"help": "Scale factor to minimize error introduced by multiplications with GeLU"}
67 | )
68 |
69 | inner_lr: Optional[float] = field(
70 | default=0.000001, metadata={"help": "Learning rate to simulate SGD inside the model"}
71 | )
72 |
73 | gate_scale: Optional[float] = field(
74 | default=10., metadata={"help": "Initial scale inside gate weights to simulate 0/1 switch"}
75 | )
76 | hidden_size: Optional[int] = field(
77 | default=4, metadata={"help": "Multiple of Embedding size (of the smaller model) for the construction"}
78 | )
79 |
80 | max_position_embeddings: Optional[int] = field(
81 | default=2048, metadata={"help": "Max sequence length"}
82 | )
83 |
84 | embd_pdrop: Optional[float] = field(
85 | default=0.1, metadata={"help": "Dropout on attention weights"}
86 | )
87 |
88 | attn_pdrop: Optional[float] = field(
89 | default=0.1, metadata={"help": "Dropout on attention weights"}
90 | )
91 | resid_pdrop: Optional[float] = field(
92 | default=0.1, metadata={"help": "Dropout on residual connections"}
93 | )
94 |
95 | activation_function: Optional[str] = field(
96 | default="gelu", metadata={"help": "Activation: gelu/relu"}
97 | )
98 | epsilon: Optional[float] = field(
99 | default=1e-05, metadata={"help": "Epsilon for layernorm computation"}
100 | )
101 |
102 | scale_attn_weights: Optional[bool] = field(
103 | default=True, metadata={"help": "Whether to scale attention weights"}
104 | )
105 |
106 | n_simulation_layers: Optional[int] = field(
107 | default=-1, metadata={"help": "Number of layers to simulate forward-backward passes"}
108 | )
109 |
110 | n_forward_backward: Optional[int] = field(
111 | default=1, metadata={"help": "Number of forward-backward passes"}
112 | )
113 |
114 | n_debug_layers: Optional[int] = field(
115 | default=-1, metadata={"help": "Number of layers of smaller model that we use (for debugging purposes)"}
116 | )
117 |
118 | projection_paths: Optional[str] = field(
119 | default="./projections", metadata={"help": "Path to all the projection matrices"}
120 | )
121 |
122 | backprop_through_attention: Optional[bool] = field(
123 | default=True, metadata={"help": "Whether to look ahead during backprop through attention"}
124 | )
125 |
126 | restrict_prefixes: Optional[bool] = field(
127 | default=False, metadata={"help": "Whether to restrict attention to blank tokens only in linear forward/backward"}
128 | )
129 |
130 | use_einsum: Optional[bool] = field(
131 | default=True, metadata={"help": "Whether to use einsum in dimension wise linear convolutions"}
132 | )
133 |
134 | use_prediction_loss: Optional[bool] = field(
135 | default=True, metadata={"help": "If true, gradient w.r.t. loss is E(p-q) with p being the softmax prediction, if false (and use_quad is false), gradient w.r.t. loss is E(1 - q)"}
136 | )
137 |
138 | use_quad: Optional[bool] = field(
139 | default=False, metadata={"help": "If true, we use the quad loss to compute the gradient from Saunshi & Malladi et al. (2020)"}
140 | )
141 |
142 |
143 | n_layers_pergpu: Optional[int] = field(
144 | default=100000, metadata={"help": "Number of layers simulated per gpu"}
145 | )
146 |
147 | reuse_forward_blocks: Optional[bool] = field(
148 | default=False, metadata={"help": "Whether to re-use forward blocks"}
149 | )
150 |
151 | reuse_backward_blocks: Optional[bool] = field(
152 | default=False, metadata={"help": "Whether to re-use backward blocks"}
153 | )
154 |
155 | ln_update_bias_only: Optional[bool] = field(
156 | default=True, metadata={"help": "Whether to update only biases in layernorm modules; current weight update of layernorm modules is very noisy, hence it's best to avoid!"}
157 | )
158 |
159 |
160 | @dataclass
161 | class DynamicArguments:
162 |
163 | data_cache_dir: Optional[str] = field(
164 | default='/scratch/gpfs/ap34/icl-as-ft/Dynamic_initialization/data',
165 | metadata={
166 | 'help': 'where to store downloaded model, datasets, and tokenizer'
167 | }
168 | )
169 |
170 | log_dir: Optional[str] = field(
171 | default='/scratch/gpfs/smalladi/icl_as_ft/logs'
172 | )
173 |
174 | dataset: Optional[str] = field(
175 | default='wikitext-103',
176 | metadata={
177 | 'help': 'dataset to use, should be in HF datasets repo'
178 | }
179 | )
180 |
181 | incontext_dataset: Optional[str] = field(
182 | default="", metadata={"help": "Dataset for in-context experiments!"}
183 | )
184 |
185 | incontext_stem: Optional[str] = field(
186 | default="", metadata={"help": "Directory containing the files!"}
187 | )
188 |
189 | incontext_n_shot: Optional[int] = field(
190 | default=4, metadata={"help": "Number of demonstations in the in-context prompt!"}
191 | )
192 |
193 | prompt_variant: Optional[int] = field(
194 | default=0, metadata={"help": "Variants to try for sst-2!"}
195 | )
196 |
197 |
198 | use_eval_set: bool = field(
199 | default=True,
200 | metadata={'help': 'when on, uses the eval dataset only for ppl measurement/in-context expts.'}
201 | )
202 |
203 | use_test_set: bool = field(
204 | default=False,
205 | metadata={'help': 'when on, uses the test dataset only for ppl measurement/in-context expts.'}
206 | )
207 |
208 | ### finetuning strategies ###
209 |
210 | chunk_data: bool = field(
211 | default=True,
212 | metadata={'help': 'when on, chunks the data into block_size chunks'}
213 | )
214 |
215 | baseline: bool = field(
216 | default=False
217 | )
218 |
219 | num_subsets: int = field(
220 | default=16
221 | )
222 | data_chunk_index: int = field(
223 | default=0
224 | )
225 |
226 | block_size: Optional[int] = field(
227 | default=1024,
228 | metadata={'help': 'size of chunks to break corpus into'}
229 | )
230 |
231 | num_workers: Optional[int] = field(
232 | default=4,
233 | metadata={'help': 'number of workers'}
234 | )
235 |
236 | data_subset: Optional[int] = field(
237 | default=-1, metadata={"help": "Number of training examples to use in the subset!"}
238 | )
239 |
240 | train_fraction: Optional[float] = field(
241 | default=0.5, metadata={"help": "Fraction of sentence for training!"}
242 | )
243 |
244 |
245 | batch_size: Optional[int] = field(
246 | default=1, metadata={"help": "Batch size for training!"}
247 | )
248 |
249 |
--------------------------------------------------------------------------------
/tint_main/utils/layernorm/backward.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import math
4 | import os
5 | from dataclasses import dataclass
6 | from typing import Optional, Tuple, Union
7 |
8 | import torch
9 | import torch.utils.checkpoint
10 | from torch import nn
11 | from torch.cuda.amp import autocast
12 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
13 | from ..modules import *
14 | from ..linear import *
15 |
16 |
17 | #Assumption on memory
18 | #It should contain [(x-\mu)/\sigma, x]
19 |
20 | class LayerNormBackward(nn.Module):
21 | def __init__(self, config, din, use_softmax, retain_nablay=False, memory_index=-1):
22 | super(LayerNormBackward, self).__init__()
23 |
24 | assert use_softmax==False ,\
25 | "Currently I only use linear attention in this module"
26 |
27 | assert memory_index == -1 or memory_index >= din, \
28 | "memory crosses current signal"
29 |
30 |
31 | self.linear = LinearBackward(config, \
32 | din=din, \
33 | dout=din, \
34 | use_softmax=use_softmax, \
35 | retain_nablay=retain_nablay, \
36 | memory_index=memory_index, \
37 | )
38 | self.epsilon = config.epsilon
39 | self.memory_index = memory_index
40 | self.config = config
41 |
42 | head_dim = config.hidden_size // config.num_attention_heads
43 | self.c_fc = Conv2D(config.num_attention_heads, head_dim, transpose=True, use_einsum=self.config.use_einsum)
44 | self.proj_fc = Conv2D(config.num_attention_heads, head_dim, transpose=True, use_einsum=self.config.use_einsum)
45 |
46 | self.config = config
47 | self.din = din
48 |
49 |
50 | c_fc_init = torch.zeros((head_dim, config.num_attention_heads, config.num_attention_heads))
51 | c_proj_init = torch.zeros((head_dim, config.num_attention_heads, config.num_attention_heads))
52 |
53 |
54 |
55 |
56 | assert din % head_dim == 0, \
57 | " 'din' should be a multiple of head_dim! "
58 |
59 | num_partitions = din // head_dim
60 |
61 |
62 |
63 | assert self.memory_index % head_dim == 0, \
64 | "Memory should start at a multiple of head_dim!"
65 |
66 | mem_head_start = self.memory_index // head_dim
67 |
68 | if retain_nablay:
69 | start_shift = num_partitions
70 | else:
71 | start_shift = 0
72 |
73 | c_fc_init[:, start_shift: start_shift + num_partitions, start_shift: start_shift + num_partitions] = 1. / config.scale_embeddings * torch.eye(num_partitions)
74 | #1. / config.scale_embeddings
75 | c_fc_init[:, start_shift: start_shift + num_partitions, mem_head_start + num_partitions: mem_head_start + 2*num_partitions] = torch.eye(num_partitions)
76 |
77 |
78 | #Compute GeLU(x + 1/N \nabla y) - GeLU(x)
79 |
80 | c_proj_init[:, start_shift: start_shift + num_partitions, start_shift: start_shift + num_partitions] = config.scale_embeddings * torch.eye(num_partitions)
81 | c_proj_init[:, start_shift: start_shift + num_partitions, mem_head_start: mem_head_start + num_partitions] = -config.scale_embeddings * torch.eye(num_partitions)
82 |
83 |
84 | with torch.no_grad():
85 |
86 | self.c_fc.weight.copy_(torch.swapaxes(c_fc_init, axis0=-1, axis1=-2))
87 | self.proj_fc.weight.copy_(torch.swapaxes(c_proj_init, axis0=-1, axis1=-2))
88 |
89 | #w acts like a gate to decide what portion of the embedding we apply layernorm on
90 | self.w = torch.zeros (( 1, 1, config.hidden_size ))
91 | if retain_nablay:
92 | self.w [:, :, din : 2*din] += config.gate_scale
93 | else:
94 | self.w [:, :, : din] += config.gate_scale
95 | self.gate = torch.nn.Tanh()
96 |
97 |
98 | #mask out normalization on prefixes
99 | self.normalization_gates = Gates (config)
100 | #Initialize Gates
101 | #Ignore the changes for the prefixes!
102 | #w, u, v, w_bias, u_bias, v_bias
103 | w = torch.zeros((1, 2*config.hidden_size))
104 | u = torch.zeros((1, 2*config.hidden_size))
105 | v = torch.zeros((1, 2*config.position_dim))
106 | w_bias = torch.zeros(2)
107 | u_bias = torch.zeros(2)
108 | v_bias = torch.zeros(2)
109 |
110 | #Input Gate is 1 on prefixes and 0 for non-prefixes
111 | v [0, config.seq_length: config.position_dim] = config.gate_scale * torch.ones(config.num_prefixes)
112 |
113 |
114 | #Change Gate is 0 on prefixes and 1 for non-prefixes
115 | v [0, config.position_dim+config.seq_length: 2*config.position_dim] = -config.gate_scale * torch.ones(config.num_prefixes)
116 | v_bias [1] += config.gate_scale
117 |
118 | self.normalization_gates.initialize_weights (w, u, v, w_bias, u_bias, v_bias)
119 |
120 |
121 |
122 | def forward(self, hidden_states, position_states, attention_mask=None, icl_mask=None):
123 |
124 | weights = self.gate ( self.w ).to(hidden_states.device)
125 |
126 | back_gradient = self.linear.forward(hidden_states, position_states)
127 |
128 | #print (back_gradient[0, 12, :72], back_gradient[0, 12, -72:])
129 | #######################################################################
130 | #Next few lines compute the operation:
131 | # f(x) = (x - \mu(x)) / \nabla(x)
132 | # N (f(x + 1/N \nabla y) - f(x))
133 | #######################################################################
134 | first_layer = self.c_fc.forward ( back_gradient )
135 | first_layer = weights * first_layer + (1. - weights) * back_gradient
136 |
137 | #print (first_layer[0, 12, :72])
138 |
139 | mean = torch.sum(first_layer * weights, dim=-1, keepdim=True) / torch.sum(weights, dim=-1, keepdim=True)
140 | var = ( self.epsilon + torch.sum( (weights * (first_layer - mean)) ** 2, dim=-1, keepdim=True) / torch.sum(weights, dim=-1, keepdim=True) ) ** 0.5
141 |
142 | #print (var)
143 | normalized_states = (first_layer - mean) / var
144 | #print (normalized_states[:, 192, :64])
145 |
146 | normalized_states = weights * normalized_states + (1. - weights) * first_layer
147 |
148 | second_layer = self.proj_fc.forward ( normalized_states )
149 |
150 | second_layer = weights * second_layer + (1. - weights) * normalized_states
151 |
152 | #######################################################################
153 |
154 | gated_output = self.normalization_gates.forward ( hidden_states, second_layer, position_states)
155 |
156 | return gated_output
157 |
158 |
159 |
160 | class LayerNormDescent(nn.Module):
161 | def __init__ (self, config, din, use_softmax, memory_index=-1, debug_zero=False):
162 | super(LayerNormDescent, self).__init__()
163 | self.config=config
164 | self.linear = LinearDescent(config, din=din, dout=din, use_softmax=use_softmax, memory_index=memory_index, debug_zero=debug_zero, update_bias_only=self.config.ln_update_bias_only)
165 |
166 | def forward(self, hidden_states, position_states, attention_mask, activation_memory=None, icl_mask=None):
167 | return self.linear.forward(hidden_states, position_states, attention_mask)
168 |
169 |
170 |
171 | class LayerNormDescent_Backward(nn.Module):
172 | def __init__(self, config, din, use_softmax, debug_zero=False, retain_nablay=False, projection_matrix=None, memory_index=-1):
173 | super(LayerNormDescent_Backward, self).__init__()
174 |
175 |
176 | self.config = config
177 | self.backward = LayerNormBackward(config, \
178 | din=din, \
179 | use_softmax=use_softmax, \
180 | retain_nablay=retain_nablay, \
181 | memory_index=memory_index, \
182 | )
183 |
184 | self.descent = LayerNormDescent(config, \
185 | din=din, \
186 | use_softmax=use_softmax, \
187 | memory_index=memory_index,\
188 | debug_zero=debug_zero, \
189 | )
190 |
191 | def forward(self, hidden_states, position_embeddings, attention_mask, activation_memory=None, icl_mask=None):
192 | backward_out = self.backward(hidden_states, position_embeddings)
193 | descent_out = self.descent (hidden_states, position_embeddings, attention_mask)
194 |
195 | return torch.cat( [ descent_out[:, :self.config.num_prefixes], backward_out[:, self.config.num_prefixes:] ], axis=1)
196 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## TinT: Trainable Transformer in Transformer
2 |
3 | This repository contains the code for our paper Trainable Transformer in Transformer (TinT).
4 |
5 | ## Quick Links
6 |
7 | - [TinT: Trainable Transformer in Transformer](#tint-trainable-transformer-in-transformer)
8 | - [Quick Links](#quick-links)
9 | - [Overview](#overview)
10 | - [Structure of TinT](#structure-of-tint)
11 | - [Creating TinT](#creating-tint)
12 | - [Requirements](#requirements)
13 | - [Perplexity Evaluation](#perplexity-evaluation)
14 | - [Downstream Evaluation](#downstream-evaluation)
15 | - [Hyperparameter Considerations](#hyperparameter-considerations)
16 | - [Unavailable features](#unavailable-features)
17 | - [Bugs or Questions](#bugs-or-questions)
18 |
19 |
20 |
21 | ## Overview
22 |
23 | We propose an efficient construction, Transformer in Transformer (in short, TinT), that allows a transformer to simulate and fine-tune complex models internally during inference (e.g., pre-trained language models). In particular, we introduce innovative approximation techniques that allow a TinT model with less than 2 billion parameters to simulate and fine-tune a 125 million parameter transformer model within a single forward pass. TinT accommodates many common transformer variants and its design ideas also improve the efficiency of past instantiations of simple models inside transformers. We conduct end-to-end experiments to validate the internal fine-tuning procedure of TinT on various language modeling and downstream tasks. For example, even with a limited one-step budget, we observe TinT for a OPT-125M model improves performance by 4-16% absolute on average compared to OPT-125M. These findings suggest that large pre-trained language models are capable of performing intricate subroutines.
24 |
25 | ### Structure of TinT
26 |
27 | Each Forward, Backward, and Descent module is represented using combinations of linear, self-attention, layernorm, and activation layers. The input consists of prefix embeddings, that represent relevant auxiliary model parameters in each layer, input token embeddings, and a binary prefix mask to separate the train and evaluation segments of the input. The auxiliary model parameters are updated in the descent module using the training part of the segment, and the updated prefix tokens are transferred to the forward modules via residual connections for evaluating the rest of the segment.
28 |
29 | 
30 |
31 |
32 | ## Creating TinT
33 |
34 | In the following section, we provide instructions on creating and evaluating TinT models with our code.
35 |
36 | ### Requirements
37 | Install necessary conda environment using
38 |
39 | ```bash
40 | conda env create -n icl_as_ft --file icl_as_ft.yml
41 | ```
42 |
43 | ### Create and store TinT
44 |
45 | ```bash
46 | python -m tests.create_model \
47 | --model_name_or_path $model_path \
48 | --cache_dir $cache_dir \
49 | --construct_model_path $model_path \
50 | --n_simulation_layers $nsim_layers \
51 | --n_forward_backward $n_forward_backward \
52 | --inner_lr $lr \
53 | --n_layers_pergpu $n_layers_pergpu \
54 | --num_attention_heads $num_attention_heads \
55 | --hidden_size $hidden_size \
56 | --num_prefixes $num_prefixes \
57 | --construct_save_model $construct_save_model \
58 | --reuse_forward_blocks $reuse_forward_blocks \
59 | --reuse_backward_blocks $reuse_backward_blocks \
60 | --restrict_prefixes $restrict_prefixes;
61 | ```
62 |
63 |
64 | * `model_path`: facebook/opt-125m or gpt2, Auxiliary model to create the TinT model
65 | * `cache_dir`: Directory to store and load opt/gpt2 models
66 | * `construct_save_model`: Whether to save the constructed model
67 | * `construct_model_path`: Path to load or save the constructed model
68 | * `n_simulation_layers`: Number of layers to update during dynamic evaluation
69 | * `n_forward_backward`: Number of SGD steps
70 | * `num_attention_heads`: Number of attention heads in constructed model
71 | * `hidden_size`: Embedding size of constructed model
72 | * `num_prefixes`: Number of prefix tokens
73 | * `inner_lr`: Learning rate for dynamic evaluation; note that in our construction, gradients are summed over tokens (and not averaged)
74 | * `n_layers_pergpu`: When using multiple gpus, partition layers, with n_layers_pergpu per gpu
75 | * `reuse_forward_blocks`: True/False, For multi step SGD, reuse transformer blocks for simulating forward pass
76 | * `reuse_backward_blocks`: True/False, For multi step SGD, reuse transformer blocks for simulating backward pass
77 | * `restrict_prefixes`: For linear operations, we can decide the linear attention heads to only restrict to interactions between prefix tokens and input embeddings
78 |
79 |
80 | An example to create a TinT model from auxiliary model gpt2 is as follows:
81 |
82 | ```bash
83 | python -m tests.create_model \
84 | --model_name_or_path gpt2 \
85 | --cache_dir "cache/" \
86 | --construct_model_path "Constructed_model/TinT_gpt2_innerlr04_ngradlayers12_sgdsteps1" \
87 | --n_simulation_layers 12 \
88 | --n_forward_backward 1 \
89 | --inner_lr 1e-04 \
90 | --n_layers_pergpu 36 \
91 | --num_attention_heads 12 \
92 | --hidden_size 3072 \
93 | --num_prefixes 256 \
94 | --construct_save_model True \
95 | --reuse_forward_blocks True \
96 | --reuse_backward_blocks True \
97 | --restrict_prefixes True;
98 | ```
99 |
100 |
101 | ### Perplexity Evaluation
102 | Use the following commandline to run perplexity evaluation on wikitext-2, wikitext-103, and c4.
103 |
104 |
105 | ```bash
106 | python -m tests.perplexity_eval \
107 | --dataset $dataset \
108 | --model_name_or_path $model_path \
109 | --cache_dir $cache_dir \
110 | --construct_model_path $model_path \
111 | --train_fraction $train_fraction \
112 | --batch_size $batch_size \
113 | --use_eval_set $use_eval_set\
114 | --use_test_set $use_test_set\
115 | --data_subset $data_subset;
116 | ```
117 |
118 | * `dataset`: c4/wikitext-2/wikitext-103
119 | * `model_path`: facebook/opt-125m or gpt2, Auxiliary model used to create the TinT model
120 | * `cache_dir`: Directory to store and load opt/gpt2 models
121 | * `construct_model_path`: Path to load the constructed model
122 | * `train_fraction`: Fraction of input to use for training (float between 0 and 1)!
123 | * `batch_size`: Batch size for the forward passes
124 | * `use_eval_set`: True/False, Use validation set?
125 | * `use_test_set`: True/False, Use test set? (if both use_eval_set and use_test_set are True, test set is used for evaluation)
126 | * `data_subset`: Evaluation on subset of data (must be a multiple of batch size).
127 |
128 |
129 | The results are stored in a json format, with all the arguments, in a file named **log_exp_construct**. An example for perplexity evaluation of the TinT model "Constructed_model/TinT_gpt2_innerlr04_ngradlayers12_sgdsteps1" on wikitext-103 is as follows:
130 |
131 | ```bash
132 | python -m tests.perplexity_eval \
133 | --dataset wikitext-103 \
134 | --model_name_or_path gpt2 \
135 | --cache_dir "cache/" \
136 | --construct_model_path "Constructed_model/TinT_gpt2_innerlr04_ngradlayers12_sgdsteps1" \
137 | --train_fraction 0.5 \
138 | --batch_size 4 \
139 | --use_eval_set True\
140 | --use_test_set False;
141 | ```
142 |
143 |
144 | ### Downstream Evaluation
145 |
146 | Please refer to the README file in [**icl_eval**](https://github.com/abhishekpanigrahi1996/transformer_in_transformer/tree/main/icl_eval) folder.
147 |
148 |
149 | ### Hyperparameter considerations
150 |
151 | Embedding size and number of attention heads in the TinT model depends on the number of weights that we stack in each prefix, the number of prefix tokens, and dimensions of the auxiliary model. Multiple assertions are present in the code to pertain to these inter-dependencies. We give a set of general rules below to decide on the hyperparameters and provide the hyperparameters that we used to construct the TinT models.
152 |
153 |
154 | There are three important dependencies to consider.
155 | * Embedding size of TinT (given by hidden_size argument) must be equal to the embedding size of the auxiliary model times (the number of weight rows that we stack per prefix token + 1). The addition of 1 is to include the bias terms in the first prefix token. E.g. for gpt2, whose embedding dimension is 768, if we decide to stack 3 weight rows per prefix token, the embedding dimension of TinT should be 768 * 4.
156 | * The number of weight rows that we stack per prefix token is equal to the number of weight rows divided by the number of prefixes (given by num_prefixes argument). E.g. for gpt2, whose embedding dimension is 768, if we decide to stack 3 weight rows per prefix token, the number of prefix tokens should be 256.
157 | * hidden_size should be divisible by the number of attention heads (given by num_attention_heads).
158 | * Attention head dimension (given by hidden_size // num_attention_heads) should be a factor of the auxiliary model's embedding dimension. This is to ensure that we can partition the embeddings of the auxiliary model equally across a few attention heads.
159 | * hidden_size must be divisible by num_prefix_tokens. Our current implementation allows unequal number of attention heads in linear attention, used to simulate linear operations, and softmax attention, used to simulate operations involving self attention. The number of attention heads in linear attention is given by (hidden_size // num_prefix_tokens).
160 |
161 |
162 | We use the following hyperparameters to create the TinT model and the perplexity/downstream evaluations. We report inner_lr (learning rate of dynamic evaluation) for the models that we report the numbers on.
163 |
164 | | | gpt2 | gpt2-medium | gpt2-large | gpt2-xl |
165 | |:--------------|:-----------:|:--------------:|:---------:|:---------:|
166 | | Auxiliary model embedding size | 768 | 1024 | 1280 | 1600 |
167 | | Auxiliary model attention heads | 12 | 16 | 20 | 25 |
168 | | Number of layers | 12 | 24 | 36 | 48 |
169 | | TinT hidden_size | 3072 | 5120 | 6400 | 9600 |
170 | | TinT num_prefixes | 256 | 256 | 320 | 320 |
171 | | TinT num_attention_heads | 12 | 20 | 20 | 30 |
172 | | Inner LR (dynamic eval) | 1e-3, 5e-4, 1e-4, 1e-5 | - | - | - |
173 |
174 |
175 |
176 | | | facebook/opt-125m | facebook/opt-350m* | facebook/opt-1.3b | facebook/opt-2.7b |
177 | |:--------------|:-----------:|:--------------:|:---------:|:---------:|
178 | | Auxiliary model embedding size | 768 | 1024 | 2048 | 2560 |
179 | | Auxiliary model attention heads | 12 | 16 | 32 | 32 |
180 | | Number of layers | 12 | 24 | 24 | 32 |
181 | | TinT hidden_size | 3072 | 5120 | 10240 | 12800 |
182 | | TinT num_prefixes | 256 | 256 | 512 | 640 |
183 | | TinT num_attention_heads | 12 | 20 | 40 | 40 |
184 | | Inner LR (dynamic eval) | 1e-5, 1e-6, 1e-7 | - | - | - |
185 |
186 | *We can't handle post layer norm in facebook/opt-350m in the current code.
187 |
188 | ### Unavailable features
189 |
190 | The current code doesn't contain the following features, which we plan to slowly integrate in the future.
191 | * `Post layer norm`: Currently, our code doesn't handle post layer norm and hence can't create TinT for facebook/opt-350m.
192 | * `Cross self attention`: The self attention module hasn't been modified to handle cross attention.
193 | * `TinT modules for gated linear units (GLUs)`: We will integrate the modules for GLUs soon.
194 | * `Attention variants`: We will integrate attention variants like AliBi and relative attention soon.
195 | * `RMSnorm`: We will integrate modules for RMSnorm soon.
196 | * `TinT for GPT-J, BLOOM, LLaMA`: We will include TinT creators for these models soon.
197 |
198 |
199 | ## Bugs or Questions
200 | If you have any questions related to the code, feel free to email Abhishek or Mengzhou (`{ap34,mengzhou}@cs.princeton.edu`). If you encounter a problem or bug when using the code, you can also open an issue.
201 |
202 |
203 |
--------------------------------------------------------------------------------
/tint_main/utils/self_attention/forward.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import math
4 | import os
5 | from dataclasses import dataclass
6 | from typing import Optional, Tuple, Union
7 |
8 | import torch
9 | import torch.utils.checkpoint
10 | from torch import nn
11 | from torch.cuda.amp import autocast
12 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
13 | from ..modules import *
14 | from ..linear import *
15 |
16 |
17 |
18 |
19 |
20 |
21 | #The following module simulates the forward operation of a softmax self-attention in the Auxiliary model.
22 | #Output: A module that first computes Query, Key, Value using a single LinearForward module, followed by a softmax attention layer
23 |
24 | #All arguments
25 | #self, \
26 | #config, \ #TinT config file
27 | #din, \ #input dimension
28 | #num_attnt_heads, \ #number of attention heads in the Auxiliary's self-attention
29 | #use_softmax=False, \ #use_softmax=False, implying we use linear attention head
30 | #projection_matrix=None, \ #projection_matrix to project the linear operation (not used in the current code)
31 | #separate_QK,\ #Unnecessary argument, to be removed
32 | #memory_index=-1, \ #memory_index to store activations for the backward and descent pass!
33 | class AttentionForward (nn.Module):
34 | def __init__ (self, \
35 | config, \
36 | din, \
37 | num_attnt_heads, \
38 | use_softmax=False, \
39 | projection_matrix=None, \
40 | separate_QK=False, \
41 | memory_index=0,\
42 | ):
43 | super(AttentionForward, self).__init__()
44 |
45 | assert use_softmax==False ,\
46 | "Currently I only use linear attention for linear operations in this module"
47 |
48 | assert num_attnt_heads <= config.num_attention_heads,\
49 | "Number of attention heads should be at least the number of attention heads necessary to simulate"
50 |
51 | self.separate_QK = separate_QK
52 | if projection_matrix is not None:
53 | dout = projection_matrix.shape[1]
54 | else:
55 | if separate_QK: dout = 2*din
56 | else: dout = din
57 |
58 | self.linear = LinearForward(config, \
59 | din=din, \
60 | dout=dout, \
61 | use_softmax=use_softmax, \
62 | projection_matrix=projection_matrix, \
63 | memory_index=-1,\
64 | )
65 |
66 | self.key_linear = self.linear
67 | self.value_linear = self.linear
68 |
69 |
70 |
71 | self.gates = Gates (config)
72 |
73 | self.din = din
74 | self.num_attnt_heads = num_attnt_heads
75 | self.config = config
76 | self.memory_index = memory_index
77 |
78 |
79 | head_dim = config.hidden_size // config.num_attention_heads
80 | basemodel_head_dim = din // num_attnt_heads
81 |
82 | self.attnt_module = Attention (config, normalize=True, proj_conv2d=True, proj_conv_dim=head_dim, proj_transpose=True)
83 |
84 | assert din % head_dim == 0, \
85 | "a bug! 'din' should be divisible by head dimensions"
86 |
87 | num_partitions = din // head_dim
88 |
89 | assert num_attnt_heads % num_partitions == 0, \
90 | "Num of attention heads should be divisible by num of partitions"
91 |
92 | num_attnt_heads_per_partition = num_attnt_heads // num_partitions
93 |
94 | #--------------------------------#--------------------------------#
95 | #For all Attention heads on the embeddings
96 | #Query uses the first set of din coordinates and splits them among the first 'num_attnt_heads' attention heads
97 | #Key uses the second set of din coordinates and splits them among the first 'num_attnt_heads' attention heads
98 | #Value uses the third set of din coordinates and splits them among the first 'num_attnt_heads' attention heads
99 | #Key and Query of embeddings ignore the position dependence.
100 | #--------------------------------#--------------------------------#
101 |
102 | q_attn_head = torch.zeros((head_dim, config.num_attention_heads, config.num_attention_heads))
103 | for i in range(num_partitions):
104 | for j in range(num_attnt_heads_per_partition):
105 | q_attn_head[ :, i * num_attnt_heads_per_partition + j, i ] = 1.
106 |
107 |
108 |
109 | q_attn = torch.zeros((config.num_attention_heads, head_dim, head_dim))
110 | for i in range(num_attnt_heads):
111 | partition = i % num_attnt_heads_per_partition
112 | q_attn[ i, :basemodel_head_dim, partition*basemodel_head_dim: (partition + 1)*basemodel_head_dim ] = torch.eye(basemodel_head_dim)
113 |
114 |
115 | k_attn_head = torch.zeros((head_dim, config.num_attention_heads, config.num_attention_heads))
116 | for i in range(num_partitions):
117 | for j in range(num_attnt_heads_per_partition):
118 | k_attn_head[ :, i * num_attnt_heads_per_partition + j, i + num_partitions] = 1.
119 |
120 |
121 |
122 | k_attn = torch.zeros((config.num_attention_heads, head_dim, head_dim))
123 | for i in range(num_attnt_heads):
124 | partition = i % num_attnt_heads_per_partition
125 | k_attn[ i, :basemodel_head_dim, partition*basemodel_head_dim: (partition + 1)*basemodel_head_dim ] = torch.eye(basemodel_head_dim)
126 |
127 |
128 |
129 | v_attn_head = torch.zeros((head_dim, config.num_attention_heads, config.num_attention_heads))
130 | for i in range(num_partitions):
131 | for j in range(num_attnt_heads_per_partition):
132 | v_attn_head[ :, i * num_attnt_heads_per_partition + j, i + 2 * num_partitions] = 1.
133 |
134 |
135 | v_attn = torch.zeros((config.num_attention_heads, head_dim, head_dim))
136 | for i in range(num_attnt_heads):
137 | partition = i % num_attnt_heads_per_partition
138 | v_attn[ i, partition*basemodel_head_dim: (partition + 1)*basemodel_head_dim, partition*basemodel_head_dim: (partition + 1)*basemodel_head_dim ] = torch.eye(basemodel_head_dim)
139 |
140 |
141 | #c_attn_init, c_attn_bias = torch.cat([query, key, value], axis=0), torch.zeros(5 * config.hidden_size)
142 |
143 |
144 | #--------------------------------#--------------------------------#
145 | #For all Attention heads on the positions
146 | #Query, Key are set such that we never attend to the blank tokens!
147 | #--------------------------------#--------------------------------#
148 |
149 | #--------------------------------#--------------------------------#
150 | #The projection matrix takes the output of the attention heads, which has the required signal only in its first basemodel_head_dim coordiantes
151 | #We merge them together and return them at the head of the embedding
152 | #--------------------------------#--------------------------------#
153 | c_proj_init = torch.zeros((head_dim, config.num_attention_heads, config.num_attention_heads))
154 | for i in range(num_partitions):
155 | c_proj_init[:, i, i*num_attnt_heads_per_partition: (i+1)*num_attnt_heads_per_partition] = 1.
156 |
157 |
158 |
159 | self.attnt_module.initialize_weights(q_attn_init=q_attn,\
160 | q_attn_init_head=q_attn_head,\
161 | k_attn_init=k_attn,\
162 | k_attn_init_head=k_attn_head,\
163 | v_attn_init=v_attn,\
164 | v_attn_init_head=v_attn_head,\
165 | c_proj_init=c_proj_init )
166 |
167 |
168 | #Initialize Gates
169 | #Ignore the changes for the prefixes!
170 | #w, u, v, w_bias, u_bias, v_bias
171 | w = torch.zeros((1, 2*config.hidden_size))
172 | u = torch.zeros((1, 2*config.hidden_size))
173 | v = torch.zeros((1, 2*config.position_dim))
174 | w_bias = torch.zeros(2)
175 | u_bias = torch.zeros(2)
176 | v_bias = torch.zeros(2)
177 |
178 | #Input Gate is 1 on prefixes and 0 for non-prefixes
179 | v [0, config.seq_length: config.position_dim] = config.gate_scale * torch.ones(config.num_prefixes)
180 |
181 | #Change Gate is 0 on prefixes and 1 for non-prefixes
182 | v [0, config.position_dim+config.seq_length: 2*config.position_dim] = - config.gate_scale * torch.ones(config.num_prefixes)
183 | v_bias [1] += config.gate_scale
184 |
185 | self.gates.initialize_weights (w, u, v, w_bias, u_bias, v_bias)
186 |
187 |
188 |
189 |
190 | def forward(self, hidden_states, position_states, key_weights=None, value_weights=None, icl_mask=None):
191 |
192 | linear_output = self.linear.forward(hidden_states, position_states)
193 |
194 |
195 | if not self.separate_QK:
196 | inp_hidden_states = torch.cat( [key_weights, hidden_states[:, self.config.num_prefixes:] ], axis=1)
197 | key_out = self.key_linear(inp_hidden_states, position_states)
198 | assert torch.sum(linear_output[:, self.config.num_prefixes:, self.din:]).item() < 1e-10,\
199 | "Key portion not empty!"
200 | linear_output[:, self.config.num_prefixes:, self.din:] += key_out[:, self.config.num_prefixes:, :-self.din]
201 |
202 |
203 |
204 | inp_hidden_states = torch.cat( [value_weights, hidden_states[:, self.config.num_prefixes:] ], axis=1)
205 | value_out = self.value_linear(inp_hidden_states, position_states)
206 |
207 | assert torch.sum(linear_output[:, self.config.num_prefixes:, 2*self.din:]).item() < 1e-10,\
208 | "Value portion not empty!"
209 | linear_output[:, self.config.num_prefixes:, 2*self.din:] += value_out[:, self.config.num_prefixes:, :-2*self.din]
210 |
211 | #Send a mask such that the tokens don't attend to the blank tokens
212 | normalization_mask = torch.zeros( (1, 1, len(hidden_states[0]), len(hidden_states[0]) ) )
213 | normalization_mask[:, :, :, :self.config.num_prefixes] = torch.finfo(self.attnt_module.p_attn.weight.dtype).min
214 |
215 | #icl_mask needs to be a 3D tensor of shape (batch_size, seqlen, seqlen)
216 | #icl_mask[i, j] = 1 if token i tends to token j
217 |
218 | if icl_mask is not None:
219 |
220 | bt = icl_mask.shape[0]
221 | for i in range( bt ):
222 | sq1 = icl_mask[i].shape[0]
223 | sq2 = icl_mask[i].shape[1]
224 | nb = self.config.num_prefixes
225 |
226 | normalization_mask[i, :, nb: nb+sq1, nb: nb+sq2] = torch.tril( torch.round(torch.clip(1. - icl_mask[i], 0., 1.)) ) * torch.finfo(self.attnt_module.p_attn.weight.dtype).min
227 |
228 | #print ("------Attention-------")
229 | attnt_output = self.attnt_module.forward(linear_output, position_states, normalization_mask=normalization_mask) [0]
230 |
231 | if self.memory_index != -1:
232 | #keep Qx, Kx in memory!
233 | #Keep also x separately afterwards!
234 | assert torch.sum(attnt_output[:, self.config.num_prefixes:, self.memory_index:]).item() < 1e-10,\
235 | "Memory portion not empty!"
236 |
237 | attnt_output[:, self.config.num_prefixes:, self.memory_index: self.memory_index+2*self.din] += linear_output[:, self.config.num_prefixes:, :2*self.din]
238 | attnt_output[:, self.config.num_prefixes:, self.memory_index+2*self.din: self.memory_index+3*self.din] += hidden_states[:, self.config.num_prefixes:, :self.din]
239 |
240 | gate_output = self.gates.forward(linear_output, attnt_output, position_states)
241 |
242 | return gate_output
243 |
--------------------------------------------------------------------------------
/tint_main/utils/self_attention/backward.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import math
4 | import os
5 | from dataclasses import dataclass
6 | from typing import Optional, Tuple, Union
7 |
8 | import torch
9 | import torch.utils.checkpoint
10 | from torch import nn
11 | from torch.cuda.amp import autocast
12 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
13 | from ..modules import *
14 | from ..linear import *
15 |
16 |
17 |
18 | #Implements the stop-attention gradient, where we don't compute the gradient w.r.t. the attention scores
19 | #The module contains one attention module, where the attention scores are re-computed between query and key vectors, transposed before dispersing the gradients.
20 |
21 | #All arguments
22 | #self, \
23 | #config, \ #TinT config file
24 | #din, \ #input dimension of auxiliary's linear layer
25 | #num_attnt_heads, \ #number of attention heads in the Auxiliary's self-attention
26 | #use_softmax=False, \ #linear self_attention used
27 | #retain_nablay=False, \ #Retain nablay for Descent pass
28 | #memory_index=-1,\ #Start index where activations are stored in Linear Forward.
29 |
30 | class AttentionBackward(nn.Module):
31 |
32 | def __init__ (self, \
33 | config, \
34 | din, \
35 | num_attnt_heads, \
36 | use_softmax, \
37 | retain_nablay=False, \
38 | memory_index=-1,\
39 | ):
40 | super(AttentionBackward, self).__init__()
41 |
42 | assert use_softmax==False ,\
43 | "Currently I only use linear attention in this module"
44 |
45 |
46 |
47 | self.attnt_gates = Gates (config)
48 |
49 | self.retain_nablay = retain_nablay
50 | self.memory_index = memory_index
51 | self.config = config
52 | self.din = din
53 |
54 | ##### First attention module #######
55 | ########### Assumption #############
56 | #The memory part has the following format [Qx, Kx] which helps us to re-compute the attention scores
57 | #We compute \sum_j a_{j, i} \nabla y_j
58 | ########### Assumption #############
59 |
60 |
61 | head_dim = config.hidden_size // config.num_attention_heads
62 | basemodel_head_dim = din // num_attnt_heads
63 |
64 | self.attnt_module = Attention (config, peak_into_future=True, normalize=True, attnt_back=True, proj_conv2d=True, proj_conv_dim=head_dim, proj_transpose=True)
65 |
66 |
67 | assert self.memory_index <= config.hidden_size - ( 3 * self.din ), \
68 | "Not enough memory to simulate backward pass"
69 | assert self.memory_index == -1 or self.memory_index >= self.din, \
70 | "Memory is crossing current signal (and additional computation space)!"
71 |
72 |
73 | #--------------------------------#--------------------------------#
74 | #For all Attention heads on the embeddings
75 | #Query uses the first set of din coordinates in memory and splits them among the first 'num_attnt_heads' attention heads
76 | #Key uses the second set of din coordinates in memory and splits them among the first 'num_attnt_heads' attention heads
77 | #Value uses the first set of din coordinates and splits them among the first 'num_attnt_heads' attention heads
78 | #Key and Query of embeddings ignore the position dependence.
79 | #--------------------------------#--------------------------------#
80 |
81 | num_partitions = din // head_dim
82 |
83 | assert num_attnt_heads % num_partitions == 0, \
84 | "Num of attention heads should be divisible by num of partitions"
85 |
86 | num_attnt_heads_per_partition = num_attnt_heads // num_partitions
87 |
88 |
89 | q_attn_head = torch.zeros((head_dim, config.num_attention_heads, config.num_attention_heads))
90 |
91 | assert memory_index % head_dim == 0,\
92 | "Memory index should be multiple of head_dim"
93 | mem_head_start = memory_index // head_dim
94 |
95 |
96 | for i in range(num_partitions):
97 | for j in range(num_attnt_heads_per_partition):
98 | q_attn_head[ :, i * num_attnt_heads_per_partition + j, i + mem_head_start ] = 1.
99 |
100 | q_attn = torch.zeros((config.num_attention_heads, head_dim, head_dim))
101 | for i in range(num_attnt_heads):
102 | partition = i % num_attnt_heads_per_partition
103 | q_attn[ i, :basemodel_head_dim, partition*basemodel_head_dim: (partition + 1)*basemodel_head_dim ] = torch.eye(basemodel_head_dim)
104 |
105 |
106 |
107 |
108 | k_attn_head = torch.zeros((head_dim, config.num_attention_heads, config.num_attention_heads))
109 | for i in range(num_partitions):
110 | for j in range(num_attnt_heads_per_partition):
111 | k_attn_head[ :, i * num_attnt_heads_per_partition + j, i + num_partitions + mem_head_start] = 1.
112 |
113 |
114 | k_attn = torch.zeros((config.num_attention_heads, head_dim, head_dim))
115 | for i in range(num_attnt_heads):
116 | partition = i % num_attnt_heads_per_partition
117 | k_attn[ i, :basemodel_head_dim, partition*basemodel_head_dim: (partition + 1)*basemodel_head_dim ] = torch.eye(basemodel_head_dim)
118 |
119 |
120 | value_head = 0
121 |
122 |
123 | v_attn_head = torch.zeros((head_dim, config.num_attention_heads, config.num_attention_heads))
124 | for i in range(num_partitions):
125 | for j in range(num_attnt_heads_per_partition):
126 | v_attn_head[ :, i * num_attnt_heads_per_partition + j, i + value_head ] = 1.
127 |
128 |
129 |
130 |
131 | v_attn = torch.zeros((config.num_attention_heads, head_dim, head_dim))
132 | for i in range(num_attnt_heads):
133 | partition = i % num_attnt_heads_per_partition
134 | v_attn[ i, partition*basemodel_head_dim: (partition + 1)*basemodel_head_dim, partition*basemodel_head_dim: (partition + 1)*basemodel_head_dim ] = torch.eye(basemodel_head_dim)
135 |
136 |
137 | c_proj_init = torch.zeros((head_dim, config.num_attention_heads, config.num_attention_heads))
138 |
139 | for i in range(num_partitions):
140 | c_proj_init[:, i + value_head, i*num_attnt_heads_per_partition: (i+1)*num_attnt_heads_per_partition] = 1.
141 |
142 |
143 |
144 | self.attnt_module.initialize_weights(q_attn_init=q_attn,\
145 | q_attn_init_head=q_attn_head,\
146 | k_attn_init=k_attn,\
147 | k_attn_init_head=k_attn_head,\
148 | v_attn_init=v_attn,\
149 | v_attn_init_head=v_attn_head,\
150 | c_proj_init=c_proj_init )
151 |
152 | #Initialize the first attention Gates
153 | #Ignore the changes for the prefixes!
154 | #w, u, v, w_bias, u_bias, v_bias
155 | w = torch.zeros((1, 2*config.hidden_size))
156 | u = torch.zeros((1, 2*config.hidden_size))
157 | v = torch.zeros((1, 2*config.position_dim))
158 | w_bias = torch.zeros(2)
159 | u_bias = torch.zeros(2)
160 | v_bias = torch.zeros(2)
161 |
162 | #if self.retain_nablay:
163 | #Input Gate is 1
164 | # v_bias [0] += config.gate_scale
165 | #else:
166 | #Input Gate is 1 on prefixesß and 0 for non-prefixes
167 | v [0, config.seq_length:config.position_dim] = config.gate_scale * torch.ones(config.num_prefixes)
168 |
169 | #Change Gate is 0 on prefixes and 1 for non-prefixes
170 | v [0, config.position_dim+config.seq_length: 2*config.position_dim] = -config.gate_scale * torch.ones(config.num_prefixes)
171 | v_bias [1] += config.gate_scale
172 |
173 | self.attnt_gates.initialize_weights (w, u, v, w_bias, u_bias, v_bias)
174 |
175 |
176 |
177 | def forward(self, hidden_states, position_states, attention_mask, icl_mask=None):
178 |
179 |
180 | #add a mask to avoid attention on blank tokens!
181 |
182 |
183 | normalization_mask = torch.zeros( (1, 1, len(hidden_states[0]), len(hidden_states[0]) ) )
184 | normalization_mask[:, :, :, :self.config.num_prefixes] = torch.finfo(self.attnt_module.p_attn.weight.dtype).min
185 |
186 | #icl_mask needs to be a 3D tensor of shape (batch_size, seqlen, seqlen)
187 | #icl_mask[i, j] = 1 if token i tends to token j
188 |
189 | if icl_mask is not None:
190 | bt = icl_mask.shape[0]
191 | for i in range( bt ):
192 | sq1 = icl_mask[i].shape[0]
193 | sq2 = icl_mask[i].shape[1]
194 | nb = self.config.num_prefixes
195 | normalization_mask[i, :, nb: nb+sq1, nb: nb+sq2] = torch.tril( torch.round(torch.clip(1. - icl_mask[i], 0., 1.)) ) * torch.finfo(self.attnt_module.p_attn.weight.dtype).min
196 |
197 | #print ("----Mask----", attention_mask)
198 | modified_attention_mask = attention_mask.detach().clone()
199 | modified_attention_mask[:, :, :, :self.config.num_prefixes] = 0.
200 |
201 | attnt_output = self.attnt_module.forward(hidden_states, \
202 | position_states, \
203 | attention_mask=modified_attention_mask, \
204 | normalization_mask=normalization_mask\
205 | ) [0]
206 |
207 |
208 | end_dim = self.memory_index + 3*self.din
209 |
210 | attnt_output[:, self.config.num_prefixes:, self.memory_index: end_dim] += hidden_states[:, self.config.num_prefixes:, self.memory_index: end_dim]
211 |
212 |
213 | gate_output = self.attnt_gates.forward(hidden_states, \
214 | attnt_output, \
215 | position_states\
216 | )
217 |
218 | return gate_output
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 | #Implements descent w.r.t. the value matrix
227 | #The module simply calls LinearDescent module on the current embeddings
228 | class AttentionDescent(nn.Module):
229 | def __init__ (self, config, din, num_attnt_heads, use_softmax, memory_index=-1, debug_zero=False, retain_nablay=False):
230 | super(AttentionDescent, self).__init__()
231 | self.linear = LinearDescent(config, din=din, dout=din, use_softmax=use_softmax, memory_index=memory_index+2*din, debug_zero=debug_zero)
232 |
233 |
234 | def forward(self, hidden_states, position_states, attention_mask, icl_mask=None):
235 | return self.linear.forward(hidden_states, position_states, attention_mask)
236 |
237 |
238 |
239 | #Combines Backward and Descent module, since Descent module uses the gradient from Backward pass.
240 | class AttentionBackward_Descent(nn.Module):
241 | def __init__ (self, config, din, num_attnt_heads, use_softmax, memory_index=-1, debug_zero=False, projection_matrix=None, retain_nablay=False):
242 | super(AttentionBackward_Descent, self).__init__()
243 |
244 | self.config = config
245 | self.memory_index = memory_index
246 | self.din = din
247 | self.attention_back = AttentionBackward(config, \
248 | din=din, \
249 | num_attnt_heads=num_attnt_heads, \
250 | memory_index=memory_index, \
251 | use_softmax=use_softmax, \
252 | retain_nablay=retain_nablay,\
253 | )
254 |
255 |
256 | self.linearback_descent = Linear_Descent_Backward(config, \
257 | din=din, \
258 | dout=din, \
259 | use_softmax=use_softmax, \
260 | memory_index=memory_index+2*din, \
261 | debug_zero=debug_zero, \
262 | projection_matrix=projection_matrix, \
263 | retain_nablay=retain_nablay, \
264 | )
265 |
266 |
267 | def forward(self, hidden_states, position_states, attention_mask, icl_mask=None):
268 | if self.config.backprop_through_attention:
269 | attention_backout = self.attention_back(hidden_states, position_states, attention_mask, icl_mask=icl_mask)
270 | else:
271 | attention_backout = hidden_states
272 |
273 |
274 | attention_descentout = self.linearback_descent(attention_backout, position_states, attention_mask)
275 |
276 |
277 | return attention_descentout
--------------------------------------------------------------------------------
/icl_eval/models/modeling_opt.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """ PyTorch OPT model."""
16 | import random
17 | from typing import List, Optional, Tuple, Union
18 |
19 | import torch
20 | import torch.utils.checkpoint
21 | from torch import nn
22 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
23 |
24 | from transformers.activations import ACT2FN
25 | from transformers.modeling_outputs import (
26 | BaseModelOutputWithPast,
27 | CausalLMOutputWithPast,
28 | QuestionAnsweringModelOutput,
29 | SequenceClassifierOutputWithPast,
30 | )
31 | from transformers.modeling_utils import PreTrainedModel
32 | from transformers.utils import (
33 | add_code_sample_docstrings,
34 | add_start_docstrings,
35 | add_start_docstrings_to_model_forward,
36 | logging,
37 | replace_return_docstrings,
38 | )
39 | from transformers.models.opt.configuration_opt import OPTConfig
40 | from transformers.models.opt.modeling_opt import OPTPreTrainedModel, OPTDecoder, OPTModel, OPTForCausalLM
41 |
42 | logger = logging.get_logger(__name__)
43 |
44 | class LocalOPTDecoder(OPTDecoder):
45 | # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
46 | def forward(
47 | self,
48 | input_ids: torch.LongTensor = None,
49 | attention_mask: Optional[torch.Tensor] = None,
50 | head_mask: Optional[torch.Tensor] = None,
51 | past_key_values: Optional[List[torch.FloatTensor]] = None,
52 | inputs_embeds: Optional[torch.FloatTensor] = None,
53 | use_cache: Optional[bool] = None,
54 | output_attentions: Optional[bool] = None,
55 | output_hidden_states: Optional[bool] = None,
56 | return_dict: Optional[bool] = None,
57 | twod_attention_mask: Optional[torch.Tensor] = None
58 | ) -> Union[Tuple, BaseModelOutputWithPast]:
59 |
60 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
61 | output_hidden_states = (
62 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
63 | )
64 | use_cache = use_cache if use_cache is not None else self.config.use_cache
65 |
66 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
67 |
68 | # retrieve input_ids and inputs_embeds
69 | if input_ids is not None and inputs_embeds is not None:
70 | raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
71 | elif input_ids is not None:
72 | input_shape = input_ids.size()
73 | input_ids = input_ids.view(-1, input_shape[-1])
74 | elif inputs_embeds is not None:
75 | input_shape = inputs_embeds.size()[:-1]
76 | else:
77 | raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
78 |
79 | if inputs_embeds is None:
80 | inputs_embeds = self.embed_tokens(input_ids)
81 |
82 | batch_size, seq_length = input_shape
83 | past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
84 | # required mask seq length can be calculated via length of past
85 | mask_seq_length = past_key_values_length + seq_length
86 |
87 | # embed positions
88 | if attention_mask is None:
89 | attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
90 | causal_attention_mask = self._prepare_decoder_attention_mask(
91 | attention_mask, input_shape, inputs_embeds, past_key_values_length
92 | )
93 | if twod_attention_mask is not None:
94 | causal_attention_mask += twod_attention_mask
95 | pos_embeds = self.embed_positions(attention_mask, past_key_values_length)
96 |
97 | if self.project_in is not None:
98 | inputs_embeds = self.project_in(inputs_embeds)
99 |
100 | hidden_states = inputs_embeds + pos_embeds
101 |
102 | if self.gradient_checkpointing and self.training:
103 | if use_cache:
104 | logger.warning_once(
105 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
106 | )
107 | use_cache = False
108 |
109 | # decoder layers
110 | all_hidden_states = () if output_hidden_states else None
111 | all_self_attns = () if output_attentions else None
112 | next_decoder_cache = () if use_cache else None
113 |
114 | # check if head_mask has a correct number of layers specified if desired
115 | for attn_mask, mask_name in zip([head_mask], ["head_mask"]):
116 | if attn_mask is not None:
117 | if attn_mask.size()[0] != (len(self.layers)):
118 | raise ValueError(
119 | f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
120 | f" {head_mask.size()[0]}."
121 | )
122 |
123 | for idx, decoder_layer in enumerate(self.layers):
124 | # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
125 | if output_hidden_states:
126 | all_hidden_states += (hidden_states,)
127 |
128 | dropout_probability = random.uniform(0, 1)
129 | if self.training and (dropout_probability < self.layerdrop):
130 | continue
131 |
132 | past_key_value = past_key_values[idx] if past_key_values is not None else None
133 |
134 | if self.gradient_checkpointing and self.training:
135 |
136 | def create_custom_forward(module):
137 | def custom_forward(*inputs):
138 | # None for past_key_value
139 | return module(*inputs, output_attentions, None)
140 |
141 | return custom_forward
142 |
143 | layer_outputs = torch.utils.checkpoint.checkpoint(
144 | create_custom_forward(decoder_layer),
145 | hidden_states,
146 | causal_attention_mask,
147 | head_mask[idx] if head_mask is not None else None,
148 | None,
149 | )
150 | else:
151 | layer_outputs = decoder_layer(
152 | hidden_states,
153 | attention_mask=causal_attention_mask,
154 | layer_head_mask=(head_mask[idx] if head_mask is not None else None),
155 | past_key_value=past_key_value,
156 | output_attentions=output_attentions,
157 | use_cache=use_cache,
158 | )
159 |
160 | hidden_states = layer_outputs[0]
161 |
162 | if use_cache:
163 | next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
164 |
165 | if output_attentions:
166 | all_self_attns += (layer_outputs[1],)
167 |
168 | if self.final_layer_norm is not None:
169 | hidden_states = self.final_layer_norm(hidden_states)
170 |
171 | if self.project_out is not None:
172 | hidden_states = self.project_out(hidden_states)
173 |
174 | # add hidden states from the last decoder layer
175 | if output_hidden_states:
176 | all_hidden_states += (hidden_states,)
177 |
178 | next_cache = next_decoder_cache if use_cache else None
179 | if not return_dict:
180 | return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
181 | return BaseModelOutputWithPast(
182 | last_hidden_state=hidden_states,
183 | past_key_values=next_cache,
184 | hidden_states=all_hidden_states,
185 | attentions=all_self_attns,
186 | )
187 |
188 |
189 | class LocalOPTModel(OPTModel):
190 | def __init__(self, config: OPTConfig):
191 | super().__init__(config)
192 | self.decoder = LocalOPTDecoder(config)
193 | # Initialize weights and apply final processing
194 | self.post_init()
195 |
196 | def forward(
197 | self,
198 | input_ids: torch.LongTensor = None,
199 | attention_mask: Optional[torch.Tensor] = None,
200 | head_mask: Optional[torch.Tensor] = None,
201 | past_key_values: Optional[List[torch.FloatTensor]] = None,
202 | inputs_embeds: Optional[torch.FloatTensor] = None,
203 | use_cache: Optional[bool] = None,
204 | output_attentions: Optional[bool] = None,
205 | output_hidden_states: Optional[bool] = None,
206 | return_dict: Optional[bool] = None,
207 | twod_attention_mask: Optional[torch.Tensor] = None
208 | ) -> Union[Tuple, BaseModelOutputWithPast]:
209 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
210 | output_hidden_states = (
211 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
212 | )
213 | use_cache = use_cache if use_cache is not None else self.config.use_cache
214 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
215 |
216 | # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
217 | decoder_outputs = self.decoder(
218 | input_ids=input_ids,
219 | attention_mask=attention_mask,
220 | head_mask=head_mask,
221 | past_key_values=past_key_values,
222 | inputs_embeds=inputs_embeds,
223 | use_cache=use_cache,
224 | output_attentions=output_attentions,
225 | output_hidden_states=output_hidden_states,
226 | return_dict=return_dict,
227 | twod_attention_mask=twod_attention_mask
228 | )
229 |
230 | if not return_dict:
231 | return decoder_outputs
232 |
233 | return BaseModelOutputWithPast(
234 | last_hidden_state=decoder_outputs.last_hidden_state,
235 | past_key_values=decoder_outputs.past_key_values,
236 | hidden_states=decoder_outputs.hidden_states,
237 | attentions=decoder_outputs.attentions,
238 | )
239 |
240 | class LocalOPTForCausalLM(OPTForCausalLM):
241 | def __init__(self, config):
242 | super().__init__(config)
243 | self.model = LocalOPTModel(config)
244 |
245 | # the lm_head weight is automatically tied to the embed tokens weight
246 | self.lm_head = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False)
247 |
248 | # Initialize weights and apply final processing
249 | self.post_init()
250 |
251 | def forward(
252 | self,
253 | input_ids: torch.LongTensor = None,
254 | attention_mask: Optional[torch.Tensor] = None,
255 | head_mask: Optional[torch.Tensor] = None,
256 | past_key_values: Optional[List[torch.FloatTensor]] = None,
257 | inputs_embeds: Optional[torch.FloatTensor] = None,
258 | labels: Optional[torch.LongTensor] = None,
259 | use_cache: Optional[bool] = None,
260 | output_attentions: Optional[bool] = None,
261 | output_hidden_states: Optional[bool] = None,
262 | return_dict: Optional[bool] = None,
263 | twod_attention_mask: Optional[torch.Tensor] = None
264 | ) -> Union[Tuple, CausalLMOutputWithPast]:
265 |
266 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
267 | output_hidden_states = (
268 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
269 | )
270 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
271 |
272 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
273 |
274 | outputs = self.model.decoder(
275 | input_ids=input_ids,
276 | attention_mask=attention_mask,
277 | head_mask=head_mask,
278 | past_key_values=past_key_values,
279 | inputs_embeds=inputs_embeds,
280 | use_cache=use_cache,
281 | output_attentions=output_attentions,
282 | output_hidden_states=output_hidden_states,
283 | return_dict=return_dict,
284 | twod_attention_mask=twod_attention_mask
285 | )
286 |
287 | logits = self.lm_head(outputs[0]).contiguous()
288 |
289 | loss = None
290 | if labels is not None:
291 | # Shift so that tokens < n predict n
292 | shift_logits = logits[..., :-1, :].contiguous()
293 | shift_labels = labels[..., 1:].contiguous()
294 | # Flatten the tokens
295 | loss_fct = CrossEntropyLoss()
296 | loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
297 |
298 | if not return_dict:
299 | output = (logits,) + outputs[1:]
300 | return (loss,) + output if loss is not None else output
301 |
302 | return CausalLMOutputWithPast(
303 | loss=loss,
304 | logits=logits,
305 | past_key_values=outputs.past_key_values,
306 | hidden_states=outputs.hidden_states,
307 | attentions=outputs.attentions,
308 | )
309 |
--------------------------------------------------------------------------------
/tests/perplexity_dynamiceval.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from tqdm import tqdm
3 | from Conversion import *
4 | from transformers import AutoTokenizer, OPTForCausalLM
5 | from datasets import load_dataset
6 | import copy
7 | from transformers import HfArgumentParser
8 | from filelock import Timeout, FileLock
9 | from data_utils import *
10 | from dataclasses import dataclass, field
11 | from typing import Callable, Dict, Optional, Union, List
12 |
13 | @dataclass
14 | class ModelArguments:
15 | """
16 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
17 | """
18 | model_name_or_path: str = field(
19 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
20 | )
21 | config_name: Optional[str] = field(
22 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
23 | )
24 | cache_dir: Optional[str] = field(
25 | default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
26 | )
27 |
28 | device: Optional[str] = field(
29 | default="cuda", metadata={"help": "cuda/cpu"}
30 | )
31 |
32 |
33 | final_layers_to_train: Optional[int] = field(
34 | default=-1, metadata={"help": "Number of layers to train"}
35 | )
36 |
37 | gradient_steps : Optional[int] = field(
38 | default=1, metadata={"help": "Number of gradient steps"}
39 | )
40 |
41 | learning_rate: Optional[float] = field(
42 | default=1e-04, metadata={"help": "Learning rate for projection!"}
43 | )
44 |
45 | batch_size: Optional[int] = field(
46 | default=1, metadata={"help": "Batch size for training!"}
47 | )
48 |
49 | train_fraction: Optional[float] = field(
50 | default=0.5, metadata={"help": "Fraction of sentence for training!"}
51 | )
52 |
53 | dynamic_chunks: Optional[int] = field(
54 | default=1, metadata={"help": "Number of dynamic chunks (before the final test per sequence)!"}
55 | )
56 |
57 |
58 |
59 |
60 | @dataclass
61 | class DynamicArguments:
62 |
63 | data_cache_dir: Optional[str] = field(
64 | default='/scratch/gpfs/ap34/icl-as-ft/Dynamic_initialization/data',
65 | metadata={
66 | 'help': 'where to store downloaded model, datasets, and tokenizer'
67 | }
68 | )
69 |
70 | log_dir: Optional[str] = field(
71 | default='/scratch/gpfs/smalladi/icl_as_ft/logs'
72 | )
73 |
74 | dataset: Optional[str] = field(
75 | default='wikitext-103',
76 | metadata={
77 | 'help': 'dataset to use, should be in HF datasets repo'
78 | }
79 | )
80 |
81 | use_eval_set: bool = field(
82 | default=True,
83 | metadata={'help': 'when on, uses the eval dataset only for ppl measurement.'}
84 | )
85 |
86 | use_test_set: bool = field(
87 | default=False,
88 | metadata={'help': 'when on, uses the test dataset only for ppl measurement.'}
89 | )
90 |
91 | ### finetuning strategies ###
92 |
93 | chunk_data: bool = field(
94 | default=True,
95 | metadata={'help': 'when on, chunks the data into block_size chunks'}
96 | )
97 |
98 | baseline: bool = field(
99 | default=False
100 | )
101 |
102 | num_subsets: int = field(
103 | default=16
104 | )
105 | data_chunk_index: int = field(
106 | default=0
107 | )
108 |
109 | block_size: Optional[int] = field(
110 | default=1024,
111 | metadata={'help': 'size of chunks to break corpus into'}
112 | )
113 |
114 | num_workers: Optional[int] = field(
115 | default=4,
116 | metadata={'help': 'number of workers'}
117 | )
118 |
119 | data_subset: Optional[int] = field(
120 | default=-1, metadata={"help": "Number of training examples to use in the subset!"}
121 | )
122 |
123 |
124 |
125 |
126 | parser = HfArgumentParser((ModelArguments, DynamicArguments))
127 | model_args, data_args = parser.parse_args_into_dataclasses()
128 |
129 | learning_rate = model_args.learning_rate
130 | gradient_steps = model_args.gradient_steps
131 | final_layers_to_train = model_args.final_layers_to_train
132 | train_fraction = model_args.train_fraction
133 | dynamic_chunks = model_args.dynamic_chunks
134 |
135 | device = model_args.device
136 | model_config = AutoConfig.from_pretrained(
137 | model_args.config_name if model_args.config_name else model_args.model_name_or_path,
138 | cache_dir=model_args.cache_dir
139 | )
140 |
141 |
142 | if 'gpt' in model_args.model_name_or_path:
143 | model_fn = GPT2LMHeadModel
144 | elif 'opt' in model_args.model_name_or_path:
145 | model_fn = OPTForCausalLM
146 | else:
147 | raise NotImplmentedError
148 |
149 | simulated_gpt2 = model_fn.from_pretrained(
150 | model_args.model_name_or_path,
151 | config=model_config,
152 | cache_dir=model_args.cache_dir,
153 | )
154 |
155 | simulated_gpt2.to(device)
156 | simulated_gpt2.eval()
157 |
158 |
159 | #dataset = load_dataset("wikitext", "wikitext-2-raw-v1", cache_dir='data', download_mode='reuse_cache_if_exists')
160 | #test_data = dataset['test']
161 | #test_data = dataset['test']
162 | #valid_data = dataset['validation']
163 |
164 | batch_size=model_args.batch_size
165 |
166 | assert batch_size == 1, \
167 | "Assume batch size as 1 for proper perplexity computation (lazy to do for multi input batch)"
168 |
169 | # Download vocabulary from huggingface.co and cache.
170 | tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, cache_dir="../..")
171 | if 'gpt' in model_args.model_name_or_path:
172 | tokenizer.pad_token = tokenizer.eos_token
173 | pad_token_id=tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
174 |
175 | elif 'opt' in model_args.model_name_or_path:
176 | tokenizer.bos_token_id = 0
177 |
178 | dataset = preprocess_dataset(data_args, tokenizer)
179 | #_, simulated_gpt2, model_config, config = Construct_NASgpt()
180 |
181 |
182 | simulated_gpt2.eval()
183 | device=next(simulated_gpt2.parameters()).device
184 |
185 |
186 |
187 | num_valid_batches = len(dataset) // batch_size
188 |
189 | if data_args.data_subset != -1:
190 | num_valid_batches = min(data_args.data_subset, num_valid_batches)
191 |
192 | avg_model_perplexity = 0.
193 | avg_eval_perplexity = 0.
194 | total_words = 0.
195 |
196 |
197 | avg_model_test_perplexity = 0.
198 | avg_eval_test_perplexity = 0.
199 | total_test_words = 0.
200 |
201 |
202 |
203 | for batch_id in tqdm( range( num_valid_batches ) ):
204 | if data_args.dataset == 'c4' and batch_id == 100: break
205 | model_copy = copy.deepcopy(simulated_gpt2)
206 | model_copy.eval()
207 |
208 | trainable_parameters = []
209 | trainable_name = []
210 | if final_layers_to_train == -1:
211 | final_layers_to_train = 12
212 |
213 | #for i in range( 12 - final_layers_to_train, 12 ):
214 | for n, p in model_copy.named_parameters():
215 | if 'ln_' not in n and '.h.' in n:
216 | layer_num = int(n.split('.h.')[-1].split('.')[0])
217 | if layer_num >= 12 - final_layers_to_train:
218 | trainable_parameters += [ p ]
219 | trainable_name += [n]
220 |
221 |
222 | optimizer = torch.optim.SGD(model_copy.parameters(), lr=learning_rate)
223 | model_copy.zero_grad()
224 |
225 |
226 |
227 | data = dataset [ batch_id * batch_size : (batch_id + 1) * batch_size ]
228 | batch_sentences = torch.tensor( data ['input_ids'] )
229 | attention_mask = torch.tensor( data ['attention_mask'] )
230 | labels = torch.tensor( data ['labels'] )
231 |
232 |
233 |
234 |
235 |
236 |
237 | if len(batch_sentences.shape) == 1:
238 | batch_sentences = batch_sentences.view((1, -1))
239 | attention_mask = attention_mask.view((1, -1))
240 | labels = labels.view((1, -1))
241 |
242 | train_subchunk_fraction = train_fraction / dynamic_chunks
243 |
244 | #initialize the dynamic loss
245 | final_loss = 0.
246 | og_loss = 0.
247 | total_terms = 0.
248 | loss_fct = torch.nn.CrossEntropyLoss(reduction='sum')
249 |
250 |
251 | og_test_loss = 0.
252 | final_test_loss = 0.
253 | #original model's loss
254 | #with torch.no_grad():
255 | # train_seq_loss = model_copy(batch_sentences.cuda(), \
256 | # output_hidden_states=False, \
257 | # labels=labels.long().to(device))[0].item()
258 |
259 | #dynamic training
260 | for dynamic_chunk_id in range(dynamic_chunks):
261 |
262 | batch_seq_lengths = torch.sum(attention_mask.int(), dim=-1)
263 | mask = torch.zeros_like(attention_mask)
264 |
265 | for i in range(len(batch_seq_lengths)):
266 | len_chunk = int(batch_seq_lengths[i] * train_subchunk_fraction)
267 | actual_fraction = len_chunk / batch_seq_lengths[i]
268 | #print (actual_fraction, train_subchunk_fraction)
269 | mask[i, dynamic_chunk_id * len_chunk: dynamic_chunk_id * len_chunk + len_chunk] = 1.
270 | bidirection_mask = mask.float()
271 |
272 | target = labels.detach().clone()
273 | target[ torch.where(bidirection_mask == 0.) ] = -100
274 |
275 |
276 | #print (target)
277 | #with torch.no_grad():
278 | input_ids = batch_sentences.to(device)
279 | target = target.to(device)
280 |
281 | #first a simple evaluation on the current model
282 | with torch.no_grad():
283 | output = model_copy(input_ids, \
284 | attention_mask=attention_mask.to(device), \
285 | output_hidden_states=False,
286 | )
287 |
288 | print (output[0].shape, input_ids.shape)
289 | final_loss += loss_fct( output[0][:, :-1].view((-1, model_config.vocab_size)), \
290 | target[:, 1:].long().view((-1,)) \
291 | ).item()
292 |
293 | gpt_output = simulated_gpt2(input_ids, \
294 | attention_mask=attention_mask.to(device), \
295 | output_hidden_states=False, \
296 | )
297 | og_loss += loss_fct( gpt_output[0][:, :-1].view((-1, model_config.vocab_size)), \
298 | target[:, 1:].long().view((-1,)) \
299 | ).item()
300 | total_terms += bidirection_mask.sum()
301 |
302 |
303 | for _ in range(gradient_steps):
304 | simulated_output = model_copy(input_ids, \
305 | attention_mask=attention_mask.to(device), \
306 | output_hidden_states=True, \
307 | labels=target.long()\
308 | )
309 | small_model_loss = simulated_output[0]
310 | #print (small_model_loss.item())
311 | small_model_loss.backward()
312 | optimizer.step()
313 | optimizer.zero_grad()
314 |
315 | for n, p in model_copy.named_parameters():
316 | for n_, p_ in simulated_gpt2.named_parameters():
317 | if n == n_ and n in trainable_name:
318 | print ( n, torch.max (torch.absolute(p - p_)) )
319 |
320 |
321 |
322 |
323 | #print ([p for n, p in model_copy.named_parameters() if 'ln_f' in n])
324 | #test on the remaining chunk
325 |
326 | batch_seq_lengths = torch.sum(attention_mask.int(), dim=-1)
327 | mask = torch.zeros_like(attention_mask)
328 |
329 |
330 | for i in range(len(batch_seq_lengths)):
331 | len_chunk = dynamic_chunks * int(batch_seq_lengths[i] * train_subchunk_fraction)
332 | test_fraction = 1. - len_chunk / (1. * batch_seq_lengths[i])
333 | mask[i, len_chunk:] = 1.
334 | bidirection_mask = mask.float()
335 |
336 |
337 | with torch.no_grad():
338 | target = labels.detach().clone()
339 | target[ torch.where(bidirection_mask == 0.) ] = -100
340 | target = target.to(device)
341 |
342 |
343 |
344 | simulated_output = model_copy(input_ids, \
345 | attention_mask=attention_mask.to(device), \
346 | output_hidden_states=False,
347 | )
348 | loss = loss_fct( simulated_output[0][:, :-1].view((-1, model_config.vocab_size)), \
349 | target[:, 1:].long().view((-1,)) \
350 | ).item()
351 | final_loss += loss
352 | avg_eval_test_perplexity += loss
353 |
354 | gpt_output = simulated_gpt2(input_ids, \
355 | attention_mask=attention_mask.to(device), \
356 | output_hidden_states=False,
357 | )
358 |
359 | loss = loss_fct( gpt_output[0][:, :-1].view((-1, model_config.vocab_size)), \
360 | target[:, 1:].long().view((-1,)) \
361 | ).item()
362 | og_loss += loss
363 | avg_model_test_perplexity += loss
364 |
365 | total_terms += bidirection_mask.sum()
366 | total_test_words += bidirection_mask.sum()
367 |
368 | del (model_copy)
369 | avg_model_perplexity += og_loss
370 | avg_eval_perplexity += final_loss
371 | total_words += total_terms
372 |
373 |
374 | final_result = {}
375 | final_result[ 'Validation Dynamic eval acc' ] = np.exp(avg_eval_perplexity / total_words)
376 | final_result[ 'Validation Model acc' ] = np.exp(avg_model_perplexity / total_words)
377 |
378 | final_result[ 'Validation Dynamic eval acc (on test)' ] = np.exp(avg_eval_test_perplexity / total_test_words)
379 | final_result[ 'Validation Model acc (on test)' ] = np.exp(avg_model_test_perplexity / total_test_words)
380 |
381 |
382 | with FileLock('log_exp.lock'):
383 | with open('log_exp', 'a') as f:
384 | final_result.update(vars(model_args))
385 | final_result.update(vars(data_args))
386 | f.write(str(final_result) + '\n')
--------------------------------------------------------------------------------
/tint_main/utils/linear/forward.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | import torch
4 |
5 | import math
6 | import os
7 | from dataclasses import dataclass
8 | from typing import Optional, Tuple, Union
9 |
10 | import torch
11 | import torch.utils.checkpoint
12 | from torch import nn
13 | from torch.cuda.amp import autocast
14 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
15 | from ..modules import *
16 | import numpy as np
17 |
18 |
19 |
20 | #------------------------------------------------------------------#
21 | #LinearForward module computes Wx_i at every position i
22 | #Important arguments: input dimension (din), output dimension (dout)
23 | #output: TinT LinearForward module, primarily containing a linear self_attention layer. (See figure 2 in the paper)
24 | #We assume that the rows of W have been stacked onto the prefix tokens,
25 | #before calling this module.
26 |
27 | #We also allow projection matrix on the linear operations, however we donot
28 | #use them in the current version.
29 | #------------------------------------------------------------------#
30 |
31 |
32 | #All arguments
33 | #self, \
34 | #config, \ #TinT config file
35 | #din, \ #input dimension
36 | #dout, \ #output dimension
37 | #use_softmax=False, \ #use_softmax=False, implying we use linear attention head
38 | #projection_matrix=None, \ #projection_matrix to project the linear operation (not used in the current code)
39 | #shift_top=0, \ #shifts the output to a shifted index if necessary
40 | #memory_index=-1, \ #memory_index to store activations for the backward and descent pass!
41 |
42 | class LinearForward(nn.Module):
43 |
44 | def __init__(self, \
45 | config, \
46 | din, \
47 | dout, \
48 | use_softmax=False, \
49 | projection_matrix=None, \
50 | shift_top=0, \
51 | memory_index=-1, \
52 | ):
53 | super(LinearForward, self).__init__()
54 |
55 | self.attnt_module = None #initialized later
56 | #We use gates to differentiate the operations on prefix embeddings and non-prefix embeddings.
57 | self.gates = Gates (config)
58 | self.din = din
59 | self.config = config
60 | self.projection_matrix = projection_matrix
61 | self.memory_index = memory_index
62 |
63 | #initialized later
64 | self.permutation_conv = None
65 | self.bias_add_conv = None
66 | self.projection_layer = None
67 | self.proj_conv2d = (config.hidden_size % dout == 0)
68 |
69 | assert use_softmax == False, \
70 | "Currently only works without softmax!"
71 |
72 | head_dim = config.num_prefixes
73 | assert config.hidden_size % head_dim == 0,\
74 | "Dimension should perfectly distribute over the prefixes"
75 |
76 | num_attention_heads = config.hidden_size // head_dim
77 |
78 |
79 |
80 | assert config.hidden_size >= din * (dout // config.num_prefixes), \
81 | "Total embedding size must be greater than the dimension necessary to store weights in the prefixes"
82 |
83 | assert dout % config.num_prefixes == 0, \
84 | "I assume uniform distribution of the weights over the prefixes"
85 |
86 | assert din % head_dim == 0,\
87 | "Currently this is a bug! I assume that the input dimension is easily divisible across the heads we want to distribute to"
88 | assert dout % head_dim == 0,\
89 | "Currently this is a bug! I assume that the output dimension is easily divisible across the heads we want to distribute to"
90 |
91 | num_wts_per_blank = dout // config.num_prefixes
92 | #initialize attention module
93 | self.attnt_module = Attention (config, \
94 | num_attention_heads=num_attention_heads, \
95 | normalize=use_softmax, \
96 | proj_conv_dim=config.num_prefixes, \
97 | proj_transpose=True, \
98 | proj_conv2d=self.proj_conv2d\
99 | )
100 |
101 | attnt_head_per_wt = din // head_dim
102 | useful_attnt_heads=attnt_head_per_wt * (dout // config.num_prefixes)
103 | extra_heads=dout // head_dim
104 |
105 | #print (config.num_attention_heads, extra_heads, useful_attnt_heads, attnt_head_per_wt)
106 | assert num_attention_heads >= extra_heads + useful_attnt_heads, \
107 | "Number of attention heads should be atleast the number of weights + biases present in each blank"
108 |
109 | assert config.num_prefixes <= head_dim ,\
110 | "Currently I assume the head dimension is atleast the number of prefixes in the original model"
111 |
112 | #--------------------------------#--------------------------------#
113 | #For all Attention heads on the embeddings
114 | #Query repeats the first din dimensions dout times, so that we can split them among the different attention heads
115 | #Key is Identity
116 | #Value is all zeros
117 | #Key and Query of embeddings ignore the position dependence.
118 |
119 | #Final attention head simply copies the bias present in the first blank
120 | #--------------------------------#--------------------------------#
121 | key_attn = torch.zeros((num_attention_heads, head_dim, head_dim))
122 |
123 | din_partition = din // attnt_head_per_wt
124 |
125 | key_attn[:useful_attnt_heads] = torch.eye(head_dim)
126 |
127 |
128 |
129 | query_attn_head = torch.zeros((head_dim, num_attention_heads, num_attention_heads))
130 | for i in range (dout // config.num_prefixes):
131 | query_attn_head[:, i*attnt_head_per_wt: (i+1)*attnt_head_per_wt, :attnt_head_per_wt] = torch.eye(attnt_head_per_wt)
132 |
133 |
134 | value_attn = torch.zeros((num_attention_heads, head_dim, head_dim))
135 | value_attn[useful_attnt_heads: useful_attnt_heads+extra_heads] = torch.eye(head_dim)
136 |
137 |
138 | #--------------------------------#--------------------------------#
139 | #For all Attention heads on the positions
140 | #Query is Identity (on the component corresponding to one-hot encodings
141 | #of the input sequence to the smaller model)
142 | #Key is Identity (on the component corresponding to one-hot encodings
143 | #of the input sequence to the smaller model) + all-ones on the the blank identifiers
144 | #Value moves the blank identifiers to the fore-front
145 | #Key and Query ignore dependence on the signal.
146 |
147 | #Final attention head simply copies the bias present in the first blank
148 | #--------------------------------#--------------------------------#
149 |
150 |
151 | query = torch.zeros((head_dim, config.position_dim))
152 | query[0, :config.seq_length] = 1.
153 |
154 | key = torch.zeros((head_dim, config.position_dim))
155 | key[:config.num_prefixes, config.seq_length: config.position_dim] = torch.eye(config.num_prefixes)
156 |
157 | value = torch.zeros((head_dim, config.position_dim))
158 | value[:config.num_prefixes, config.seq_length: config.position_dim] = torch.eye(config.num_prefixes)
159 |
160 | expand_ = torch.zeros((3, 1, num_attention_heads))
161 | #for query, we use position embedding only at heads useful_attnt_heads: useful_attnt_heads+extra_heads
162 | expand_[0, 0, useful_attnt_heads: useful_attnt_heads+extra_heads] = 1.
163 | #for key, we use position embedding only at heads useful_attnt_heads: useful_attnt_heads+extra_heads
164 | expand_[1, 0, useful_attnt_heads: useful_attnt_heads+extra_heads] = 1.
165 | #for value, we use position embedding only at heads :useful_attnt_heads
166 | expand_[2, 0, :useful_attnt_heads] = 1.
167 |
168 |
169 | p_attn_init = torch.cat([query, key, value], axis=0)
170 |
171 | #--------------------------------#--------------------------------#
172 | #The projection matrix after the attention module reorders such that
173 | #, , ..., appear in a sequential order.
174 | #--------------------------------#--------------------------------#
175 |
176 | if not self.proj_conv2d:
177 | c_proj_init = torch.zeros(( config.hidden_size, config.hidden_size ))
178 |
179 | for i in range(dout):
180 | num_useful_heads = dout // config.num_prefixes
181 | desd_loc = head_dim * attnt_head_per_wt * (i % num_useful_heads) + i // num_useful_heads
182 |
183 | for sub_head in range(attnt_head_per_wt):
184 | if use_softmax:
185 | c_proj_init[shift_top+i, desd_loc + sub_head * head_dim] = config.scale_embeddings
186 | else:
187 | c_proj_init[shift_top+i, desd_loc + sub_head * head_dim] = 1.
188 |
189 |
190 | for i in range(dout):
191 | desd_loc = head_dim * num_useful_heads * attnt_head_per_wt + i
192 | c_proj_init[shift_top+i, desd_loc] = 1.
193 |
194 | if projection_matrix is not None:
195 | projection_tensor = torch.zeros((config.hidden_size, config.hidden_size))
196 | projection_tensor[shift_top:shift_top+projection_matrix.shape[0], :projection_matrix.shape[1]] = torch.tensor(projection_matrix, dtype=c_proj_init.dtype)
197 | c_proj_init = projection_tensor @ c_proj_init
198 |
199 | else:
200 | assert head_dim % config.num_prefixes == 0, \
201 | "This is a bug! For simpler operation, I assume head_dim to be divisible by config.num_prefixes"
202 |
203 | num_partitions_head = head_dim // config.num_prefixes
204 | num_channels = config.hidden_size // config.num_prefixes
205 | num_wt_channels = dout // config.num_prefixes
206 | c_proj_init = torch.zeros(( config.num_prefixes, num_channels, num_channels ))
207 | for i in range(num_wt_channels):
208 | for j in range(attnt_head_per_wt * num_partitions_head):
209 | c_proj_init[:, i, i * attnt_head_per_wt * num_partitions_head + j ] = 1.
210 |
211 | c_proj_init[:, num_wt_channels + i, num_wt_channels * attnt_head_per_wt * num_partitions_head + i] = 1.
212 |
213 |
214 | num_abs_heads = config.hidden_size // dout
215 | shift_top_head = shift_top // dout
216 |
217 | #permute the final computation
218 | self.permutation_conv = Conv2D(nf=num_abs_heads, nx=dout, transpose=False)
219 | permutation_wt = torch.zeros((num_abs_heads, dout, dout))
220 | for i in range(dout):
221 | desd_loc = ( i % num_wt_channels ) * config.num_prefixes + i // num_wt_channels
222 | permutation_wt[0, i, desd_loc] = 1.
223 | permutation_wt[1] = torch.eye(dout)
224 | with torch.no_grad():
225 | self.permutation_conv.weight.copy_(permutation_wt.transpose(-1, -2))
226 |
227 |
228 | #add bias to the
229 | self.bias_add_conv = Conv2D(nf=num_abs_heads, nx=dout, transpose=True, use_einsum=self.config.use_einsum)
230 | bias_add_wt = torch.zeros((dout, num_abs_heads, num_abs_heads))
231 | bias_add_wt[:, shift_top_head, 0: 2 ] = 1.
232 | with torch.no_grad():
233 | self.bias_add_conv.weight.copy_(bias_add_wt.transpose(-1, -2))
234 |
235 | if projection_matrix is not None:
236 | start_index=shift_top
237 | if projection_matrix.shape[0] >= projection_matrix.shape[1]:
238 | self.projection_layer = up_projection (config, projection_matrix, signal_index=start_index, store_index=start_index)
239 | else:
240 | self.projection_layer = down_projection (config, projection_matrix, signal_index=start_index, store_index=start_index)
241 |
242 |
243 | self.attnt_module.initialize_weights(q_attn_init_head=query_attn_head, \
244 | k_attn_init=key_attn, \
245 | v_attn_init=value_attn,
246 | p_attn_init=p_attn_init, \
247 | p_expand_init=expand_,\
248 | c_proj_init=c_proj_init, \
249 | )
250 |
251 |
252 |
253 | #Initialize Gates
254 | #Ignore the changes for the prefixes!
255 | #w, u, v, w_bias, u_bias, v_bias
256 | w = torch.zeros((1, 2*config.hidden_size))
257 | u = torch.zeros((1, 2*config.hidden_size))
258 | v = torch.zeros((1, 2*config.position_dim))
259 | w_bias = torch.zeros(2)
260 | u_bias = torch.zeros(2)
261 | v_bias = torch.zeros(2)
262 |
263 | #Input Gate is 1 on prefixes and 0 for non-prefixes
264 | v [0, config.seq_length: config.position_dim] = config.gate_scale * torch.ones(config.num_prefixes)
265 |
266 |
267 | #Change Gate is 0 on prefixes and 1 for non-prefixes
268 | v [0, config.position_dim+config.seq_length: 2*config.position_dim] = -config.gate_scale * torch.ones(config.num_prefixes)
269 | v_bias [1] += config.gate_scale
270 |
271 | self.gates.initialize_weights (w, u, v, w_bias, u_bias, v_bias)
272 |
273 |
274 | def forward(self, hidden_states, position_embeddings):
275 | output = self.attnt_module.forward(hidden_states=hidden_states, positions=position_embeddings, restrict_prefixes=self.config.restrict_prefixes)[0]
276 |
277 | if self.permutation_conv is not None:
278 | output = self.permutation_conv(output)
279 | output = self.bias_add_conv( output )
280 | if self.projection_layer is not None:
281 | output = self.projection_layer(output)
282 | #store the input in memory for backward pass later on
283 | if self.memory_index != -1:
284 | assert torch.sum(output[:, self.config.num_prefixes:, self.memory_index: ]).item() < 1e-10,\
285 | "Memory portion not empty!"
286 |
287 | output[:, self.config.num_prefixes:, self.memory_index: self.memory_index + self.din ] += hidden_states[:, self.config.num_prefixes:, :self.din]
288 | return self.gates.forward(hidden_states=hidden_states, \
289 | output_states=output, \
290 | position=position_embeddings\
291 | )
292 |
--------------------------------------------------------------------------------
/icl_eval/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import contextlib
4 | from typing import Optional, Union
5 | import numpy as np
6 | from dataclasses import dataclass, is_dataclass, asdict
7 | import logging
8 | import time
9 | from torch.nn import CrossEntropyLoss
10 | import torch.nn.functional as F
11 | from transformers.modeling_outputs import CausalLMOutputWithPast
12 | import torch
13 | from transformers.utils import PaddingStrategy
14 | from transformers import PreTrainedTokenizerBase
15 | from typing import Optional, Union, List, Dict, Any
16 |
17 | logger = logging.getLogger(__name__)
18 |
19 | @dataclass
20 | class Prediction:
21 | correct_candidate: Union[int, str]
22 | predicted_candidate: Union[int, str]
23 | scores: List[float] = None
24 |
25 | @contextlib.contextmanager
26 | def count_time(name):
27 | logger.info("%s..." % name)
28 | start_time = time.time()
29 | try:
30 | yield
31 | finally:
32 | logger.info("Done with %.2fs" % (time.time() - start_time))
33 |
34 | @contextlib.contextmanager
35 | def temp_seed(seed):
36 | state = np.random.get_state()
37 | np.random.seed(seed)
38 | try:
39 | yield
40 | finally:
41 | np.random.set_state(state)
42 |
43 | def forward_wrap_with_option_len(self, input_ids=None, labels=None, option_len=None, num_options=None, return_dict=None, **kwargs):
44 | outputs = self.original_forward(input_ids=input_ids, **kwargs)
45 | if labels is None:
46 | return outputs
47 | logits = outputs.logits
48 |
49 | loss = None
50 | # shift so that tokens < n predict n
51 | shift_logits = logits[..., :-1, :].contiguous()
52 | # here we use input_ids (which should always = labels) bc sometimes labels are correct candidate IDs
53 | shift_labels = torch.clone(input_ids)[..., 1:].contiguous()
54 | shift_labels[shift_labels == self.config.pad_token_id] = -100
55 |
56 | # apply option len
57 | for _i, _len in enumerate(option_len):
58 | shift_labels[_i, :-_len] = -100
59 |
60 | # calculate the loss
61 | loss_fct = CrossEntropyLoss(ignore_index=-100)
62 | if num_options is not None:
63 | # Train as classificaiton tasks
64 | log_probs = F.log_softmax(shift_logits, dim=-1)
65 | mask = shift_labels != -100 # option part
66 | shift_labels[~mask] = 0 # so that it doesn't mess up with indexing
67 |
68 | selected_log_probs = torch.gather(log_probs, dim=-1, index=shift_labels.unsqueeze(-1)).squeeze(-1) # (bsz x num_options, len)
69 | selected_log_probs = (selected_log_probs * mask).sum(-1) # (bsz x num_options)
70 |
71 | num_options = num_options[0]
72 | selected_log_probs = selected_log_probs.view(-1, num_options) # (bsz, num_options)
73 | labels = labels.view(-1, num_options)[:, 0] # labels repeat so we only take the first one
74 | loss = loss_fct(selected_log_probs, labels)
75 | else:
76 | loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
77 | if not return_dict:
78 | output = (logits,) + outputs[1:]
79 | return (loss,) + output if loss is not None else output
80 |
81 | return CausalLMOutputWithPast(
82 | loss=loss,
83 | logits=logits,
84 | past_key_values=outputs.past_key_values,
85 | hidden_states=outputs.hidden_states,
86 | attentions=outputs.attentions,
87 | )
88 |
89 | @dataclass
90 | class DataCollatorWithPaddingAndNesting:
91 |
92 | tokenizer: PreTrainedTokenizerBase
93 | padding: Union[bool, str, PaddingStrategy] = True
94 | max_length: Optional[int] = None
95 | pad_to_multiple_of: Optional[int] = None
96 | return_tensors: str = "pt"
97 |
98 | def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
99 | features = [ff for f in features for ff in f]
100 | batch = self.tokenizer.pad(
101 | features,
102 | padding=self.padding,
103 | max_length=self.max_length,
104 | pad_to_multiple_of=self.pad_to_multiple_of,
105 | return_tensors=self.return_tensors,
106 | )
107 | if "label" in batch:
108 | batch["labels"] = batch["label"]
109 | del batch["label"]
110 | if "label_ids" in batch:
111 | batch["labels"] = batch["label_ids"]
112 | del batch["label_ids"]
113 | return batch
114 |
115 |
116 | def encode_prompt(task, template, train_samples, eval_sample, tokenizer, max_length, sfc=False, icl_sfc=False, generation=False, generation_with_gold=False, **kwargs):
117 | """
118 | sfc: calibration (surface form competition)
119 | icl_sfc: calibration (surface form competition) with in-context demonstrations
120 | """
121 | train_prompts = [template.verbalize(sample, sample.correct_candidate).strip() for sample in train_samples]
122 |
123 |
124 | ############### New code to include a label mask ##################
125 | train_label_positions = []
126 | correct_label_lengths = []
127 | prompt = ''
128 |
129 | for (sample, sample_cand) in zip( train_samples, train_prompts ):
130 |
131 | example_prompt = (prompt + sample_cand).strip()
132 | example_sent = (prompt + template.encode(sample)).strip()
133 |
134 | label_len = len(tokenizer.encode(example_prompt)) - len(tokenizer.encode(example_sent))
135 | train_label_positions += [ len(tokenizer.encode(example_sent)) ]
136 | correct_label_lengths += [ label_len ]
137 | prompt = example_prompt + task.train_sep
138 | label_mask = None
139 | ############### New code to include a label mask ##################
140 |
141 | train_prompts = task.train_sep.join(train_prompts).strip()
142 |
143 | if sfc or icl_sfc:
144 | encode_fn = template.encode_sfc; verbalize_fn = template.verbalize_sfc
145 | else:
146 | encode_fn = template.encode; verbalize_fn = template.verbalize
147 |
148 | unverbalized_eval_prompt = encode_fn(eval_sample).strip(' ')
149 | if not generation:
150 | verbalized_eval_prompts = [verbalize_fn(eval_sample, cand).strip(' ') for cand in eval_sample.candidates]
151 | unverbalized_eval_prompt_length = len(tokenizer.encode(unverbalized_eval_prompt))
152 | option_lens = [(len(tokenizer.encode(verbalized_eval_prompt)) - unverbalized_eval_prompt_length) for verbalized_eval_prompt in verbalized_eval_prompts]
153 |
154 | if sfc or train_prompts == '':
155 | # without demonstrations
156 | final_prompts = verbalized_eval_prompts
157 | else:
158 | # with demonstrations
159 | final_prompts = [(train_prompts + task.train_sep + eval_prompt).lstrip().strip(' ') for eval_prompt in verbalized_eval_prompts]
160 |
161 | else:
162 | assert not sfc and not icl_sfc, "Generation tasks do not support SFC"
163 | if generation_with_gold:
164 | verbalized_eval_prompts = [verbalize_fn(eval_sample, eval_sample.correct_candidate)]
165 | unverbalized_eval_prompt_length = len(tokenizer.encode(unverbalized_eval_prompt))
166 | option_lens = [(len(tokenizer.encode(verbalized_eval_prompt)) - unverbalized_eval_prompt_length) for verbalized_eval_prompt in verbalized_eval_prompts]
167 | final_prompts = [(train_prompts + task.train_sep + eval_prompt).lstrip().strip(' ') for eval_prompt in verbalized_eval_prompts]
168 | else:
169 | option_lens = [0]
170 | final_prompts = [(train_prompts + task.train_sep + unverbalized_eval_prompt).lstrip().strip(' ')]
171 |
172 | # tokenize
173 | encodings = [tokenizer.encode(final_prompt) for final_prompt in final_prompts]
174 |
175 | ############### New code to include a label mask ##################
176 | if not generation:
177 | if sfc:
178 | label_mask = [np.zeros( len(encoding) ) for encoding in encodings]
179 | else:
180 | label_mask = []
181 |
182 | for enc_order in range(len(encodings)):
183 | mask = np.zeros( len(encodings[enc_order]) )
184 | for (pos, length) in zip( train_label_positions, correct_label_lengths ):
185 | mask [pos: pos+length] = 1.
186 | label_mask += [mask]
187 | # label_mask = np.stack(label_mask)
188 | ############### New code to include a label mask ##################
189 | # print (label_mask)
190 | # truncate
191 | if any([len(encoding) > max_length for encoding in encodings]):
192 | logger.warn("Exceed max length")
193 | if tokenizer.add_bos_token:
194 | encodings = [encoding[0:1] + encoding[1:][-(max_length-1):] for encoding in encodings]
195 | # label_mask = [lmask[0:1] + lmask[1:][-(max_length-1):] for lmask in label_mask]
196 | else:
197 | encodings = [encoding[-max_length:] for encoding in encodings]
198 | # label_mask = [lmask[-max_length:] for lmask in label_mask]
199 |
200 | return encodings, option_lens, label_mask
201 |
202 |
203 | def encode_prompt_with_construction(task, template, train_samples, eval_sample, tokenizer, max_length, sfc=False, icl_sfc=False, generation=False, generation_with_gold=False, **kwargs):
204 |
205 | """
206 | sfc: calibration (surface form competition)
207 | icl_sfc: calibration (surface form competition) with in-context demonstrations
208 | """
209 |
210 | restrict_attention_demonstration=kwargs['restrict_attention_demonstration']
211 | position_modify_demonstration=kwargs['position_modify_demonstration']
212 | mask_demonstration_eval=kwargs['mask_demonstration_eval']
213 | position_modify_eval=kwargs['position_modify_eval']
214 |
215 |
216 | assert restrict_attention_demonstration or not position_modify_demonstration, \
217 | "Can't change position if we are not restricting attention of each demonstration to itself"
218 |
219 |
220 | assert mask_demonstration_eval or not position_modify_eval, \
221 | "Can't change position if we are not masking demonstrations"
222 |
223 |
224 |
225 | train_prompts = [template.verbalize(sample, sample.correct_candidate).strip() for sample in train_samples]
226 |
227 | ############### New code to include a label mask, icl mask, and modify positions ##################
228 | train_example_start = []
229 | train_example_end = []
230 |
231 | train_label_positions = []
232 | correct_label_lengths = []
233 |
234 | prompt = ''
235 | if tokenizer.add_bos_token: shift_right = 1
236 | else: shift_right = 0
237 | for (sample, sample_cand) in zip( train_samples, train_prompts ):
238 |
239 | example_prompt = (prompt + sample_cand).lstrip()
240 | example_sent = (prompt + template.encode(sample)).lstrip()
241 |
242 | example_prompt_enc = tokenizer.encode(example_prompt)
243 | example_sent_enc = tokenizer.encode(example_sent)
244 | example_sample_enc = tokenizer.encode(sample_cand.strip())
245 |
246 | label_len = len(example_prompt_enc) - len(example_sent_enc)
247 |
248 | train_example_start += [len(example_prompt_enc) - len(example_sample_enc) + shift_right]
249 | train_example_end += [len(example_prompt_enc)]
250 |
251 | train_label_positions += [ len(example_sent_enc) ]
252 | correct_label_lengths += [ label_len ]
253 |
254 | prompt = example_prompt + task.train_sep
255 | label_mask = None
256 | test_example_start = len(tokenizer.encode(prompt + 'end')) - len(tokenizer.encode('end')) + shift_right
257 |
258 |
259 | ############### New code to include a label mask ##################
260 |
261 | train_prompts = task.train_sep.join(train_prompts).strip()
262 |
263 | if sfc or icl_sfc:
264 | encode_fn = template.encode_sfc; verbalize_fn = template.verbalize_sfc
265 | else:
266 | encode_fn = template.encode; verbalize_fn = template.verbalize
267 |
268 | unverbalized_eval_prompt = encode_fn(eval_sample).strip(' ')
269 | if not generation:
270 | verbalized_eval_prompts = [verbalize_fn(eval_sample, cand).strip(' ') for cand in eval_sample.candidates]
271 | unverbalized_eval_prompt_length = len(tokenizer.encode(unverbalized_eval_prompt))
272 | option_lens = [(len(tokenizer.encode(verbalized_eval_prompt)) - unverbalized_eval_prompt_length) for verbalized_eval_prompt in verbalized_eval_prompts]
273 |
274 | if sfc or train_prompts == '':
275 | # without demonstrations
276 | final_prompts = verbalized_eval_prompts
277 | else:
278 | # with demonstrations
279 | final_prompts = [(train_prompts + task.train_sep + eval_prompt).lstrip().strip(' ') for eval_prompt in verbalized_eval_prompts]
280 |
281 | else:
282 | assert not sfc and not icl_sfc, "Generation tasks do not support SFC"
283 | if generation_with_gold:
284 | verbalized_eval_prompts = [verbalize_fn(eval_sample, eval_sample.correct_candidate)]
285 | unverbalized_eval_prompt_length = len(tokenizer.encode(unverbalized_eval_prompt))
286 | option_lens = [(len(tokenizer.encode(verbalized_eval_prompt)) - unverbalized_eval_prompt_length) for verbalized_eval_prompt in verbalized_eval_prompts]
287 | final_prompts = [(train_prompts + task.train_sep + eval_prompt).lstrip().strip(' ') for eval_prompt in verbalized_eval_prompts]
288 | else:
289 | option_lens = [0]
290 | final_prompts = [(train_prompts + task.train_sep + unverbalized_eval_prompt).lstrip().strip(' ')]
291 |
292 | # tokenize
293 | encodings = [tokenizer.encode(final_prompt) for final_prompt in final_prompts]
294 |
295 | ############### New code to include a label mask ##################
296 | if not generation:
297 | if sfc:
298 | label_mask = [np.zeros( len(encoding) ) for encoding in encodings]
299 | else:
300 | label_mask = []
301 |
302 | for enc_order in range(len(encodings)):
303 | mask = np.zeros( len(encodings[enc_order]) )
304 | for (pos, length) in zip( train_label_positions, correct_label_lengths ):
305 | mask [pos: pos+length] = 1.
306 | label_mask += [mask]
307 | label_mask = np.stack(label_mask)
308 | ############### New code to include a label mask ##################
309 |
310 | ############### New code to include a icl mask and position ids##################
311 | icl_mask = [ np.ones( (len(encoding), len(encoding)) ) for encoding in encodings ]
312 | if restrict_attention_demonstration:
313 | for (start, end) in zip(train_example_start, train_example_end):
314 | for i in range(len(encodings)):
315 | icl_mask[i][start:end, start:end] = 1.
316 | icl_mask[i][start:end, :start] = 0.
317 | icl_mask[i][start:end, end:] = 0.
318 |
319 | if mask_demonstration_eval:
320 | for i in range(len(encodings)):
321 | icl_mask[i][test_example_start:, test_example_start:] = 1.
322 | icl_mask[i][test_example_start:, :test_example_start] = 0.
323 |
324 | if tokenizer.add_bos_token:
325 | for i in range(len(encodings)): icl_mask[i][:, 0] = 1.
326 |
327 | position_ids = [ np.arange( len(encoding) ) for encoding in encodings ]
328 | if position_modify_demonstration:
329 | if tokenizer.add_bos_token: start_id = 1
330 | else: start_id = 0
331 |
332 | for (start, end) in zip(train_example_start, train_example_end):
333 | example_length = end - start
334 | for i in range(len(encodings)):
335 | position_ids [i][start:end] = np.arange( start_id, start_id + example_length )
336 |
337 | if position_modify_eval:
338 | if tokenizer.add_bos_token: start_id = 1
339 | else: start_id = 0
340 |
341 | for i in range( len(encodings) ):
342 | start = test_example_start
343 | end = len(encodings[i])
344 |
345 | example_length = end - start
346 | position_ids [i][start:end] = np.arange( start_id, start_id + example_length )
347 |
348 | ############### New code to include a icl mask and position ids ##################
349 |
350 |
351 | # truncate
352 | if any([len(encoding) > max_length for encoding in encodings]):
353 | logger.warn("Exceed max length")
354 | if tokenizer.add_bos_token:
355 |
356 | new_position_ids = []
357 | new_label_mask = []
358 | new_icl_mask = []
359 | for i in range( len(encodings) ):
360 | max_len = min(max_length, len(encodings[i]))
361 |
362 | nicm = np.zeros((max_len, max_len))
363 | nicm [ -(max_len-1):, -(max_len-1): ] = icl_mask[i] [-(max_len-1):, -(max_len-1):]
364 | nicm [ :, 0 ] = 1.
365 |
366 | new_icl_mask += [nicm]
367 | new_position_ids += [ np.concatenate( [np.asarray([0]), position_ids[i][ -(max_len-1): ] ], axis=0 ) ]
368 | new_label_mask += [ np.concatenate( [np.asarray([label_mask[i][0]]), label_mask[i][ -(max_len-1): ] ], axis=0 ) ]
369 | icl_mask = new_icl_mask
370 | position_ids = new_position_ids
371 | label_mask = new_label_mask
372 |
373 | encodings = [encoding[0:1] + encoding[1:][-(max_length-1):] for encoding in encodings]
374 | else:
375 | encodings = [encoding[-max_length:] for encoding in encodings]
376 | label_mask = [lmask[-max_length:] for lmask in label_mask]
377 | position_ids = [ position_id[ -max_length: ] for position_id in position_ids]
378 | icl_mask = [ icm[-max_length:, -max_length:] for icm in icl_mask ]
379 |
380 | return encodings, option_lens, label_mask, icl_mask, position_ids
381 |
382 |
383 | def load_generation():
384 | out_dir = "/scratch/gpfs/mengzhou/space6/out/the_pile_corrected"
385 | json_files = [f"{out_dir}/ft/ft_opt-125m-lr1e-4/generation_downstream/ft_opt-125m-lr1e-4-hf-sample-0.90-len20-num5-copa.json"]
386 | for model in ["350m", "1.3b", "2.7b"]:
387 | json_files.append(f"{out_dir}/kd_pretrained_ce1_layer0_bs512/kd_pretrained_temp1_tmodelft{model}_lr1e-4/generation_downstream/kd_pretrained_temp1_tmodelft{model}_lr1e-4-hf-sample-0.90-len20-num5-copa.json")
388 |
389 | i = 0
390 | for json_file in json_files[3:4]:
391 | name = os.path.basename(os.path.dirname(os.path.dirname(json_file)))
392 | generations = json.load(open(json_file))
393 | gen = generations[i]
394 | print(name)
395 | print("\033[1mprefix\033[0m:", gen["prefix"])
396 | print("\033[1mcorrect_options\033[0m:", gen["correct_options"])
397 | for i in range(len(gen["incorrect_options"])):
398 | print("\033[1mincorrect_options\033[0m" + f" {i}:", end=" ")
399 | print(gen["incorrect_options"][i])
400 | for i in range(len(gen["generated"])):
401 | print("\033[1mgenerated_option\033[0m" + f" {i}:", end=" ")
402 | print(f"[{round(gen['scores'][i], 2)}]:", end=" ")
403 | print(gen["generated"][i])
404 |
405 | def read_jsonl(file):
406 | ds = []
407 | try:
408 | with open(file) as f:
409 | for i, line in enumerate(f):
410 | d = json.loads(line.strip())
411 | ds.append(d)
412 | except:
413 | import pdb
414 | pdb.set_trace()
415 | return ds
416 |
417 |
418 | from collections.abc import Mapping
419 | from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
420 | InputDataClass = NewType("InputDataClass", Any)
421 | from dataclasses import dataclass
422 | from transformers.tokenization_utils_base import PreTrainedTokenizerBase
423 | import torch
424 | @dataclass
425 | class ICLCollator:
426 | tokenizer: PreTrainedTokenizerBase
427 |
428 | def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
429 | if not isinstance(features[0], Mapping):
430 | features = [vars(f) for f in features]
431 | first = features[0]
432 | batch = {}
433 |
434 | pad_id = self.tokenizer.pad_token_id
435 |
436 | pad_ids = {"input_ids": pad_id, "attention_mask": 0, "sfc_input_ids": pad_id, "sfc_attention_mask": 0, "labels": pad_id}
437 | for key in first:
438 | pp = pad_ids[key]
439 | lens = [len(f[key]) for f in features]
440 | max_len = max(lens)
441 | feature = np.stack([np.pad(f[key], (0, max_len - lens[i]), "constant", constant_values=(0, pp)) for i, f in enumerate(features)])
442 | padded_feature = torch.from_numpy(feature).long()
443 | batch[key] = padded_feature
444 |
445 | return batch
446 |
447 |
448 | import json
449 | class EnhancedJSONEncoder(json.JSONEncoder):
450 | def default(self, o):
451 | if is_dataclass(o):
452 | return asdict(o)
453 | return super().default(o)
454 |
455 | def write_predictions_to_file(final_preds, output):
456 | with open(output, "w") as f:
457 | for pred in final_preds:
458 | f.write(json.dumps(pred, cls=EnhancedJSONEncoder) + "\n")
459 |
460 | def write_metrics_to_file(metrics, output):
461 | with open(output, "w") as f:
462 | if type(metrics) == list:
463 | for metric in metrics:
464 | json.dump(metric, f, cls=EnhancedJSONEncoder)
465 | f.write("\n")
466 | else:
467 | json.dump(metrics, f, cls=EnhancedJSONEncoder, indent=4)
468 |
469 | if __name__ == "__main__":
470 | load_generation()
471 |
--------------------------------------------------------------------------------