├── .gitignore ├── LICENSE ├── README.md ├── environment.yml ├── setting.jpg └── src ├── base_models.py ├── conf ├── base.yaml ├── decision_tree.yaml ├── linear_regression.yaml ├── models │ ├── small.yaml │ ├── standard.yaml │ └── tiny.yaml ├── relu_2nn_regression.yaml ├── sparse_linear_regression.yaml ├── toy.yaml └── wandb.yaml ├── curriculum.py ├── eval.ipynb ├── eval.py ├── models.py ├── plot_utils.py ├── samplers.py ├── schema.py ├── tasks.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints 2 | __pycache__ 3 | models.zip 4 | models 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Dimitris Tsipras, Shivam Garg 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This repository contains the code and models for our paper: 2 | 3 | **What Can Transformers Learn In-Context? A Case Study of Simple Function Classes**
4 | *Shivam Garg\*, Dimitris Tsipras\*, Percy Liang, Gregory Valiant*
5 | Paper: http://arxiv.org/abs/2208.01066

6 | 7 | ![](setting.jpg) 8 | 9 | ```bibtex 10 | @InProceedings{garg2022what, 11 | title={What Can Transformers Learn In-Context? A Case Study of Simple Function Classes}, 12 | author={Shivam Garg and Dimitris Tsipras and Percy Liang and Gregory Valiant}, 13 | year={2022}, 14 | booktitle={arXiv preprint} 15 | } 16 | ``` 17 | 18 | ## Getting started 19 | You can start by cloning our repository and following the steps below. 20 | 21 | 1. Install the dependencies for our code using Conda. You may need to adjust the environment YAML file depending on your setup. 22 | 23 | ``` 24 | conda env create -f environment.yml 25 | conda activate in-context-learning 26 | ``` 27 | 28 | 2. Download [model checkpoints](https://github.com/dtsip/in-context-learning/releases/download/initial/models.zip) and extract them in the current directory. 29 | 30 | ``` 31 | wget https://github.com/dtsip/in-context-learning/releases/download/initial/models.zip 32 | unzip models.zip 33 | ``` 34 | 35 | 3. [Optional] If you plan to train, populate `conf/wandb.yaml` with you wandb info. 36 | 37 | That's it! You can now explore our pre-trained models or train your own. The key entry points 38 | are as follows (starting from `src`): 39 | - The `eval.ipynb` notebook contains code to load our own pre-trained models, plot the pre-computed metrics, and evaluate them on new data. 40 | - `train.py` takes as argument a configuration yaml from `conf` and trains the corresponding model. You can try `python train.py --config conf/toy.yaml` for a quick training run. 41 | 42 | # Maintainers 43 | * [Shivam Garg](https://cs.stanford.edu/~shivamg/) 44 | * [Dimitris Tsipras](https://dtsipras.com/) 45 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: in-context-learning 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - pip=21.2.4 7 | - python=3.8.12 8 | - pytorch=1.11.0 9 | - pip: 10 | - jupyter==1.0.0 11 | - matplotlib==3.5.2 12 | - numpy==1.22.3 13 | - pandas==1.4.2 14 | - quinine==0.3.0 15 | - scikit-learn==1.0.2 16 | - seaborn==0.11.2 17 | - tqdm==4.64.0 18 | - transformers==4.17.0 19 | - wandb==0.12.11 20 | - xgboost==1.6.1 21 | - protobuf==3.20.1 22 | -------------------------------------------------------------------------------- /setting.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dtsip/in-context-learning/338337192195827c58356f9b627e5f999ea703d2/setting.jpg -------------------------------------------------------------------------------- /src/base_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class NeuralNetwork(nn.Module): 6 | def __init__(self, in_size=50, hidden_size=1000, out_size=1): 7 | super(NeuralNetwork, self).__init__() 8 | 9 | self.net = nn.Sequential( 10 | nn.Linear(in_size, hidden_size), 11 | nn.ReLU(), 12 | nn.Linear(hidden_size, out_size), 13 | ) 14 | 15 | def forward(self, x): 16 | out = self.net(x) 17 | return out 18 | 19 | 20 | class ParallelNetworks(nn.Module): 21 | def __init__(self, num_models, model_class, **model_class_init_args): 22 | super(ParallelNetworks, self).__init__() 23 | self.nets = nn.ModuleList( 24 | [model_class(**model_class_init_args) for i in range(num_models)] 25 | ) 26 | 27 | def forward(self, xs): 28 | assert xs.shape[0] == len(self.nets) 29 | 30 | for i in range(len(self.nets)): 31 | out = self.nets[i](xs[i]) 32 | if i == 0: 33 | outs = torch.zeros( 34 | [len(self.nets)] + list(out.shape), device=out.device 35 | ) 36 | outs[i] = out 37 | return outs 38 | -------------------------------------------------------------------------------- /src/conf/base.yaml: -------------------------------------------------------------------------------- 1 | inherit: 2 | - models/standard.yaml 3 | - wandb.yaml 4 | 5 | model: 6 | n_dims: 20 7 | n_positions: 101 8 | 9 | training: 10 | data: gaussian 11 | task_kwargs: {} 12 | batch_size: 64 13 | learning_rate: 0.0001 14 | save_every_steps: 1000 15 | keep_every_steps: 100000 16 | train_steps: 500001 17 | curriculum: 18 | dims: 19 | start: 5 20 | end: 20 21 | inc: 1 22 | interval: 2000 23 | -------------------------------------------------------------------------------- /src/conf/decision_tree.yaml: -------------------------------------------------------------------------------- 1 | inherit: 2 | - base.yaml 3 | 4 | training: 5 | task: decision_tree 6 | task_kwargs: {"depth": 4} 7 | curriculum: 8 | points: 9 | start: 26 10 | end: 101 11 | inc: 5 12 | interval: 2000 13 | 14 | out_dir: ../models/decision_tree 15 | 16 | wandb: 17 | name: "decision_tree_standard" 18 | -------------------------------------------------------------------------------- /src/conf/linear_regression.yaml: -------------------------------------------------------------------------------- 1 | inherit: 2 | - base.yaml 3 | 4 | training: 5 | task: linear_regression 6 | curriculum: 7 | points: 8 | start: 11 9 | end: 41 10 | inc: 2 11 | interval: 2000 12 | 13 | out_dir: ../models/linear_regression 14 | 15 | wandb: 16 | name: "linear_regression_standard" 17 | -------------------------------------------------------------------------------- /src/conf/models/small.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | family: gpt2 3 | n_embd: 128 4 | n_layer: 6 5 | n_head: 4 6 | -------------------------------------------------------------------------------- /src/conf/models/standard.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | family: gpt2 3 | n_embd: 256 4 | n_layer: 12 5 | n_head: 8 6 | -------------------------------------------------------------------------------- /src/conf/models/tiny.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | family: gpt2 3 | n_embd: 64 4 | n_layer: 3 5 | n_head: 2 6 | -------------------------------------------------------------------------------- /src/conf/relu_2nn_regression.yaml: -------------------------------------------------------------------------------- 1 | inherit: 2 | - base.yaml 3 | 4 | training: 5 | task: relu_2nn_regression 6 | task_kwargs: {"hidden_layer_size": 100} 7 | curriculum: 8 | points: 9 | start: 26 10 | end: 101 11 | inc: 5 12 | interval: 2000 13 | 14 | out_dir: ../models/relu_2nn_regression 15 | 16 | wandb: 17 | name: "relu_2nn_regression_standard" 18 | -------------------------------------------------------------------------------- /src/conf/sparse_linear_regression.yaml: -------------------------------------------------------------------------------- 1 | inherit: 2 | - base.yaml 3 | 4 | training: 5 | task: sparse_linear_regression 6 | task_kwargs: {"sparsity": 3} 7 | curriculum: 8 | points: 9 | start: 11 10 | end: 41 11 | inc: 2 12 | interval: 2000 13 | 14 | out_dir: ../models/sparse_linear_regression 15 | 16 | wandb: 17 | name: "sparse_regression_standard" 18 | -------------------------------------------------------------------------------- /src/conf/toy.yaml: -------------------------------------------------------------------------------- 1 | inherit: 2 | - models/standard.yaml 3 | - wandb.yaml 4 | 5 | model: 6 | n_dims: 5 7 | n_positions: 11 8 | 9 | training: 10 | task: linear_regression 11 | data: gaussian 12 | task_kwargs: {} 13 | batch_size: 64 14 | learning_rate: 0.0001 15 | save_every_steps: 1000 16 | keep_every_steps: 100000 17 | train_steps: 5001 18 | curriculum: 19 | dims: 20 | start: 5 21 | end: 5 22 | inc: 1 23 | interval: 2000 24 | points: 25 | start: 11 26 | end: 11 27 | inc: 2 28 | interval: 2000 29 | 30 | out_dir: ../models/linear_regression 31 | 32 | wandb: 33 | name: "linear_regression_toy" 34 | -------------------------------------------------------------------------------- /src/conf/wandb.yaml: -------------------------------------------------------------------------------- 1 | wandb: 2 | project: in-context-training 3 | entity: your-entity 4 | notes: 5 | log_every_steps: 100 6 | -------------------------------------------------------------------------------- /src/curriculum.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | 4 | class Curriculum: 5 | def __init__(self, args): 6 | # args.dims and args.points each contain start, end, inc, interval attributes 7 | # inc denotes the change in n_dims, 8 | # this change is done every interval, 9 | # and start/end are the limits of the parameter 10 | self.n_dims_truncated = args.dims.start 11 | self.n_points = args.points.start 12 | self.n_dims_schedule = args.dims 13 | self.n_points_schedule = args.points 14 | self.step_count = 0 15 | 16 | def update(self): 17 | self.step_count += 1 18 | self.n_dims_truncated = self.update_var( 19 | self.n_dims_truncated, self.n_dims_schedule 20 | ) 21 | self.n_points = self.update_var(self.n_points, self.n_points_schedule) 22 | 23 | def update_var(self, var, schedule): 24 | if self.step_count % schedule.interval == 0: 25 | var += schedule.inc 26 | 27 | return min(var, schedule.end) 28 | 29 | 30 | # returns the final value of var after applying curriculum. 31 | def get_final_var(init_var, total_steps, inc, n_steps, lim): 32 | final_var = init_var + math.floor((total_steps) / n_steps) * inc 33 | 34 | return min(final_var, lim) 35 | -------------------------------------------------------------------------------- /src/eval.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | 5 | from munch import Munch 6 | import numpy as np 7 | import pandas as pd 8 | from tqdm import tqdm 9 | import torch 10 | import yaml 11 | 12 | import models 13 | from samplers import get_data_sampler, sample_transformation 14 | from tasks import get_task_sampler 15 | 16 | 17 | def get_model_from_run(run_path, step=-1, only_conf=False): 18 | config_path = os.path.join(run_path, "config.yaml") 19 | with open(config_path) as fp: # we don't Quinfig it to avoid inherits 20 | conf = Munch.fromDict(yaml.safe_load(fp)) 21 | if only_conf: 22 | return None, conf 23 | 24 | model = models.build_model(conf.model) 25 | 26 | if step == -1: 27 | state_path = os.path.join(run_path, "state.pt") 28 | state = torch.load(state_path) 29 | model.load_state_dict(state["model_state_dict"]) 30 | else: 31 | model_path = os.path.join(run_path, f"model_{step}.pt") 32 | state_dict = torch.load(model_path) 33 | model.load_state_dict(state_dict) 34 | 35 | return model, conf 36 | 37 | 38 | # Functions for evaluation 39 | 40 | 41 | def eval_batch(model, task_sampler, xs, xs_p=None): 42 | task = task_sampler() 43 | if torch.cuda.is_available() and model.name.split("_")[0] in ["gpt2", "lstm"]: 44 | device = "cuda" 45 | else: 46 | device = "cpu" 47 | 48 | if xs_p is None: 49 | ys = task.evaluate(xs) 50 | pred = model(xs.to(device), ys.to(device)).detach() 51 | metrics = task.get_metric()(pred.cpu(), ys) 52 | else: 53 | b_size, n_points, _ = xs.shape 54 | metrics = torch.zeros(b_size, n_points) 55 | for i in range(n_points): 56 | xs_comb = torch.cat((xs[:, :i, :], xs_p[:, i:, :]), dim=1) 57 | ys = task.evaluate(xs_comb) 58 | 59 | pred = model(xs_comb.to(device), ys.to(device), inds=[i]).detach() 60 | metrics[:, i] = task.get_metric()(pred.cpu(), ys)[:, i] 61 | 62 | return metrics 63 | 64 | 65 | # Functions for generating different kinds of train/test data 66 | 67 | 68 | def gen_standard(data_sampler, n_points, b_size): 69 | xs = data_sampler.sample_xs(n_points, b_size) 70 | 71 | return xs, None 72 | 73 | 74 | def gen_opposite_quadrants(data_sampler, n_points, b_size): 75 | xs = data_sampler.sample_xs(n_points, b_size) 76 | pattern = torch.randn([b_size, 1, xs.shape[2]]).sign() 77 | 78 | xs_train_pre = xs.abs() * pattern 79 | xs_test_post = -xs_train_pre 80 | 81 | return xs_train_pre, xs_test_post 82 | 83 | 84 | def gen_random_quadrants(data_sampler, n_points, b_size): 85 | xs = data_sampler.sample_xs(n_points, b_size) 86 | pattern = torch.randn([b_size, 1, xs.shape[2]]).sign() 87 | 88 | xs_train_pre = xs.abs() * pattern 89 | xs_test_post = xs 90 | 91 | return xs_train_pre, xs_test_post 92 | 93 | 94 | def gen_orthogonal_train_test(data_sampler, n_points, b_size): 95 | xs = data_sampler.sample_xs(n_points, b_size) 96 | n_dim = xs.shape[2] 97 | n_points = min(n_points, n_dim) 98 | # raise ValueError("number of points should be at most the dimension.") 99 | xs_train_pre = xs 100 | xs_test_post = torch.zeros(xs.shape) 101 | for i in range(n_points): 102 | xs_test_post_i = xs[:, i : i + 1, :] 103 | xs_train_pre_i = xs[:, :i, :] 104 | _, _, Vt = torch.linalg.svd(xs_train_pre_i, full_matrices=False) 105 | xs_train_pre_i_projection = Vt.transpose(1, 2) @ Vt 106 | xs_test_post_i_orthogonalized = ( 107 | xs_test_post_i - xs_test_post_i @ xs_train_pre_i_projection 108 | ) 109 | xs_test_post_i_normalized = ( 110 | xs_test_post_i_orthogonalized 111 | * xs_test_post_i.norm(dim=2).unsqueeze(2) 112 | / xs_test_post_i_orthogonalized.norm(dim=2).unsqueeze(2) 113 | ) 114 | 115 | xs_test_post[:, i : i + 1, :] = xs_test_post_i_normalized 116 | 117 | return xs_train_pre, xs_test_post 118 | 119 | 120 | def gen_overlapping_train_test(data_sampler, n_points, b_size): 121 | xs = data_sampler.sample_xs(n_points, b_size) 122 | xs_train_pre = xs 123 | xs_test_post = xs.clone() 124 | b_size = xs.shape[0] 125 | for i in range(1, n_points): 126 | xs_train_pre_i = xs[:, :i, :] 127 | perm = torch.stack([torch.randperm(i) for _ in range(b_size)]).unsqueeze(dim=1) 128 | ind_mat = (perm == 0) + 0.0 129 | xs_test_post[:, i : i + 1, :] = ind_mat @ xs_train_pre_i 130 | 131 | return xs_train_pre, xs_test_post 132 | 133 | 134 | def aggregate_metrics(metrics, bootstrap_trials=1000): 135 | """ 136 | Takes as input a tensor of shape (num_eval, n_points) and returns a dict with 137 | per-point mean, stddev, and bootstrap limits 138 | """ 139 | results = {} 140 | results["mean"] = metrics.mean(dim=0) 141 | results["std"] = metrics.std(dim=0, unbiased=True) 142 | n = len(metrics) 143 | bootstrap_indices = torch.randint(n, size=(bootstrap_trials, n)) 144 | bootstrap_means = metrics[bootstrap_indices].mean(dim=1).sort(dim=0)[0] 145 | results["bootstrap_low"] = bootstrap_means[int(0.05 * bootstrap_trials), :] 146 | results["bootstrap_high"] = bootstrap_means[int(0.95 * bootstrap_trials), :] 147 | 148 | return {k: v.tolist() for k, v in results.items()} 149 | 150 | 151 | def eval_model( 152 | model, 153 | task_name, 154 | data_name, 155 | n_dims, 156 | n_points, 157 | prompting_strategy, 158 | num_eval_examples=1280, 159 | batch_size=64, 160 | data_sampler_kwargs={}, 161 | task_sampler_kwargs={}, 162 | ): 163 | """ 164 | Evaluate a model on a task with a variety of strategies. 165 | Args: 166 | - task: which base task we are evaluating on. E.g., "linear_regression" 167 | - prompting_strategy: how to construct the prompt, e.g., "random_quadrants" 168 | - num_eval_examples: total number of examples to evaluate on 169 | - **sampler_kwargs: remaining arguments to pass directly to the sampler 170 | """ 171 | 172 | assert num_eval_examples % batch_size == 0 173 | data_sampler = get_data_sampler(data_name, n_dims, **data_sampler_kwargs) 174 | task_sampler = get_task_sampler( 175 | task_name, n_dims, batch_size, **task_sampler_kwargs 176 | ) 177 | 178 | all_metrics = [] 179 | 180 | generating_func = globals()[f"gen_{prompting_strategy}"] 181 | for i in range(num_eval_examples // batch_size): 182 | xs, xs_p = generating_func(data_sampler, n_points, batch_size) 183 | 184 | metrics = eval_batch(model, task_sampler, xs, xs_p) 185 | all_metrics.append(metrics) 186 | 187 | metrics = torch.cat(all_metrics, dim=0) 188 | 189 | return aggregate_metrics(metrics) 190 | 191 | 192 | def build_evals(conf): 193 | n_dims = conf.model.n_dims 194 | n_points = conf.training.curriculum.points.end 195 | batch_size = conf.training.batch_size 196 | 197 | task_name = conf.training.task 198 | data_name = conf.training.data 199 | 200 | base_kwargs = { 201 | "task_name": task_name, 202 | "n_dims": n_dims, 203 | "n_points": n_points, 204 | "batch_size": batch_size, 205 | "data_name": data_name, 206 | "prompting_strategy": "standard", 207 | } 208 | 209 | evaluation_kwargs = {} 210 | 211 | evaluation_kwargs["standard"] = {"prompting_strategy": "standard"} 212 | if task_name != "linear_regression": 213 | if task_name in ["relu_2nn_regression"]: 214 | evaluation_kwargs["linear_regression"] = {"task_name": "linear_regression"} 215 | for name, kwargs in evaluation_kwargs.items(): 216 | # allow kwargs to override base_kwargs values 217 | evaluation_kwargs[name] = base_kwargs.copy() 218 | evaluation_kwargs[name].update(kwargs) 219 | return evaluation_kwargs 220 | 221 | for strategy in [ 222 | "random_quadrants", 223 | "orthogonal_train_test", 224 | "overlapping_train_test", 225 | ]: 226 | evaluation_kwargs[strategy] = {"prompting_strategy": strategy} 227 | 228 | for method in ["half_subspace", "skewed"]: 229 | if "subspace" in method: 230 | eigenvals = torch.zeros(n_dims) 231 | eigenvals[: n_dims // 2] = 1 232 | else: 233 | eigenvals = 1 / (torch.arange(n_dims) + 1) 234 | 235 | scale = sample_transformation(eigenvals, normalize=True) 236 | evaluation_kwargs[f"{method}"] = { 237 | "data_sampler_kwargs": {"scale": scale}, 238 | } 239 | 240 | for dim in ["x", "y"]: 241 | for scale in [0.333, 0.5, 2, 3]: 242 | if dim == "x": 243 | eigenvals = scale * torch.ones(n_dims) 244 | t = sample_transformation(eigenvals) 245 | scaling_args = {"data_sampler_kwargs": {"scale": t}} 246 | else: 247 | eigenvals = scale * torch.ones(n_dims) 248 | scaling_args = {"task_sampler_kwargs": {"scale": scale}} 249 | 250 | evaluation_kwargs[f"scale-{dim}={scale}"] = scaling_args 251 | 252 | evaluation_kwargs[f"noisyLR"] = { 253 | "task_sampler_kwargs": {"renormalize_ys": True, "noise_std": 1}, 254 | "task_name": "noisy_linear_regression", 255 | } 256 | 257 | for name, kwargs in evaluation_kwargs.items(): 258 | # allow kwargs to override base_kwargs values 259 | evaluation_kwargs[name] = base_kwargs.copy() 260 | evaluation_kwargs[name].update(kwargs) 261 | 262 | return evaluation_kwargs 263 | 264 | 265 | def compute_evals(all_models, evaluation_kwargs, save_path=None, recompute=False): 266 | try: 267 | with open(save_path) as fp: 268 | all_metrics = json.load(fp) 269 | except Exception: 270 | all_metrics = {} 271 | 272 | for eval_name, kwargs in tqdm(evaluation_kwargs.items()): 273 | metrics = {} 274 | if eval_name in all_metrics and not recompute: 275 | metrics = all_metrics[eval_name] 276 | for model in all_models: 277 | if model.name in metrics and not recompute: 278 | continue 279 | 280 | metrics[model.name] = eval_model(model, **kwargs) 281 | all_metrics[eval_name] = metrics 282 | 283 | if save_path is not None: 284 | with open(save_path, "w") as fp: 285 | json.dump(all_metrics, fp, indent=2) 286 | 287 | return all_metrics 288 | 289 | 290 | def get_run_metrics( 291 | run_path, step=-1, cache=True, skip_model_load=False, skip_baselines=False 292 | ): 293 | if skip_model_load: 294 | _, conf = get_model_from_run(run_path, only_conf=True) 295 | all_models = [] 296 | else: 297 | model, conf = get_model_from_run(run_path, step) 298 | model = model.cuda().eval() 299 | all_models = [model] 300 | if not skip_baselines: 301 | all_models += models.get_relevant_baselines(conf.training.task) 302 | evaluation_kwargs = build_evals(conf) 303 | 304 | if not cache: 305 | save_path = None 306 | elif step == -1: 307 | save_path = os.path.join(run_path, "metrics.json") 308 | else: 309 | save_path = os.path.join(run_path, f"metrics_{step}.json") 310 | 311 | recompute = False 312 | if save_path is not None and os.path.exists(save_path): 313 | checkpoint_created = os.path.getmtime(run_path) 314 | cache_created = os.path.getmtime(save_path) 315 | if checkpoint_created > cache_created: 316 | recompute = True 317 | 318 | all_metrics = compute_evals(all_models, evaluation_kwargs, save_path, recompute) 319 | return all_metrics 320 | 321 | 322 | 323 | def conf_to_model_name(conf): 324 | if conf.model.family == "gpt2": 325 | return { 326 | (3, 2): "Transformer-xs", 327 | (6, 4): "Transformer-small", 328 | (12, 8): "Transformer", 329 | }[(conf.model.n_layer, conf.model.n_head)] 330 | else: 331 | return conf.wandb.name 332 | 333 | 334 | def baseline_names(name): 335 | if "OLS" in name: 336 | return "Least Squares" 337 | if name == "averaging": 338 | return "Averaging" 339 | if "NN" in name: 340 | k = name.split("_")[1].split("=")[1] 341 | return f"{k}-Nearest Neighbors" 342 | if "lasso" in name: 343 | alpha = name.split("_")[1].split("=")[1] 344 | return f"Lasso (alpha={alpha})" 345 | if "gd" in name: 346 | return "2-layer NN, GD" 347 | if "decision_tree" in name: 348 | return "Greedy Tree Learning" 349 | if "xgboost" in name: 350 | return "XGBoost" 351 | return name 352 | 353 | 354 | def read_run_dir(run_dir): 355 | all_runs = {} 356 | for task in os.listdir(run_dir): 357 | task_dir = os.path.join(run_dir, task) 358 | for run_id in os.listdir(task_dir): 359 | run_path = os.path.join(task_dir, run_id) 360 | _, conf = get_model_from_run(run_path, only_conf=True) 361 | params = {} 362 | params["run_id"] = run_id 363 | params["task"] = task 364 | params["model"] = conf_to_model_name(conf) 365 | params["kwargs"] = "_".join( 366 | f"{k}={v}" for k, v in conf.training.task_kwargs.items() 367 | ) 368 | num_tasks = ( 369 | conf.training.num_tasks if "num_tasks" in conf.training else None 370 | ) 371 | params["num_tasks"] = num_tasks if num_tasks is not None else -1 372 | num_examples = ( 373 | conf.training.num_training_examples 374 | if "num_training_examples" in conf.training 375 | else None 376 | ) 377 | params["num_examples"] = num_examples if num_examples is not None else -1 378 | params["n_dims"] = conf.model.n_dims 379 | params["n_layer"] = conf.model.n_layer 380 | params["n_head"] = conf.model.n_head 381 | params["run_name"] = conf.wandb.name 382 | 383 | for k, v in params.items(): 384 | if k not in all_runs: 385 | all_runs[k] = [] 386 | all_runs[k].append(v) 387 | 388 | df = pd.DataFrame(all_runs).sort_values("run_name") 389 | assert len(df) == len(df.run_name.unique()) 390 | return df 391 | 392 | if __name__ == "__main__": 393 | run_dir = sys.argv[1] 394 | for task in os.listdir(run_dir): 395 | task_dir = os.path.join(run_dir, task) 396 | print(f"Evaluating task {task}") 397 | for run_id in tqdm(os.listdir(task_dir)): 398 | run_path = os.path.join(run_dir, task, run_id) 399 | metrics = get_run_metrics(run_path) 400 | -------------------------------------------------------------------------------- /src/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from transformers import GPT2Model, GPT2Config 4 | from tqdm import tqdm 5 | from sklearn.svm import LinearSVC 6 | from sklearn.linear_model import LogisticRegression, Lasso 7 | import warnings 8 | from sklearn import tree 9 | import xgboost as xgb 10 | 11 | from base_models import NeuralNetwork, ParallelNetworks 12 | 13 | 14 | def build_model(conf): 15 | if conf.family == "gpt2": 16 | model = TransformerModel( 17 | n_dims=conf.n_dims, 18 | n_positions=conf.n_positions, 19 | n_embd=conf.n_embd, 20 | n_layer=conf.n_layer, 21 | n_head=conf.n_head, 22 | ) 23 | else: 24 | raise NotImplementedError 25 | 26 | return model 27 | 28 | 29 | def get_relevant_baselines(task_name): 30 | task_to_baselines = { 31 | "linear_regression": [ 32 | (LeastSquaresModel, {}), 33 | (NNModel, {"n_neighbors": 3}), 34 | (AveragingModel, {}), 35 | ], 36 | "linear_classification": [ 37 | (NNModel, {"n_neighbors": 3}), 38 | (AveragingModel, {}), 39 | ], 40 | "sparse_linear_regression": [ 41 | (LeastSquaresModel, {}), 42 | (NNModel, {"n_neighbors": 3}), 43 | (AveragingModel, {}), 44 | ] 45 | + [(LassoModel, {"alpha": alpha}) for alpha in [1, 0.1, 0.01, 0.001, 0.0001]], 46 | "relu_2nn_regression": [ 47 | (LeastSquaresModel, {}), 48 | (NNModel, {"n_neighbors": 3}), 49 | (AveragingModel, {}), 50 | ( 51 | GDModel, 52 | { 53 | "model_class": NeuralNetwork, 54 | "model_class_args": { 55 | "in_size": 20, 56 | "hidden_size": 100, 57 | "out_size": 1, 58 | }, 59 | "opt_alg": "adam", 60 | "batch_size": 100, 61 | "lr": 5e-3, 62 | "num_steps": 100, 63 | }, 64 | ), 65 | ], 66 | "decision_tree": [ 67 | (LeastSquaresModel, {}), 68 | (NNModel, {"n_neighbors": 3}), 69 | (DecisionTreeModel, {"max_depth": 4}), 70 | (DecisionTreeModel, {"max_depth": None}), 71 | (XGBoostModel, {}), 72 | (AveragingModel, {}), 73 | ], 74 | } 75 | 76 | models = [model_cls(**kwargs) for model_cls, kwargs in task_to_baselines[task_name]] 77 | return models 78 | 79 | 80 | class TransformerModel(nn.Module): 81 | def __init__(self, n_dims, n_positions, n_embd=128, n_layer=12, n_head=4): 82 | super(TransformerModel, self).__init__() 83 | configuration = GPT2Config( 84 | n_positions=2 * n_positions, 85 | n_embd=n_embd, 86 | n_layer=n_layer, 87 | n_head=n_head, 88 | resid_pdrop=0.0, 89 | embd_pdrop=0.0, 90 | attn_pdrop=0.0, 91 | use_cache=False, 92 | ) 93 | self.name = f"gpt2_embd={n_embd}_layer={n_layer}_head={n_head}" 94 | 95 | self.n_positions = n_positions 96 | self.n_dims = n_dims 97 | self._read_in = nn.Linear(n_dims, n_embd) 98 | self._backbone = GPT2Model(configuration) 99 | self._read_out = nn.Linear(n_embd, 1) 100 | 101 | @staticmethod 102 | def _combine(xs_b, ys_b): 103 | """Interleaves the x's and the y's into a single sequence.""" 104 | bsize, points, dim = xs_b.shape 105 | ys_b_wide = torch.cat( 106 | ( 107 | ys_b.view(bsize, points, 1), 108 | torch.zeros(bsize, points, dim - 1, device=ys_b.device), 109 | ), 110 | axis=2, 111 | ) 112 | zs = torch.stack((xs_b, ys_b_wide), dim=2) 113 | zs = zs.view(bsize, 2 * points, dim) 114 | return zs 115 | 116 | def forward(self, xs, ys, inds=None): 117 | if inds is None: 118 | inds = torch.arange(ys.shape[1]) 119 | else: 120 | inds = torch.tensor(inds) 121 | if max(inds) >= ys.shape[1] or min(inds) < 0: 122 | raise ValueError("inds contain indices where xs and ys are not defined") 123 | zs = self._combine(xs, ys) 124 | embeds = self._read_in(zs) 125 | output = self._backbone(inputs_embeds=embeds).last_hidden_state 126 | prediction = self._read_out(output) 127 | return prediction[:, ::2, 0][:, inds] # predict only on xs 128 | 129 | 130 | class NNModel: 131 | def __init__(self, n_neighbors, weights="uniform"): 132 | # should we be picking k optimally 133 | self.n_neighbors = n_neighbors 134 | self.weights = weights 135 | self.name = f"NN_n={n_neighbors}_{weights}" 136 | 137 | def __call__(self, xs, ys, inds=None): 138 | if inds is None: 139 | inds = range(ys.shape[1]) 140 | else: 141 | if max(inds) >= ys.shape[1] or min(inds) < 0: 142 | raise ValueError("inds contain indices where xs and ys are not defined") 143 | 144 | preds = [] 145 | 146 | for i in inds: 147 | if i == 0: 148 | preds.append(torch.zeros_like(ys[:, 0])) # predict zero for first point 149 | continue 150 | train_xs, train_ys = xs[:, :i], ys[:, :i] 151 | test_x = xs[:, i : i + 1] 152 | dist = (train_xs - test_x).square().sum(dim=2).sqrt() 153 | 154 | if self.weights == "uniform": 155 | weights = torch.ones_like(dist) 156 | else: 157 | weights = 1.0 / dist 158 | inf_mask = torch.isinf(weights).float() # deal with exact match 159 | inf_row = torch.any(inf_mask, axis=1) 160 | weights[inf_row] = inf_mask[inf_row] 161 | 162 | pred = [] 163 | k = min(i, self.n_neighbors) 164 | ranks = dist.argsort()[:, :k] 165 | for y, w, n in zip(train_ys, weights, ranks): 166 | y, w = y[n], w[n] 167 | pred.append((w * y).sum() / w.sum()) 168 | preds.append(torch.stack(pred)) 169 | 170 | return torch.stack(preds, dim=1) 171 | 172 | 173 | # xs and ys should be on cpu for this method. Otherwise the output maybe off in case when train_xs is not full rank due to the implementation of torch.linalg.lstsq. 174 | class LeastSquaresModel: 175 | def __init__(self, driver=None): 176 | self.driver = driver 177 | self.name = f"OLS_driver={driver}" 178 | 179 | def __call__(self, xs, ys, inds=None): 180 | xs, ys = xs.cpu(), ys.cpu() 181 | if inds is None: 182 | inds = range(ys.shape[1]) 183 | else: 184 | if max(inds) >= ys.shape[1] or min(inds) < 0: 185 | raise ValueError("inds contain indices where xs and ys are not defined") 186 | 187 | preds = [] 188 | 189 | for i in inds: 190 | if i == 0: 191 | preds.append(torch.zeros_like(ys[:, 0])) # predict zero for first point 192 | continue 193 | train_xs, train_ys = xs[:, :i], ys[:, :i] 194 | test_x = xs[:, i : i + 1] 195 | 196 | ws, _, _, _ = torch.linalg.lstsq( 197 | train_xs, train_ys.unsqueeze(2), driver=self.driver 198 | ) 199 | 200 | pred = test_x @ ws 201 | preds.append(pred[:, 0, 0]) 202 | 203 | return torch.stack(preds, dim=1) 204 | 205 | 206 | class AveragingModel: 207 | def __init__(self): 208 | self.name = "averaging" 209 | 210 | def __call__(self, xs, ys, inds=None): 211 | if inds is None: 212 | inds = range(ys.shape[1]) 213 | else: 214 | if max(inds) >= ys.shape[1] or min(inds) < 0: 215 | raise ValueError("inds contain indices where xs and ys are not defined") 216 | 217 | preds = [] 218 | 219 | for i in inds: 220 | if i == 0: 221 | preds.append(torch.zeros_like(ys[:, 0])) # predict zero for first point 222 | continue 223 | train_xs, train_ys = xs[:, :i], ys[:, :i] 224 | test_x = xs[:, i : i + 1] 225 | 226 | train_zs = train_xs * train_ys.unsqueeze(dim=-1) 227 | w_p = train_zs.mean(dim=1).unsqueeze(dim=-1) 228 | pred = test_x @ w_p 229 | preds.append(pred[:, 0, 0]) 230 | 231 | return torch.stack(preds, dim=1) 232 | 233 | 234 | # Lasso regression (for sparse linear regression). 235 | # Seems to take more time as we decrease alpha. 236 | class LassoModel: 237 | def __init__(self, alpha, max_iter=100000): 238 | # the l1 regularizer gets multiplied by alpha. 239 | self.alpha = alpha 240 | self.max_iter = max_iter 241 | self.name = f"lasso_alpha={alpha}_max_iter={max_iter}" 242 | 243 | # inds is a list containing indices where we want the prediction. 244 | # prediction made at all indices by default. 245 | def __call__(self, xs, ys, inds=None): 246 | xs, ys = xs.cpu(), ys.cpu() 247 | 248 | if inds is None: 249 | inds = range(ys.shape[1]) 250 | else: 251 | if max(inds) >= ys.shape[1] or min(inds) < 0: 252 | raise ValueError("inds contain indices where xs and ys are not defined") 253 | 254 | preds = [] # predict one for first point 255 | 256 | # i: loop over num_points 257 | # j: loop over bsize 258 | for i in inds: 259 | pred = torch.zeros_like(ys[:, 0]) 260 | 261 | if i > 0: 262 | pred = torch.zeros_like(ys[:, 0]) 263 | for j in range(ys.shape[0]): 264 | train_xs, train_ys = xs[j, :i], ys[j, :i] 265 | 266 | # If all points till now have the same label, predict that label. 267 | 268 | clf = Lasso( 269 | alpha=self.alpha, fit_intercept=False, max_iter=self.max_iter 270 | ) 271 | 272 | # Check for convergence. 273 | with warnings.catch_warnings(): 274 | warnings.filterwarnings("error") 275 | try: 276 | clf.fit(train_xs, train_ys) 277 | except Warning: 278 | print(f"lasso convergence warning at i={i}, j={j}.") 279 | raise 280 | 281 | w_pred = torch.from_numpy(clf.coef_).unsqueeze(1) 282 | 283 | test_x = xs[j, i : i + 1] 284 | y_pred = (test_x @ w_pred.float()).squeeze(1) 285 | pred[j] = y_pred[0] 286 | 287 | preds.append(pred) 288 | 289 | return torch.stack(preds, dim=1) 290 | 291 | 292 | # Gradient Descent and variants. 293 | # Example usage: gd_model = GDModel(NeuralNetwork, {'in_size': 50, 'hidden_size':400, 'out_size' :1}, opt_alg = 'adam', batch_size = 100, lr = 5e-3, num_steps = 200) 294 | class GDModel: 295 | def __init__( 296 | self, 297 | model_class, 298 | model_class_args, 299 | opt_alg="sgd", 300 | batch_size=1, 301 | num_steps=1000, 302 | lr=1e-3, 303 | loss_name="squared", 304 | ): 305 | # model_class: torch.nn model class 306 | # model_class_args: a dict containing arguments for model_class 307 | # opt_alg can be 'sgd' or 'adam' 308 | # verbose: whether to print the progress or not 309 | # batch_size: batch size for sgd 310 | self.model_class = model_class 311 | self.model_class_args = model_class_args 312 | self.opt_alg = opt_alg 313 | self.lr = lr 314 | self.batch_size = batch_size 315 | self.num_steps = num_steps 316 | self.loss_name = loss_name 317 | 318 | self.name = f"gd_model_class={model_class}_model_class_args={model_class_args}_opt_alg={opt_alg}_lr={lr}_batch_size={batch_size}_num_steps={num_steps}_loss_name={loss_name}" 319 | 320 | def __call__(self, xs, ys, inds=None, verbose=False, print_step=100): 321 | # inds is a list containing indices where we want the prediction. 322 | # prediction made at all indices by default. 323 | # xs: bsize X npoints X ndim. 324 | # ys: bsize X npoints. 325 | xs, ys = xs.cuda(), ys.cuda() 326 | 327 | if inds is None: 328 | inds = range(ys.shape[1]) 329 | else: 330 | if max(inds) >= ys.shape[1] or min(inds) < 0: 331 | raise ValueError("inds contain indices where xs and ys are not defined") 332 | 333 | preds = [] # predict one for first point 334 | 335 | # i: loop over num_points 336 | for i in tqdm(inds): 337 | pred = torch.zeros_like(ys[:, 0]) 338 | model = ParallelNetworks( 339 | ys.shape[0], self.model_class, **self.model_class_args 340 | ) 341 | model.cuda() 342 | if i > 0: 343 | pred = torch.zeros_like(ys[:, 0]) 344 | 345 | train_xs, train_ys = xs[:, :i], ys[:, :i] 346 | test_xs, test_ys = xs[:, i : i + 1], ys[:, i : i + 1] 347 | 348 | if self.opt_alg == "sgd": 349 | optimizer = torch.optim.SGD(model.parameters(), lr=self.lr) 350 | elif self.opt_alg == "adam": 351 | optimizer = torch.optim.Adam(model.parameters(), lr=self.lr) 352 | else: 353 | raise NotImplementedError(f"{self.opt_alg} not implemented.") 354 | 355 | if self.loss_name == "squared": 356 | loss_criterion = nn.MSELoss() 357 | else: 358 | raise NotImplementedError(f"{self.loss_name} not implemented.") 359 | 360 | # Training loop 361 | for j in range(self.num_steps): 362 | 363 | # Prepare batch 364 | mask = torch.zeros(i).bool() 365 | perm = torch.randperm(i) 366 | mask[perm[: self.batch_size]] = True 367 | train_xs_cur, train_ys_cur = train_xs[:, mask, :], train_ys[:, mask] 368 | 369 | if verbose and j % print_step == 0: 370 | model.eval() 371 | with torch.no_grad(): 372 | outputs = model(train_xs_cur) 373 | loss = loss_criterion( 374 | outputs[:, :, 0], train_ys_cur 375 | ).detach() 376 | outputs_test = model(test_xs) 377 | test_loss = loss_criterion( 378 | outputs_test[:, :, 0], test_ys 379 | ).detach() 380 | print( 381 | f"ind:{i},step:{j}, train_loss:{loss.item()}, test_loss:{test_loss.item()}" 382 | ) 383 | 384 | optimizer.zero_grad() 385 | 386 | model.train() 387 | outputs = model(train_xs_cur) 388 | loss = loss_criterion(outputs[:, :, 0], train_ys_cur) 389 | loss.backward() 390 | optimizer.step() 391 | 392 | model.eval() 393 | pred = model(test_xs).detach() 394 | 395 | assert pred.shape[1] == 1 and pred.shape[2] == 1 396 | pred = pred[:, 0, 0] 397 | 398 | preds.append(pred) 399 | 400 | return torch.stack(preds, dim=1) 401 | 402 | 403 | class DecisionTreeModel: 404 | def __init__(self, max_depth=None): 405 | self.max_depth = max_depth 406 | self.name = f"decision_tree_max_depth={max_depth}" 407 | 408 | # inds is a list containing indices where we want the prediction. 409 | # prediction made at all indices by default. 410 | def __call__(self, xs, ys, inds=None): 411 | xs, ys = xs.cpu(), ys.cpu() 412 | 413 | if inds is None: 414 | inds = range(ys.shape[1]) 415 | else: 416 | if max(inds) >= ys.shape[1] or min(inds) < 0: 417 | raise ValueError("inds contain indices where xs and ys are not defined") 418 | 419 | preds = [] 420 | 421 | # i: loop over num_points 422 | # j: loop over bsize 423 | for i in inds: 424 | pred = torch.zeros_like(ys[:, 0]) 425 | 426 | if i > 0: 427 | pred = torch.zeros_like(ys[:, 0]) 428 | for j in range(ys.shape[0]): 429 | train_xs, train_ys = xs[j, :i], ys[j, :i] 430 | 431 | clf = tree.DecisionTreeRegressor(max_depth=self.max_depth) 432 | clf = clf.fit(train_xs, train_ys) 433 | test_x = xs[j, i : i + 1] 434 | y_pred = clf.predict(test_x) 435 | pred[j] = y_pred[0] 436 | 437 | preds.append(pred) 438 | 439 | return torch.stack(preds, dim=1) 440 | 441 | 442 | class XGBoostModel: 443 | def __init__(self): 444 | self.name = "xgboost" 445 | 446 | # inds is a list containing indices where we want the prediction. 447 | # prediction made at all indices by default. 448 | def __call__(self, xs, ys, inds=None): 449 | xs, ys = xs.cpu(), ys.cpu() 450 | 451 | if inds is None: 452 | inds = range(ys.shape[1]) 453 | else: 454 | if max(inds) >= ys.shape[1] or min(inds) < 0: 455 | raise ValueError("inds contain indices where xs and ys are not defined") 456 | 457 | preds = [] 458 | 459 | # i: loop over num_points 460 | # j: loop over bsize 461 | for i in tqdm(inds): 462 | pred = torch.zeros_like(ys[:, 0]) 463 | if i > 0: 464 | pred = torch.zeros_like(ys[:, 0]) 465 | for j in range(ys.shape[0]): 466 | train_xs, train_ys = xs[j, :i], ys[j, :i] 467 | 468 | clf = xgb.XGBRegressor() 469 | 470 | clf = clf.fit(train_xs, train_ys) 471 | test_x = xs[j, i : i + 1] 472 | y_pred = clf.predict(test_x) 473 | pred[j] = y_pred[0].item() 474 | 475 | preds.append(pred) 476 | 477 | return torch.stack(preds, dim=1) 478 | -------------------------------------------------------------------------------- /src/plot_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import matplotlib.pyplot as plt 4 | import seaborn as sns 5 | 6 | from eval import get_run_metrics, baseline_names, get_model_from_run 7 | from models import build_model 8 | 9 | sns.set_theme("notebook", "darkgrid") 10 | palette = sns.color_palette("colorblind") 11 | 12 | 13 | relevant_model_names = { 14 | "linear_regression": [ 15 | "Transformer", 16 | "Least Squares", 17 | "3-Nearest Neighbors", 18 | "Averaging", 19 | ], 20 | "sparse_linear_regression": [ 21 | "Transformer", 22 | "Least Squares", 23 | "3-Nearest Neighbors", 24 | "Averaging", 25 | "Lasso (alpha=0.01)", 26 | ], 27 | "decision_tree": [ 28 | "Transformer", 29 | "3-Nearest Neighbors", 30 | "2-layer NN, GD", 31 | "Greedy Tree Learning", 32 | "XGBoost", 33 | ], 34 | "relu_2nn_regression": [ 35 | "Transformer", 36 | "Least Squares", 37 | "3-Nearest Neighbors", 38 | "2-layer NN, GD", 39 | ], 40 | } 41 | 42 | 43 | def basic_plot(metrics, models=None, trivial=1.0): 44 | fig, ax = plt.subplots(1, 1) 45 | 46 | if models is not None: 47 | metrics = {k: metrics[k] for k in models} 48 | 49 | color = 0 50 | ax.axhline(trivial, ls="--", color="gray") 51 | for name, vs in metrics.items(): 52 | ax.plot(vs["mean"], "-", label=name, color=palette[color % 10], lw=2) 53 | low = vs["bootstrap_low"] 54 | high = vs["bootstrap_high"] 55 | ax.fill_between(range(len(low)), low, high, alpha=0.3) 56 | color += 1 57 | ax.set_xlabel("in-context examples") 58 | ax.set_ylabel("squared error") 59 | ax.set_xlim(-1, len(low) + 0.1) 60 | ax.set_ylim(-0.1, 1.25) 61 | 62 | legend = ax.legend(loc="upper left", bbox_to_anchor=(1, 1)) 63 | fig.set_size_inches(4, 3) 64 | for line in legend.get_lines(): 65 | line.set_linewidth(3) 66 | 67 | return fig, ax 68 | 69 | 70 | def collect_results(run_dir, df, valid_row=None, rename_eval=None, rename_model=None): 71 | all_metrics = {} 72 | for _, r in df.iterrows(): 73 | if valid_row is not None and not valid_row(r): 74 | continue 75 | 76 | run_path = os.path.join(run_dir, r.task, r.run_id) 77 | _, conf = get_model_from_run(run_path, only_conf=True) 78 | 79 | print(r.run_name, r.run_id) 80 | metrics = get_run_metrics(run_path, skip_model_load=True) 81 | 82 | for eval_name, results in sorted(metrics.items()): 83 | processed_results = {} 84 | for model_name, m in results.items(): 85 | if "gpt2" in model_name in model_name: 86 | model_name = r.model 87 | if rename_model is not None: 88 | model_name = rename_model(model_name, r) 89 | else: 90 | model_name = baseline_names(model_name) 91 | m_processed = {} 92 | n_dims = conf.model.n_dims 93 | 94 | xlim = 2 * n_dims + 1 95 | if r.task in ["relu_2nn_regression", "decision_tree"]: 96 | xlim = 200 97 | 98 | normalization = n_dims 99 | if r.task == "sparse_linear_regression": 100 | normalization = int(r.kwargs.split("=")[-1]) 101 | if r.task == "decision_tree": 102 | normalization = 1 103 | 104 | for k, v in m.items(): 105 | v = v[:xlim] 106 | v = [vv / normalization for vv in v] 107 | m_processed[k] = v 108 | processed_results[model_name] = m_processed 109 | if rename_eval is not None: 110 | eval_name = rename_eval(eval_name, r) 111 | if eval_name not in all_metrics: 112 | all_metrics[eval_name] = {} 113 | all_metrics[eval_name].update(processed_results) 114 | return all_metrics 115 | -------------------------------------------------------------------------------- /src/samplers.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | 5 | 6 | class DataSampler: 7 | def __init__(self, n_dims): 8 | self.n_dims = n_dims 9 | 10 | def sample_xs(self): 11 | raise NotImplementedError 12 | 13 | 14 | def get_data_sampler(data_name, n_dims, **kwargs): 15 | names_to_classes = { 16 | "gaussian": GaussianSampler, 17 | } 18 | if data_name in names_to_classes: 19 | sampler_cls = names_to_classes[data_name] 20 | return sampler_cls(n_dims, **kwargs) 21 | else: 22 | print("Unknown sampler") 23 | raise NotImplementedError 24 | 25 | 26 | def sample_transformation(eigenvalues, normalize=False): 27 | n_dims = len(eigenvalues) 28 | U, _, _ = torch.linalg.svd(torch.randn(n_dims, n_dims)) 29 | t = U @ torch.diag(eigenvalues) @ torch.transpose(U, 0, 1) 30 | if normalize: 31 | norm_subspace = torch.sum(eigenvalues**2) 32 | t *= math.sqrt(n_dims / norm_subspace) 33 | return t 34 | 35 | 36 | class GaussianSampler(DataSampler): 37 | def __init__(self, n_dims, bias=None, scale=None): 38 | super().__init__(n_dims) 39 | self.bias = bias 40 | self.scale = scale 41 | 42 | def sample_xs(self, n_points, b_size, n_dims_truncated=None, seeds=None): 43 | if seeds is None: 44 | xs_b = torch.randn(b_size, n_points, self.n_dims) 45 | else: 46 | xs_b = torch.zeros(b_size, n_points, self.n_dims) 47 | generator = torch.Generator() 48 | assert len(seeds) == b_size 49 | for i, seed in enumerate(seeds): 50 | generator.manual_seed(seed) 51 | xs_b[i] = torch.randn(n_points, self.n_dims, generator=generator) 52 | if self.scale is not None: 53 | xs_b = xs_b @ self.scale 54 | if self.bias is not None: 55 | xs_b += self.bias 56 | if n_dims_truncated is not None: 57 | xs_b[:, :, n_dims_truncated:] = 0 58 | return xs_b 59 | -------------------------------------------------------------------------------- /src/schema.py: -------------------------------------------------------------------------------- 1 | from quinine import ( 2 | tstring, 3 | tinteger, 4 | tfloat, 5 | tboolean, 6 | stdict, 7 | tdict, 8 | default, 9 | required, 10 | allowed, 11 | nullable, 12 | ) 13 | from funcy import merge 14 | 15 | 16 | model_schema = { 17 | "family": merge(tstring, allowed(["gpt2", "lstm"])), 18 | "n_positions": merge(tinteger, required), # maximum context length 19 | "n_dims": merge(tinteger, required), # latent dimension 20 | "n_embd": merge(tinteger, required), 21 | "n_layer": merge(tinteger, required), 22 | "n_head": merge(tinteger, required), 23 | } 24 | 25 | curriculum_base_schema = { 26 | "start": merge(tinteger, required), # initial parameter 27 | "end": merge(tinteger, required), # limit of final value 28 | "inc": merge(tinteger, required), # how much to increment each time 29 | "interval": merge(tinteger, required), # increment every how many steps 30 | } 31 | 32 | curriculum_schema = { 33 | "dims": stdict(curriculum_base_schema), 34 | "points": stdict(curriculum_base_schema), 35 | } 36 | 37 | TASK_LIST = [ 38 | "linear_regression", 39 | "sparse_linear_regression", 40 | "linear_classification", 41 | "relu_2nn_regression", 42 | "decision_tree", 43 | ] 44 | 45 | training_schema = { 46 | "task": merge(tstring, allowed(TASK_LIST)), 47 | "task_kwargs": merge(tdict, required), 48 | "num_tasks": merge(tinteger, nullable, default(None)), 49 | "num_training_examples": merge(tinteger, nullable, default(None)), 50 | "data": merge(tstring, allowed(["gaussian"])), 51 | "batch_size": merge(tinteger, default(64)), 52 | "learning_rate": merge(tfloat, default(3e-4)), 53 | "train_steps": merge(tinteger, default(1000)), 54 | "save_every_steps": merge(tinteger, default(1000)), # how often to checkpoint 55 | "keep_every_steps": merge(tinteger, default(-1)), # permanent checkpoints 56 | "resume_id": merge(tstring, nullable, default(None)), # run uuid64 57 | "curriculum": stdict(curriculum_schema), 58 | } 59 | 60 | wandb_schema = { 61 | "project": merge(tstring, default("in-context-training")), 62 | "entity": merge(tstring, default("in-context")), 63 | "notes": merge(tstring, default("")), 64 | "name": merge(tstring, nullable, default(None)), 65 | "log_every_steps": merge(tinteger, default(10)), 66 | } 67 | 68 | schema = { 69 | "out_dir": merge(tstring, required), 70 | "model": stdict(model_schema), 71 | "training": stdict(training_schema), 72 | "wandb": stdict(wandb_schema), 73 | "test_run": merge(tboolean, default(False)), 74 | } 75 | -------------------------------------------------------------------------------- /src/tasks.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | 5 | 6 | def squared_error(ys_pred, ys): 7 | return (ys - ys_pred).square() 8 | 9 | 10 | def mean_squared_error(ys_pred, ys): 11 | return (ys - ys_pred).square().mean() 12 | 13 | 14 | def accuracy(ys_pred, ys): 15 | return (ys == ys_pred.sign()).float() 16 | 17 | 18 | sigmoid = torch.nn.Sigmoid() 19 | bce_loss = torch.nn.BCELoss() 20 | 21 | 22 | def cross_entropy(ys_pred, ys): 23 | output = sigmoid(ys_pred) 24 | target = (ys + 1) / 2 25 | return bce_loss(output, target) 26 | 27 | 28 | class Task: 29 | def __init__(self, n_dims, batch_size, pool_dict=None, seeds=None): 30 | self.n_dims = n_dims 31 | self.b_size = batch_size 32 | self.pool_dict = pool_dict 33 | self.seeds = seeds 34 | assert pool_dict is None or seeds is None 35 | 36 | def evaluate(self, xs): 37 | raise NotImplementedError 38 | 39 | @staticmethod 40 | def generate_pool_dict(n_dims, num_tasks): 41 | raise NotImplementedError 42 | 43 | @staticmethod 44 | def get_metric(): 45 | raise NotImplementedError 46 | 47 | @staticmethod 48 | def get_training_metric(): 49 | raise NotImplementedError 50 | 51 | 52 | def get_task_sampler( 53 | task_name, n_dims, batch_size, pool_dict=None, num_tasks=None, **kwargs 54 | ): 55 | task_names_to_classes = { 56 | "linear_regression": LinearRegression, 57 | "sparse_linear_regression": SparseLinearRegression, 58 | "linear_classification": LinearClassification, 59 | "noisy_linear_regression": NoisyLinearRegression, 60 | "quadratic_regression": QuadraticRegression, 61 | "relu_2nn_regression": Relu2nnRegression, 62 | "decision_tree": DecisionTree, 63 | } 64 | if task_name in task_names_to_classes: 65 | task_cls = task_names_to_classes[task_name] 66 | if num_tasks is not None: 67 | if pool_dict is not None: 68 | raise ValueError("Either pool_dict or num_tasks should be None.") 69 | pool_dict = task_cls.generate_pool_dict(n_dims, num_tasks, **kwargs) 70 | return lambda **args: task_cls(n_dims, batch_size, pool_dict, **args, **kwargs) 71 | else: 72 | print("Unknown task") 73 | raise NotImplementedError 74 | 75 | 76 | class LinearRegression(Task): 77 | def __init__(self, n_dims, batch_size, pool_dict=None, seeds=None, scale=1): 78 | """scale: a constant by which to scale the randomly sampled weights.""" 79 | super(LinearRegression, self).__init__(n_dims, batch_size, pool_dict, seeds) 80 | self.scale = scale 81 | 82 | if pool_dict is None and seeds is None: 83 | self.w_b = torch.randn(self.b_size, self.n_dims, 1) 84 | elif seeds is not None: 85 | self.w_b = torch.zeros(self.b_size, self.n_dims, 1) 86 | generator = torch.Generator() 87 | assert len(seeds) == self.b_size 88 | for i, seed in enumerate(seeds): 89 | generator.manual_seed(seed) 90 | self.w_b[i] = torch.randn(self.n_dims, 1, generator=generator) 91 | else: 92 | assert "w" in pool_dict 93 | indices = torch.randperm(len(pool_dict["w"]))[:batch_size] 94 | self.w_b = pool_dict["w"][indices] 95 | 96 | def evaluate(self, xs_b): 97 | w_b = self.w_b.to(xs_b.device) 98 | ys_b = self.scale * (xs_b @ w_b)[:, :, 0] 99 | return ys_b 100 | 101 | @staticmethod 102 | def generate_pool_dict(n_dims, num_tasks, **kwargs): # ignore extra args 103 | return {"w": torch.randn(num_tasks, n_dims, 1)} 104 | 105 | @staticmethod 106 | def get_metric(): 107 | return squared_error 108 | 109 | @staticmethod 110 | def get_training_metric(): 111 | return mean_squared_error 112 | 113 | 114 | class SparseLinearRegression(LinearRegression): 115 | def __init__( 116 | self, 117 | n_dims, 118 | batch_size, 119 | pool_dict=None, 120 | seeds=None, 121 | scale=1, 122 | sparsity=3, 123 | valid_coords=None, 124 | ): 125 | """scale: a constant by which to scale the randomly sampled weights.""" 126 | super(SparseLinearRegression, self).__init__( 127 | n_dims, batch_size, pool_dict, seeds, scale 128 | ) 129 | self.sparsity = sparsity 130 | if valid_coords is None: 131 | valid_coords = n_dims 132 | assert valid_coords <= n_dims 133 | 134 | for i, w in enumerate(self.w_b): 135 | mask = torch.ones(n_dims).bool() 136 | if seeds is None: 137 | perm = torch.randperm(valid_coords) 138 | else: 139 | generator = torch.Generator() 140 | generator.manual_seed(seeds[i]) 141 | perm = torch.randperm(valid_coords, generator=generator) 142 | mask[perm[:sparsity]] = False 143 | w[mask] = 0 144 | 145 | def evaluate(self, xs_b): 146 | w_b = self.w_b.to(xs_b.device) 147 | ys_b = self.scale * (xs_b @ w_b)[:, :, 0] 148 | return ys_b 149 | 150 | @staticmethod 151 | def get_metric(): 152 | return squared_error 153 | 154 | @staticmethod 155 | def get_training_metric(): 156 | return mean_squared_error 157 | 158 | 159 | class LinearClassification(LinearRegression): 160 | def evaluate(self, xs_b): 161 | ys_b = super().evaluate(xs_b) 162 | return ys_b.sign() 163 | 164 | @staticmethod 165 | def get_metric(): 166 | return accuracy 167 | 168 | @staticmethod 169 | def get_training_metric(): 170 | return cross_entropy 171 | 172 | 173 | class NoisyLinearRegression(LinearRegression): 174 | def __init__( 175 | self, 176 | n_dims, 177 | batch_size, 178 | pool_dict=None, 179 | seeds=None, 180 | scale=1, 181 | noise_std=0, 182 | renormalize_ys=False, 183 | ): 184 | """noise_std: standard deviation of noise added to the prediction.""" 185 | super(NoisyLinearRegression, self).__init__( 186 | n_dims, batch_size, pool_dict, seeds, scale 187 | ) 188 | self.noise_std = noise_std 189 | self.renormalize_ys = renormalize_ys 190 | 191 | def evaluate(self, xs_b): 192 | ys_b = super().evaluate(xs_b) 193 | ys_b_noisy = ys_b + torch.randn_like(ys_b) * self.noise_std 194 | if self.renormalize_ys: 195 | ys_b_noisy = ys_b_noisy * math.sqrt(self.n_dims) / ys_b_noisy.std() 196 | 197 | return ys_b_noisy 198 | 199 | 200 | class QuadraticRegression(LinearRegression): 201 | def evaluate(self, xs_b): 202 | w_b = self.w_b.to(xs_b.device) 203 | ys_b_quad = ((xs_b**2) @ w_b)[:, :, 0] 204 | # ys_b_quad = ys_b_quad * math.sqrt(self.n_dims) / ys_b_quad.std() 205 | # Renormalize to Linear Regression Scale 206 | ys_b_quad = ys_b_quad / math.sqrt(3) 207 | ys_b_quad = self.scale * ys_b_quad 208 | return ys_b_quad 209 | 210 | 211 | class Relu2nnRegression(Task): 212 | def __init__( 213 | self, 214 | n_dims, 215 | batch_size, 216 | pool_dict=None, 217 | seeds=None, 218 | scale=1, 219 | hidden_layer_size=100, 220 | ): 221 | """scale: a constant by which to scale the randomly sampled weights.""" 222 | super(Relu2nnRegression, self).__init__(n_dims, batch_size, pool_dict, seeds) 223 | self.scale = scale 224 | self.hidden_layer_size = hidden_layer_size 225 | 226 | if pool_dict is None and seeds is None: 227 | self.W1 = torch.randn(self.b_size, self.n_dims, hidden_layer_size) 228 | self.W2 = torch.randn(self.b_size, hidden_layer_size, 1) 229 | elif seeds is not None: 230 | self.W1 = torch.zeros(self.b_size, self.n_dims, hidden_layer_size) 231 | self.W2 = torch.zeros(self.b_size, hidden_layer_size, 1) 232 | generator = torch.Generator() 233 | assert len(seeds) == self.b_size 234 | for i, seed in enumerate(seeds): 235 | generator.manual_seed(seed) 236 | self.W1[i] = torch.randn( 237 | self.n_dims, hidden_layer_size, generator=generator 238 | ) 239 | self.W2[i] = torch.randn(hidden_layer_size, 1, generator=generator) 240 | else: 241 | assert "W1" in pool_dict and "W2" in pool_dict 242 | assert len(pool_dict["W1"]) == len(pool_dict["W2"]) 243 | indices = torch.randperm(len(pool_dict["W1"]))[:batch_size] 244 | self.W1 = pool_dict["W1"][indices] 245 | self.W2 = pool_dict["W2"][indices] 246 | 247 | def evaluate(self, xs_b): 248 | W1 = self.W1.to(xs_b.device) 249 | W2 = self.W2.to(xs_b.device) 250 | # Renormalize to Linear Regression Scale 251 | ys_b_nn = (torch.nn.functional.relu(xs_b @ W1) @ W2)[:, :, 0] 252 | ys_b_nn = ys_b_nn * math.sqrt(2 / self.hidden_layer_size) 253 | ys_b_nn = self.scale * ys_b_nn 254 | # ys_b_nn = ys_b_nn * math.sqrt(self.n_dims) / ys_b_nn.std() 255 | return ys_b_nn 256 | 257 | @staticmethod 258 | def generate_pool_dict(n_dims, num_tasks, hidden_layer_size=4, **kwargs): 259 | return { 260 | "W1": torch.randn(num_tasks, n_dims, hidden_layer_size), 261 | "W2": torch.randn(num_tasks, hidden_layer_size, 1), 262 | } 263 | 264 | @staticmethod 265 | def get_metric(): 266 | return squared_error 267 | 268 | @staticmethod 269 | def get_training_metric(): 270 | return mean_squared_error 271 | 272 | 273 | class DecisionTree(Task): 274 | def __init__(self, n_dims, batch_size, pool_dict=None, seeds=None, depth=4): 275 | 276 | super(DecisionTree, self).__init__(n_dims, batch_size, pool_dict, seeds) 277 | self.depth = depth 278 | 279 | if pool_dict is None: 280 | 281 | # We represent the tree using an array (tensor). Root node is at index 0, its 2 children at index 1 and 2... 282 | # dt_tensor stores the coordinate used at each node of the decision tree. 283 | # Only indices corresponding to non-leaf nodes are relevant 284 | self.dt_tensor = torch.randint( 285 | low=0, high=n_dims, size=(batch_size, 2 ** (depth + 1) - 1) 286 | ) 287 | 288 | # Target value at the leaf nodes. 289 | # Only indices corresponding to leaf nodes are relevant. 290 | self.target_tensor = torch.randn(self.dt_tensor.shape) 291 | elif seeds is not None: 292 | self.dt_tensor = torch.zeros(batch_size, 2 ** (depth + 1) - 1) 293 | self.target_tensor = torch.zeros_like(dt_tensor) 294 | generator = torch.Generator() 295 | assert len(seeds) == self.b_size 296 | for i, seed in enumerate(seeds): 297 | generator.manual_seed(seed) 298 | self.dt_tensor[i] = torch.randint( 299 | low=0, 300 | high=n_dims - 1, 301 | size=2 ** (depth + 1) - 1, 302 | generator=generator, 303 | ) 304 | self.target_tensor[i] = torch.randn( 305 | self.dt_tensor[i].shape, generator=generator 306 | ) 307 | else: 308 | raise NotImplementedError 309 | 310 | def evaluate(self, xs_b): 311 | dt_tensor = self.dt_tensor.to(xs_b.device) 312 | target_tensor = self.target_tensor.to(xs_b.device) 313 | ys_b = torch.zeros(xs_b.shape[0], xs_b.shape[1], device=xs_b.device) 314 | for i in range(xs_b.shape[0]): 315 | xs_bool = xs_b[i] > 0 316 | # If a single decision tree present, use it for all the xs in the batch. 317 | if self.b_size == 1: 318 | dt = dt_tensor[0] 319 | target = target_tensor[0] 320 | else: 321 | dt = dt_tensor[i] 322 | target = target_tensor[i] 323 | 324 | cur_nodes = torch.zeros(xs_b.shape[1], device=xs_b.device).long() 325 | for j in range(self.depth): 326 | cur_coords = dt[cur_nodes] 327 | cur_decisions = xs_bool[torch.arange(xs_bool.shape[0]), cur_coords] 328 | cur_nodes = 2 * cur_nodes + 1 + cur_decisions 329 | 330 | ys_b[i] = target[cur_nodes] 331 | 332 | return ys_b 333 | 334 | @staticmethod 335 | def generate_pool_dict(n_dims, num_tasks, hidden_layer_size=4, **kwargs): 336 | raise NotImplementedError 337 | 338 | @staticmethod 339 | def get_metric(): 340 | return squared_error 341 | 342 | @staticmethod 343 | def get_training_metric(): 344 | return mean_squared_error 345 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | from random import randint 3 | import uuid 4 | 5 | from quinine import QuinineArgumentParser 6 | from tqdm import tqdm 7 | import torch 8 | import yaml 9 | 10 | from eval import get_run_metrics 11 | from tasks import get_task_sampler 12 | from samplers import get_data_sampler 13 | from curriculum import Curriculum 14 | from schema import schema 15 | from models import build_model 16 | 17 | import wandb 18 | 19 | torch.backends.cudnn.benchmark = True 20 | 21 | 22 | def train_step(model, xs, ys, optimizer, loss_func): 23 | optimizer.zero_grad() 24 | output = model(xs, ys) 25 | loss = loss_func(output, ys) 26 | loss.backward() 27 | optimizer.step() 28 | return loss.detach().item(), output.detach() 29 | 30 | 31 | def sample_seeds(total_seeds, count): 32 | seeds = set() 33 | while len(seeds) < count: 34 | seeds.add(randint(0, total_seeds - 1)) 35 | return seeds 36 | 37 | 38 | def train(model, args): 39 | optimizer = torch.optim.Adam(model.parameters(), lr=args.training.learning_rate) 40 | curriculum = Curriculum(args.training.curriculum) 41 | 42 | starting_step = 0 43 | state_path = os.path.join(args.out_dir, "state.pt") 44 | if os.path.exists(state_path): 45 | state = torch.load(state_path) 46 | model.load_state_dict(state["model_state_dict"]) 47 | optimizer.load_state_dict(state["optimizer_state_dict"]) 48 | starting_step = state["train_step"] 49 | for i in range(state["train_step"] + 1): 50 | curriculum.update() 51 | 52 | n_dims = model.n_dims 53 | bsize = args.training.batch_size 54 | data_sampler = get_data_sampler(args.training.data, n_dims=n_dims) 55 | task_sampler = get_task_sampler( 56 | args.training.task, 57 | n_dims, 58 | bsize, 59 | num_tasks=args.training.num_tasks, 60 | **args.training.task_kwargs, 61 | ) 62 | pbar = tqdm(range(starting_step, args.training.train_steps)) 63 | 64 | num_training_examples = args.training.num_training_examples 65 | 66 | for i in pbar: 67 | data_sampler_args = {} 68 | task_sampler_args = {} 69 | 70 | if "sparse" in args.training.task: 71 | task_sampler_args["valid_coords"] = curriculum.n_dims_truncated 72 | if num_training_examples is not None: 73 | assert num_training_examples >= bsize 74 | seeds = sample_seeds(num_training_examples, bsize) 75 | data_sampler_args["seeds"] = seeds 76 | task_sampler_args["seeds"] = [s + 1 for s in seeds] 77 | 78 | xs = data_sampler.sample_xs( 79 | curriculum.n_points, 80 | bsize, 81 | curriculum.n_dims_truncated, 82 | **data_sampler_args, 83 | ) 84 | task = task_sampler(**task_sampler_args) 85 | ys = task.evaluate(xs) 86 | 87 | loss_func = task.get_training_metric() 88 | 89 | loss, output = train_step(model, xs.cuda(), ys.cuda(), optimizer, loss_func) 90 | 91 | point_wise_tags = list(range(curriculum.n_points)) 92 | point_wise_loss_func = task.get_metric() 93 | point_wise_loss = point_wise_loss_func(output, ys.cuda()).mean(dim=0) 94 | 95 | baseline_loss = ( 96 | sum( 97 | max(curriculum.n_dims_truncated - ii, 0) 98 | for ii in range(curriculum.n_points) 99 | ) 100 | / curriculum.n_points 101 | ) 102 | 103 | if i % args.wandb.log_every_steps == 0 and not args.test_run: 104 | wandb.log( 105 | { 106 | "overall_loss": loss, 107 | "excess_loss": loss / baseline_loss, 108 | "pointwise/loss": dict( 109 | zip(point_wise_tags, point_wise_loss.cpu().numpy()) 110 | ), 111 | "n_points": curriculum.n_points, 112 | "n_dims": curriculum.n_dims_truncated, 113 | }, 114 | step=i, 115 | ) 116 | 117 | curriculum.update() 118 | 119 | pbar.set_description(f"loss {loss}") 120 | if i % args.training.save_every_steps == 0 and not args.test_run: 121 | training_state = { 122 | "model_state_dict": model.state_dict(), 123 | "optimizer_state_dict": optimizer.state_dict(), 124 | "train_step": i, 125 | } 126 | torch.save(training_state, state_path) 127 | 128 | if ( 129 | args.training.keep_every_steps > 0 130 | and i % args.training.keep_every_steps == 0 131 | and not args.test_run 132 | and i > 0 133 | ): 134 | torch.save(model.state_dict(), os.path.join(args.out_dir, f"model_{i}.pt")) 135 | 136 | 137 | def main(args): 138 | if args.test_run: 139 | curriculum_args = args.training.curriculum 140 | curriculum_args.points.start = curriculum_args.points.end 141 | curriculum_args.dims.start = curriculum_args.dims.end 142 | args.training.train_steps = 100 143 | else: 144 | wandb.init( 145 | dir=args.out_dir, 146 | project=args.wandb.project, 147 | entity=args.wandb.entity, 148 | config=args.__dict__, 149 | notes=args.wandb.notes, 150 | name=args.wandb.name, 151 | resume=True, 152 | ) 153 | 154 | model = build_model(args.model) 155 | model.cuda() 156 | model.train() 157 | 158 | train(model, args) 159 | 160 | if not args.test_run: 161 | _ = get_run_metrics(args.out_dir) # precompute metrics for eval 162 | 163 | 164 | if __name__ == "__main__": 165 | parser = QuinineArgumentParser(schema=schema) 166 | args = parser.parse_quinfig() 167 | assert args.model.family in ["gpt2", "lstm"] 168 | print(f"Running with: {args}") 169 | 170 | if not args.test_run: 171 | run_id = args.training.resume_id 172 | if run_id is None: 173 | run_id = str(uuid.uuid4()) 174 | 175 | out_dir = os.path.join(args.out_dir, run_id) 176 | if not os.path.exists(out_dir): 177 | os.makedirs(out_dir) 178 | args.out_dir = out_dir 179 | 180 | with open(os.path.join(out_dir, "config.yaml"), "w") as yaml_file: 181 | yaml.dump(args.__dict__, yaml_file, default_flow_style=False) 182 | 183 | main(args) 184 | --------------------------------------------------------------------------------