├── .gitignore
├── LICENSE
├── README.md
├── environment.yml
├── setting.jpg
└── src
├── base_models.py
├── conf
├── base.yaml
├── decision_tree.yaml
├── linear_regression.yaml
├── models
│ ├── small.yaml
│ ├── standard.yaml
│ └── tiny.yaml
├── relu_2nn_regression.yaml
├── sparse_linear_regression.yaml
├── toy.yaml
└── wandb.yaml
├── curriculum.py
├── eval.ipynb
├── eval.py
├── models.py
├── plot_utils.py
├── samplers.py
├── schema.py
├── tasks.py
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .ipynb_checkpoints
2 | __pycache__
3 | models.zip
4 | models
5 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Dimitris Tsipras, Shivam Garg
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | This repository contains the code and models for our paper:
2 |
3 | **What Can Transformers Learn In-Context? A Case Study of Simple Function Classes**
4 | *Shivam Garg\*, Dimitris Tsipras\*, Percy Liang, Gregory Valiant*
5 | Paper: http://arxiv.org/abs/2208.01066
6 |
7 | 
8 |
9 | ```bibtex
10 | @InProceedings{garg2022what,
11 | title={What Can Transformers Learn In-Context? A Case Study of Simple Function Classes},
12 | author={Shivam Garg and Dimitris Tsipras and Percy Liang and Gregory Valiant},
13 | year={2022},
14 | booktitle={arXiv preprint}
15 | }
16 | ```
17 |
18 | ## Getting started
19 | You can start by cloning our repository and following the steps below.
20 |
21 | 1. Install the dependencies for our code using Conda. You may need to adjust the environment YAML file depending on your setup.
22 |
23 | ```
24 | conda env create -f environment.yml
25 | conda activate in-context-learning
26 | ```
27 |
28 | 2. Download [model checkpoints](https://github.com/dtsip/in-context-learning/releases/download/initial/models.zip) and extract them in the current directory.
29 |
30 | ```
31 | wget https://github.com/dtsip/in-context-learning/releases/download/initial/models.zip
32 | unzip models.zip
33 | ```
34 |
35 | 3. [Optional] If you plan to train, populate `conf/wandb.yaml` with you wandb info.
36 |
37 | That's it! You can now explore our pre-trained models or train your own. The key entry points
38 | are as follows (starting from `src`):
39 | - The `eval.ipynb` notebook contains code to load our own pre-trained models, plot the pre-computed metrics, and evaluate them on new data.
40 | - `train.py` takes as argument a configuration yaml from `conf` and trains the corresponding model. You can try `python train.py --config conf/toy.yaml` for a quick training run.
41 |
42 | # Maintainers
43 | * [Shivam Garg](https://cs.stanford.edu/~shivamg/)
44 | * [Dimitris Tsipras](https://dtsipras.com/)
45 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: in-context-learning
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - pip=21.2.4
7 | - python=3.8.12
8 | - pytorch=1.11.0
9 | - pip:
10 | - jupyter==1.0.0
11 | - matplotlib==3.5.2
12 | - numpy==1.22.3
13 | - pandas==1.4.2
14 | - quinine==0.3.0
15 | - scikit-learn==1.0.2
16 | - seaborn==0.11.2
17 | - tqdm==4.64.0
18 | - transformers==4.17.0
19 | - wandb==0.12.11
20 | - xgboost==1.6.1
21 | - protobuf==3.20.1
22 |
--------------------------------------------------------------------------------
/setting.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dtsip/in-context-learning/338337192195827c58356f9b627e5f999ea703d2/setting.jpg
--------------------------------------------------------------------------------
/src/base_models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class NeuralNetwork(nn.Module):
6 | def __init__(self, in_size=50, hidden_size=1000, out_size=1):
7 | super(NeuralNetwork, self).__init__()
8 |
9 | self.net = nn.Sequential(
10 | nn.Linear(in_size, hidden_size),
11 | nn.ReLU(),
12 | nn.Linear(hidden_size, out_size),
13 | )
14 |
15 | def forward(self, x):
16 | out = self.net(x)
17 | return out
18 |
19 |
20 | class ParallelNetworks(nn.Module):
21 | def __init__(self, num_models, model_class, **model_class_init_args):
22 | super(ParallelNetworks, self).__init__()
23 | self.nets = nn.ModuleList(
24 | [model_class(**model_class_init_args) for i in range(num_models)]
25 | )
26 |
27 | def forward(self, xs):
28 | assert xs.shape[0] == len(self.nets)
29 |
30 | for i in range(len(self.nets)):
31 | out = self.nets[i](xs[i])
32 | if i == 0:
33 | outs = torch.zeros(
34 | [len(self.nets)] + list(out.shape), device=out.device
35 | )
36 | outs[i] = out
37 | return outs
38 |
--------------------------------------------------------------------------------
/src/conf/base.yaml:
--------------------------------------------------------------------------------
1 | inherit:
2 | - models/standard.yaml
3 | - wandb.yaml
4 |
5 | model:
6 | n_dims: 20
7 | n_positions: 101
8 |
9 | training:
10 | data: gaussian
11 | task_kwargs: {}
12 | batch_size: 64
13 | learning_rate: 0.0001
14 | save_every_steps: 1000
15 | keep_every_steps: 100000
16 | train_steps: 500001
17 | curriculum:
18 | dims:
19 | start: 5
20 | end: 20
21 | inc: 1
22 | interval: 2000
23 |
--------------------------------------------------------------------------------
/src/conf/decision_tree.yaml:
--------------------------------------------------------------------------------
1 | inherit:
2 | - base.yaml
3 |
4 | training:
5 | task: decision_tree
6 | task_kwargs: {"depth": 4}
7 | curriculum:
8 | points:
9 | start: 26
10 | end: 101
11 | inc: 5
12 | interval: 2000
13 |
14 | out_dir: ../models/decision_tree
15 |
16 | wandb:
17 | name: "decision_tree_standard"
18 |
--------------------------------------------------------------------------------
/src/conf/linear_regression.yaml:
--------------------------------------------------------------------------------
1 | inherit:
2 | - base.yaml
3 |
4 | training:
5 | task: linear_regression
6 | curriculum:
7 | points:
8 | start: 11
9 | end: 41
10 | inc: 2
11 | interval: 2000
12 |
13 | out_dir: ../models/linear_regression
14 |
15 | wandb:
16 | name: "linear_regression_standard"
17 |
--------------------------------------------------------------------------------
/src/conf/models/small.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | family: gpt2
3 | n_embd: 128
4 | n_layer: 6
5 | n_head: 4
6 |
--------------------------------------------------------------------------------
/src/conf/models/standard.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | family: gpt2
3 | n_embd: 256
4 | n_layer: 12
5 | n_head: 8
6 |
--------------------------------------------------------------------------------
/src/conf/models/tiny.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | family: gpt2
3 | n_embd: 64
4 | n_layer: 3
5 | n_head: 2
6 |
--------------------------------------------------------------------------------
/src/conf/relu_2nn_regression.yaml:
--------------------------------------------------------------------------------
1 | inherit:
2 | - base.yaml
3 |
4 | training:
5 | task: relu_2nn_regression
6 | task_kwargs: {"hidden_layer_size": 100}
7 | curriculum:
8 | points:
9 | start: 26
10 | end: 101
11 | inc: 5
12 | interval: 2000
13 |
14 | out_dir: ../models/relu_2nn_regression
15 |
16 | wandb:
17 | name: "relu_2nn_regression_standard"
18 |
--------------------------------------------------------------------------------
/src/conf/sparse_linear_regression.yaml:
--------------------------------------------------------------------------------
1 | inherit:
2 | - base.yaml
3 |
4 | training:
5 | task: sparse_linear_regression
6 | task_kwargs: {"sparsity": 3}
7 | curriculum:
8 | points:
9 | start: 11
10 | end: 41
11 | inc: 2
12 | interval: 2000
13 |
14 | out_dir: ../models/sparse_linear_regression
15 |
16 | wandb:
17 | name: "sparse_regression_standard"
18 |
--------------------------------------------------------------------------------
/src/conf/toy.yaml:
--------------------------------------------------------------------------------
1 | inherit:
2 | - models/standard.yaml
3 | - wandb.yaml
4 |
5 | model:
6 | n_dims: 5
7 | n_positions: 11
8 |
9 | training:
10 | task: linear_regression
11 | data: gaussian
12 | task_kwargs: {}
13 | batch_size: 64
14 | learning_rate: 0.0001
15 | save_every_steps: 1000
16 | keep_every_steps: 100000
17 | train_steps: 5001
18 | curriculum:
19 | dims:
20 | start: 5
21 | end: 5
22 | inc: 1
23 | interval: 2000
24 | points:
25 | start: 11
26 | end: 11
27 | inc: 2
28 | interval: 2000
29 |
30 | out_dir: ../models/linear_regression
31 |
32 | wandb:
33 | name: "linear_regression_toy"
34 |
--------------------------------------------------------------------------------
/src/conf/wandb.yaml:
--------------------------------------------------------------------------------
1 | wandb:
2 | project: in-context-training
3 | entity: your-entity
4 | notes:
5 | log_every_steps: 100
6 |
--------------------------------------------------------------------------------
/src/curriculum.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 |
4 | class Curriculum:
5 | def __init__(self, args):
6 | # args.dims and args.points each contain start, end, inc, interval attributes
7 | # inc denotes the change in n_dims,
8 | # this change is done every interval,
9 | # and start/end are the limits of the parameter
10 | self.n_dims_truncated = args.dims.start
11 | self.n_points = args.points.start
12 | self.n_dims_schedule = args.dims
13 | self.n_points_schedule = args.points
14 | self.step_count = 0
15 |
16 | def update(self):
17 | self.step_count += 1
18 | self.n_dims_truncated = self.update_var(
19 | self.n_dims_truncated, self.n_dims_schedule
20 | )
21 | self.n_points = self.update_var(self.n_points, self.n_points_schedule)
22 |
23 | def update_var(self, var, schedule):
24 | if self.step_count % schedule.interval == 0:
25 | var += schedule.inc
26 |
27 | return min(var, schedule.end)
28 |
29 |
30 | # returns the final value of var after applying curriculum.
31 | def get_final_var(init_var, total_steps, inc, n_steps, lim):
32 | final_var = init_var + math.floor((total_steps) / n_steps) * inc
33 |
34 | return min(final_var, lim)
35 |
--------------------------------------------------------------------------------
/src/eval.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import sys
4 |
5 | from munch import Munch
6 | import numpy as np
7 | import pandas as pd
8 | from tqdm import tqdm
9 | import torch
10 | import yaml
11 |
12 | import models
13 | from samplers import get_data_sampler, sample_transformation
14 | from tasks import get_task_sampler
15 |
16 |
17 | def get_model_from_run(run_path, step=-1, only_conf=False):
18 | config_path = os.path.join(run_path, "config.yaml")
19 | with open(config_path) as fp: # we don't Quinfig it to avoid inherits
20 | conf = Munch.fromDict(yaml.safe_load(fp))
21 | if only_conf:
22 | return None, conf
23 |
24 | model = models.build_model(conf.model)
25 |
26 | if step == -1:
27 | state_path = os.path.join(run_path, "state.pt")
28 | state = torch.load(state_path)
29 | model.load_state_dict(state["model_state_dict"])
30 | else:
31 | model_path = os.path.join(run_path, f"model_{step}.pt")
32 | state_dict = torch.load(model_path)
33 | model.load_state_dict(state_dict)
34 |
35 | return model, conf
36 |
37 |
38 | # Functions for evaluation
39 |
40 |
41 | def eval_batch(model, task_sampler, xs, xs_p=None):
42 | task = task_sampler()
43 | if torch.cuda.is_available() and model.name.split("_")[0] in ["gpt2", "lstm"]:
44 | device = "cuda"
45 | else:
46 | device = "cpu"
47 |
48 | if xs_p is None:
49 | ys = task.evaluate(xs)
50 | pred = model(xs.to(device), ys.to(device)).detach()
51 | metrics = task.get_metric()(pred.cpu(), ys)
52 | else:
53 | b_size, n_points, _ = xs.shape
54 | metrics = torch.zeros(b_size, n_points)
55 | for i in range(n_points):
56 | xs_comb = torch.cat((xs[:, :i, :], xs_p[:, i:, :]), dim=1)
57 | ys = task.evaluate(xs_comb)
58 |
59 | pred = model(xs_comb.to(device), ys.to(device), inds=[i]).detach()
60 | metrics[:, i] = task.get_metric()(pred.cpu(), ys)[:, i]
61 |
62 | return metrics
63 |
64 |
65 | # Functions for generating different kinds of train/test data
66 |
67 |
68 | def gen_standard(data_sampler, n_points, b_size):
69 | xs = data_sampler.sample_xs(n_points, b_size)
70 |
71 | return xs, None
72 |
73 |
74 | def gen_opposite_quadrants(data_sampler, n_points, b_size):
75 | xs = data_sampler.sample_xs(n_points, b_size)
76 | pattern = torch.randn([b_size, 1, xs.shape[2]]).sign()
77 |
78 | xs_train_pre = xs.abs() * pattern
79 | xs_test_post = -xs_train_pre
80 |
81 | return xs_train_pre, xs_test_post
82 |
83 |
84 | def gen_random_quadrants(data_sampler, n_points, b_size):
85 | xs = data_sampler.sample_xs(n_points, b_size)
86 | pattern = torch.randn([b_size, 1, xs.shape[2]]).sign()
87 |
88 | xs_train_pre = xs.abs() * pattern
89 | xs_test_post = xs
90 |
91 | return xs_train_pre, xs_test_post
92 |
93 |
94 | def gen_orthogonal_train_test(data_sampler, n_points, b_size):
95 | xs = data_sampler.sample_xs(n_points, b_size)
96 | n_dim = xs.shape[2]
97 | n_points = min(n_points, n_dim)
98 | # raise ValueError("number of points should be at most the dimension.")
99 | xs_train_pre = xs
100 | xs_test_post = torch.zeros(xs.shape)
101 | for i in range(n_points):
102 | xs_test_post_i = xs[:, i : i + 1, :]
103 | xs_train_pre_i = xs[:, :i, :]
104 | _, _, Vt = torch.linalg.svd(xs_train_pre_i, full_matrices=False)
105 | xs_train_pre_i_projection = Vt.transpose(1, 2) @ Vt
106 | xs_test_post_i_orthogonalized = (
107 | xs_test_post_i - xs_test_post_i @ xs_train_pre_i_projection
108 | )
109 | xs_test_post_i_normalized = (
110 | xs_test_post_i_orthogonalized
111 | * xs_test_post_i.norm(dim=2).unsqueeze(2)
112 | / xs_test_post_i_orthogonalized.norm(dim=2).unsqueeze(2)
113 | )
114 |
115 | xs_test_post[:, i : i + 1, :] = xs_test_post_i_normalized
116 |
117 | return xs_train_pre, xs_test_post
118 |
119 |
120 | def gen_overlapping_train_test(data_sampler, n_points, b_size):
121 | xs = data_sampler.sample_xs(n_points, b_size)
122 | xs_train_pre = xs
123 | xs_test_post = xs.clone()
124 | b_size = xs.shape[0]
125 | for i in range(1, n_points):
126 | xs_train_pre_i = xs[:, :i, :]
127 | perm = torch.stack([torch.randperm(i) for _ in range(b_size)]).unsqueeze(dim=1)
128 | ind_mat = (perm == 0) + 0.0
129 | xs_test_post[:, i : i + 1, :] = ind_mat @ xs_train_pre_i
130 |
131 | return xs_train_pre, xs_test_post
132 |
133 |
134 | def aggregate_metrics(metrics, bootstrap_trials=1000):
135 | """
136 | Takes as input a tensor of shape (num_eval, n_points) and returns a dict with
137 | per-point mean, stddev, and bootstrap limits
138 | """
139 | results = {}
140 | results["mean"] = metrics.mean(dim=0)
141 | results["std"] = metrics.std(dim=0, unbiased=True)
142 | n = len(metrics)
143 | bootstrap_indices = torch.randint(n, size=(bootstrap_trials, n))
144 | bootstrap_means = metrics[bootstrap_indices].mean(dim=1).sort(dim=0)[0]
145 | results["bootstrap_low"] = bootstrap_means[int(0.05 * bootstrap_trials), :]
146 | results["bootstrap_high"] = bootstrap_means[int(0.95 * bootstrap_trials), :]
147 |
148 | return {k: v.tolist() for k, v in results.items()}
149 |
150 |
151 | def eval_model(
152 | model,
153 | task_name,
154 | data_name,
155 | n_dims,
156 | n_points,
157 | prompting_strategy,
158 | num_eval_examples=1280,
159 | batch_size=64,
160 | data_sampler_kwargs={},
161 | task_sampler_kwargs={},
162 | ):
163 | """
164 | Evaluate a model on a task with a variety of strategies.
165 | Args:
166 | - task: which base task we are evaluating on. E.g., "linear_regression"
167 | - prompting_strategy: how to construct the prompt, e.g., "random_quadrants"
168 | - num_eval_examples: total number of examples to evaluate on
169 | - **sampler_kwargs: remaining arguments to pass directly to the sampler
170 | """
171 |
172 | assert num_eval_examples % batch_size == 0
173 | data_sampler = get_data_sampler(data_name, n_dims, **data_sampler_kwargs)
174 | task_sampler = get_task_sampler(
175 | task_name, n_dims, batch_size, **task_sampler_kwargs
176 | )
177 |
178 | all_metrics = []
179 |
180 | generating_func = globals()[f"gen_{prompting_strategy}"]
181 | for i in range(num_eval_examples // batch_size):
182 | xs, xs_p = generating_func(data_sampler, n_points, batch_size)
183 |
184 | metrics = eval_batch(model, task_sampler, xs, xs_p)
185 | all_metrics.append(metrics)
186 |
187 | metrics = torch.cat(all_metrics, dim=0)
188 |
189 | return aggregate_metrics(metrics)
190 |
191 |
192 | def build_evals(conf):
193 | n_dims = conf.model.n_dims
194 | n_points = conf.training.curriculum.points.end
195 | batch_size = conf.training.batch_size
196 |
197 | task_name = conf.training.task
198 | data_name = conf.training.data
199 |
200 | base_kwargs = {
201 | "task_name": task_name,
202 | "n_dims": n_dims,
203 | "n_points": n_points,
204 | "batch_size": batch_size,
205 | "data_name": data_name,
206 | "prompting_strategy": "standard",
207 | }
208 |
209 | evaluation_kwargs = {}
210 |
211 | evaluation_kwargs["standard"] = {"prompting_strategy": "standard"}
212 | if task_name != "linear_regression":
213 | if task_name in ["relu_2nn_regression"]:
214 | evaluation_kwargs["linear_regression"] = {"task_name": "linear_regression"}
215 | for name, kwargs in evaluation_kwargs.items():
216 | # allow kwargs to override base_kwargs values
217 | evaluation_kwargs[name] = base_kwargs.copy()
218 | evaluation_kwargs[name].update(kwargs)
219 | return evaluation_kwargs
220 |
221 | for strategy in [
222 | "random_quadrants",
223 | "orthogonal_train_test",
224 | "overlapping_train_test",
225 | ]:
226 | evaluation_kwargs[strategy] = {"prompting_strategy": strategy}
227 |
228 | for method in ["half_subspace", "skewed"]:
229 | if "subspace" in method:
230 | eigenvals = torch.zeros(n_dims)
231 | eigenvals[: n_dims // 2] = 1
232 | else:
233 | eigenvals = 1 / (torch.arange(n_dims) + 1)
234 |
235 | scale = sample_transformation(eigenvals, normalize=True)
236 | evaluation_kwargs[f"{method}"] = {
237 | "data_sampler_kwargs": {"scale": scale},
238 | }
239 |
240 | for dim in ["x", "y"]:
241 | for scale in [0.333, 0.5, 2, 3]:
242 | if dim == "x":
243 | eigenvals = scale * torch.ones(n_dims)
244 | t = sample_transformation(eigenvals)
245 | scaling_args = {"data_sampler_kwargs": {"scale": t}}
246 | else:
247 | eigenvals = scale * torch.ones(n_dims)
248 | scaling_args = {"task_sampler_kwargs": {"scale": scale}}
249 |
250 | evaluation_kwargs[f"scale-{dim}={scale}"] = scaling_args
251 |
252 | evaluation_kwargs[f"noisyLR"] = {
253 | "task_sampler_kwargs": {"renormalize_ys": True, "noise_std": 1},
254 | "task_name": "noisy_linear_regression",
255 | }
256 |
257 | for name, kwargs in evaluation_kwargs.items():
258 | # allow kwargs to override base_kwargs values
259 | evaluation_kwargs[name] = base_kwargs.copy()
260 | evaluation_kwargs[name].update(kwargs)
261 |
262 | return evaluation_kwargs
263 |
264 |
265 | def compute_evals(all_models, evaluation_kwargs, save_path=None, recompute=False):
266 | try:
267 | with open(save_path) as fp:
268 | all_metrics = json.load(fp)
269 | except Exception:
270 | all_metrics = {}
271 |
272 | for eval_name, kwargs in tqdm(evaluation_kwargs.items()):
273 | metrics = {}
274 | if eval_name in all_metrics and not recompute:
275 | metrics = all_metrics[eval_name]
276 | for model in all_models:
277 | if model.name in metrics and not recompute:
278 | continue
279 |
280 | metrics[model.name] = eval_model(model, **kwargs)
281 | all_metrics[eval_name] = metrics
282 |
283 | if save_path is not None:
284 | with open(save_path, "w") as fp:
285 | json.dump(all_metrics, fp, indent=2)
286 |
287 | return all_metrics
288 |
289 |
290 | def get_run_metrics(
291 | run_path, step=-1, cache=True, skip_model_load=False, skip_baselines=False
292 | ):
293 | if skip_model_load:
294 | _, conf = get_model_from_run(run_path, only_conf=True)
295 | all_models = []
296 | else:
297 | model, conf = get_model_from_run(run_path, step)
298 | model = model.cuda().eval()
299 | all_models = [model]
300 | if not skip_baselines:
301 | all_models += models.get_relevant_baselines(conf.training.task)
302 | evaluation_kwargs = build_evals(conf)
303 |
304 | if not cache:
305 | save_path = None
306 | elif step == -1:
307 | save_path = os.path.join(run_path, "metrics.json")
308 | else:
309 | save_path = os.path.join(run_path, f"metrics_{step}.json")
310 |
311 | recompute = False
312 | if save_path is not None and os.path.exists(save_path):
313 | checkpoint_created = os.path.getmtime(run_path)
314 | cache_created = os.path.getmtime(save_path)
315 | if checkpoint_created > cache_created:
316 | recompute = True
317 |
318 | all_metrics = compute_evals(all_models, evaluation_kwargs, save_path, recompute)
319 | return all_metrics
320 |
321 |
322 |
323 | def conf_to_model_name(conf):
324 | if conf.model.family == "gpt2":
325 | return {
326 | (3, 2): "Transformer-xs",
327 | (6, 4): "Transformer-small",
328 | (12, 8): "Transformer",
329 | }[(conf.model.n_layer, conf.model.n_head)]
330 | else:
331 | return conf.wandb.name
332 |
333 |
334 | def baseline_names(name):
335 | if "OLS" in name:
336 | return "Least Squares"
337 | if name == "averaging":
338 | return "Averaging"
339 | if "NN" in name:
340 | k = name.split("_")[1].split("=")[1]
341 | return f"{k}-Nearest Neighbors"
342 | if "lasso" in name:
343 | alpha = name.split("_")[1].split("=")[1]
344 | return f"Lasso (alpha={alpha})"
345 | if "gd" in name:
346 | return "2-layer NN, GD"
347 | if "decision_tree" in name:
348 | return "Greedy Tree Learning"
349 | if "xgboost" in name:
350 | return "XGBoost"
351 | return name
352 |
353 |
354 | def read_run_dir(run_dir):
355 | all_runs = {}
356 | for task in os.listdir(run_dir):
357 | task_dir = os.path.join(run_dir, task)
358 | for run_id in os.listdir(task_dir):
359 | run_path = os.path.join(task_dir, run_id)
360 | _, conf = get_model_from_run(run_path, only_conf=True)
361 | params = {}
362 | params["run_id"] = run_id
363 | params["task"] = task
364 | params["model"] = conf_to_model_name(conf)
365 | params["kwargs"] = "_".join(
366 | f"{k}={v}" for k, v in conf.training.task_kwargs.items()
367 | )
368 | num_tasks = (
369 | conf.training.num_tasks if "num_tasks" in conf.training else None
370 | )
371 | params["num_tasks"] = num_tasks if num_tasks is not None else -1
372 | num_examples = (
373 | conf.training.num_training_examples
374 | if "num_training_examples" in conf.training
375 | else None
376 | )
377 | params["num_examples"] = num_examples if num_examples is not None else -1
378 | params["n_dims"] = conf.model.n_dims
379 | params["n_layer"] = conf.model.n_layer
380 | params["n_head"] = conf.model.n_head
381 | params["run_name"] = conf.wandb.name
382 |
383 | for k, v in params.items():
384 | if k not in all_runs:
385 | all_runs[k] = []
386 | all_runs[k].append(v)
387 |
388 | df = pd.DataFrame(all_runs).sort_values("run_name")
389 | assert len(df) == len(df.run_name.unique())
390 | return df
391 |
392 | if __name__ == "__main__":
393 | run_dir = sys.argv[1]
394 | for task in os.listdir(run_dir):
395 | task_dir = os.path.join(run_dir, task)
396 | print(f"Evaluating task {task}")
397 | for run_id in tqdm(os.listdir(task_dir)):
398 | run_path = os.path.join(run_dir, task, run_id)
399 | metrics = get_run_metrics(run_path)
400 |
--------------------------------------------------------------------------------
/src/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from transformers import GPT2Model, GPT2Config
4 | from tqdm import tqdm
5 | from sklearn.svm import LinearSVC
6 | from sklearn.linear_model import LogisticRegression, Lasso
7 | import warnings
8 | from sklearn import tree
9 | import xgboost as xgb
10 |
11 | from base_models import NeuralNetwork, ParallelNetworks
12 |
13 |
14 | def build_model(conf):
15 | if conf.family == "gpt2":
16 | model = TransformerModel(
17 | n_dims=conf.n_dims,
18 | n_positions=conf.n_positions,
19 | n_embd=conf.n_embd,
20 | n_layer=conf.n_layer,
21 | n_head=conf.n_head,
22 | )
23 | else:
24 | raise NotImplementedError
25 |
26 | return model
27 |
28 |
29 | def get_relevant_baselines(task_name):
30 | task_to_baselines = {
31 | "linear_regression": [
32 | (LeastSquaresModel, {}),
33 | (NNModel, {"n_neighbors": 3}),
34 | (AveragingModel, {}),
35 | ],
36 | "linear_classification": [
37 | (NNModel, {"n_neighbors": 3}),
38 | (AveragingModel, {}),
39 | ],
40 | "sparse_linear_regression": [
41 | (LeastSquaresModel, {}),
42 | (NNModel, {"n_neighbors": 3}),
43 | (AveragingModel, {}),
44 | ]
45 | + [(LassoModel, {"alpha": alpha}) for alpha in [1, 0.1, 0.01, 0.001, 0.0001]],
46 | "relu_2nn_regression": [
47 | (LeastSquaresModel, {}),
48 | (NNModel, {"n_neighbors": 3}),
49 | (AveragingModel, {}),
50 | (
51 | GDModel,
52 | {
53 | "model_class": NeuralNetwork,
54 | "model_class_args": {
55 | "in_size": 20,
56 | "hidden_size": 100,
57 | "out_size": 1,
58 | },
59 | "opt_alg": "adam",
60 | "batch_size": 100,
61 | "lr": 5e-3,
62 | "num_steps": 100,
63 | },
64 | ),
65 | ],
66 | "decision_tree": [
67 | (LeastSquaresModel, {}),
68 | (NNModel, {"n_neighbors": 3}),
69 | (DecisionTreeModel, {"max_depth": 4}),
70 | (DecisionTreeModel, {"max_depth": None}),
71 | (XGBoostModel, {}),
72 | (AveragingModel, {}),
73 | ],
74 | }
75 |
76 | models = [model_cls(**kwargs) for model_cls, kwargs in task_to_baselines[task_name]]
77 | return models
78 |
79 |
80 | class TransformerModel(nn.Module):
81 | def __init__(self, n_dims, n_positions, n_embd=128, n_layer=12, n_head=4):
82 | super(TransformerModel, self).__init__()
83 | configuration = GPT2Config(
84 | n_positions=2 * n_positions,
85 | n_embd=n_embd,
86 | n_layer=n_layer,
87 | n_head=n_head,
88 | resid_pdrop=0.0,
89 | embd_pdrop=0.0,
90 | attn_pdrop=0.0,
91 | use_cache=False,
92 | )
93 | self.name = f"gpt2_embd={n_embd}_layer={n_layer}_head={n_head}"
94 |
95 | self.n_positions = n_positions
96 | self.n_dims = n_dims
97 | self._read_in = nn.Linear(n_dims, n_embd)
98 | self._backbone = GPT2Model(configuration)
99 | self._read_out = nn.Linear(n_embd, 1)
100 |
101 | @staticmethod
102 | def _combine(xs_b, ys_b):
103 | """Interleaves the x's and the y's into a single sequence."""
104 | bsize, points, dim = xs_b.shape
105 | ys_b_wide = torch.cat(
106 | (
107 | ys_b.view(bsize, points, 1),
108 | torch.zeros(bsize, points, dim - 1, device=ys_b.device),
109 | ),
110 | axis=2,
111 | )
112 | zs = torch.stack((xs_b, ys_b_wide), dim=2)
113 | zs = zs.view(bsize, 2 * points, dim)
114 | return zs
115 |
116 | def forward(self, xs, ys, inds=None):
117 | if inds is None:
118 | inds = torch.arange(ys.shape[1])
119 | else:
120 | inds = torch.tensor(inds)
121 | if max(inds) >= ys.shape[1] or min(inds) < 0:
122 | raise ValueError("inds contain indices where xs and ys are not defined")
123 | zs = self._combine(xs, ys)
124 | embeds = self._read_in(zs)
125 | output = self._backbone(inputs_embeds=embeds).last_hidden_state
126 | prediction = self._read_out(output)
127 | return prediction[:, ::2, 0][:, inds] # predict only on xs
128 |
129 |
130 | class NNModel:
131 | def __init__(self, n_neighbors, weights="uniform"):
132 | # should we be picking k optimally
133 | self.n_neighbors = n_neighbors
134 | self.weights = weights
135 | self.name = f"NN_n={n_neighbors}_{weights}"
136 |
137 | def __call__(self, xs, ys, inds=None):
138 | if inds is None:
139 | inds = range(ys.shape[1])
140 | else:
141 | if max(inds) >= ys.shape[1] or min(inds) < 0:
142 | raise ValueError("inds contain indices where xs and ys are not defined")
143 |
144 | preds = []
145 |
146 | for i in inds:
147 | if i == 0:
148 | preds.append(torch.zeros_like(ys[:, 0])) # predict zero for first point
149 | continue
150 | train_xs, train_ys = xs[:, :i], ys[:, :i]
151 | test_x = xs[:, i : i + 1]
152 | dist = (train_xs - test_x).square().sum(dim=2).sqrt()
153 |
154 | if self.weights == "uniform":
155 | weights = torch.ones_like(dist)
156 | else:
157 | weights = 1.0 / dist
158 | inf_mask = torch.isinf(weights).float() # deal with exact match
159 | inf_row = torch.any(inf_mask, axis=1)
160 | weights[inf_row] = inf_mask[inf_row]
161 |
162 | pred = []
163 | k = min(i, self.n_neighbors)
164 | ranks = dist.argsort()[:, :k]
165 | for y, w, n in zip(train_ys, weights, ranks):
166 | y, w = y[n], w[n]
167 | pred.append((w * y).sum() / w.sum())
168 | preds.append(torch.stack(pred))
169 |
170 | return torch.stack(preds, dim=1)
171 |
172 |
173 | # xs and ys should be on cpu for this method. Otherwise the output maybe off in case when train_xs is not full rank due to the implementation of torch.linalg.lstsq.
174 | class LeastSquaresModel:
175 | def __init__(self, driver=None):
176 | self.driver = driver
177 | self.name = f"OLS_driver={driver}"
178 |
179 | def __call__(self, xs, ys, inds=None):
180 | xs, ys = xs.cpu(), ys.cpu()
181 | if inds is None:
182 | inds = range(ys.shape[1])
183 | else:
184 | if max(inds) >= ys.shape[1] or min(inds) < 0:
185 | raise ValueError("inds contain indices where xs and ys are not defined")
186 |
187 | preds = []
188 |
189 | for i in inds:
190 | if i == 0:
191 | preds.append(torch.zeros_like(ys[:, 0])) # predict zero for first point
192 | continue
193 | train_xs, train_ys = xs[:, :i], ys[:, :i]
194 | test_x = xs[:, i : i + 1]
195 |
196 | ws, _, _, _ = torch.linalg.lstsq(
197 | train_xs, train_ys.unsqueeze(2), driver=self.driver
198 | )
199 |
200 | pred = test_x @ ws
201 | preds.append(pred[:, 0, 0])
202 |
203 | return torch.stack(preds, dim=1)
204 |
205 |
206 | class AveragingModel:
207 | def __init__(self):
208 | self.name = "averaging"
209 |
210 | def __call__(self, xs, ys, inds=None):
211 | if inds is None:
212 | inds = range(ys.shape[1])
213 | else:
214 | if max(inds) >= ys.shape[1] or min(inds) < 0:
215 | raise ValueError("inds contain indices where xs and ys are not defined")
216 |
217 | preds = []
218 |
219 | for i in inds:
220 | if i == 0:
221 | preds.append(torch.zeros_like(ys[:, 0])) # predict zero for first point
222 | continue
223 | train_xs, train_ys = xs[:, :i], ys[:, :i]
224 | test_x = xs[:, i : i + 1]
225 |
226 | train_zs = train_xs * train_ys.unsqueeze(dim=-1)
227 | w_p = train_zs.mean(dim=1).unsqueeze(dim=-1)
228 | pred = test_x @ w_p
229 | preds.append(pred[:, 0, 0])
230 |
231 | return torch.stack(preds, dim=1)
232 |
233 |
234 | # Lasso regression (for sparse linear regression).
235 | # Seems to take more time as we decrease alpha.
236 | class LassoModel:
237 | def __init__(self, alpha, max_iter=100000):
238 | # the l1 regularizer gets multiplied by alpha.
239 | self.alpha = alpha
240 | self.max_iter = max_iter
241 | self.name = f"lasso_alpha={alpha}_max_iter={max_iter}"
242 |
243 | # inds is a list containing indices where we want the prediction.
244 | # prediction made at all indices by default.
245 | def __call__(self, xs, ys, inds=None):
246 | xs, ys = xs.cpu(), ys.cpu()
247 |
248 | if inds is None:
249 | inds = range(ys.shape[1])
250 | else:
251 | if max(inds) >= ys.shape[1] or min(inds) < 0:
252 | raise ValueError("inds contain indices where xs and ys are not defined")
253 |
254 | preds = [] # predict one for first point
255 |
256 | # i: loop over num_points
257 | # j: loop over bsize
258 | for i in inds:
259 | pred = torch.zeros_like(ys[:, 0])
260 |
261 | if i > 0:
262 | pred = torch.zeros_like(ys[:, 0])
263 | for j in range(ys.shape[0]):
264 | train_xs, train_ys = xs[j, :i], ys[j, :i]
265 |
266 | # If all points till now have the same label, predict that label.
267 |
268 | clf = Lasso(
269 | alpha=self.alpha, fit_intercept=False, max_iter=self.max_iter
270 | )
271 |
272 | # Check for convergence.
273 | with warnings.catch_warnings():
274 | warnings.filterwarnings("error")
275 | try:
276 | clf.fit(train_xs, train_ys)
277 | except Warning:
278 | print(f"lasso convergence warning at i={i}, j={j}.")
279 | raise
280 |
281 | w_pred = torch.from_numpy(clf.coef_).unsqueeze(1)
282 |
283 | test_x = xs[j, i : i + 1]
284 | y_pred = (test_x @ w_pred.float()).squeeze(1)
285 | pred[j] = y_pred[0]
286 |
287 | preds.append(pred)
288 |
289 | return torch.stack(preds, dim=1)
290 |
291 |
292 | # Gradient Descent and variants.
293 | # Example usage: gd_model = GDModel(NeuralNetwork, {'in_size': 50, 'hidden_size':400, 'out_size' :1}, opt_alg = 'adam', batch_size = 100, lr = 5e-3, num_steps = 200)
294 | class GDModel:
295 | def __init__(
296 | self,
297 | model_class,
298 | model_class_args,
299 | opt_alg="sgd",
300 | batch_size=1,
301 | num_steps=1000,
302 | lr=1e-3,
303 | loss_name="squared",
304 | ):
305 | # model_class: torch.nn model class
306 | # model_class_args: a dict containing arguments for model_class
307 | # opt_alg can be 'sgd' or 'adam'
308 | # verbose: whether to print the progress or not
309 | # batch_size: batch size for sgd
310 | self.model_class = model_class
311 | self.model_class_args = model_class_args
312 | self.opt_alg = opt_alg
313 | self.lr = lr
314 | self.batch_size = batch_size
315 | self.num_steps = num_steps
316 | self.loss_name = loss_name
317 |
318 | self.name = f"gd_model_class={model_class}_model_class_args={model_class_args}_opt_alg={opt_alg}_lr={lr}_batch_size={batch_size}_num_steps={num_steps}_loss_name={loss_name}"
319 |
320 | def __call__(self, xs, ys, inds=None, verbose=False, print_step=100):
321 | # inds is a list containing indices where we want the prediction.
322 | # prediction made at all indices by default.
323 | # xs: bsize X npoints X ndim.
324 | # ys: bsize X npoints.
325 | xs, ys = xs.cuda(), ys.cuda()
326 |
327 | if inds is None:
328 | inds = range(ys.shape[1])
329 | else:
330 | if max(inds) >= ys.shape[1] or min(inds) < 0:
331 | raise ValueError("inds contain indices where xs and ys are not defined")
332 |
333 | preds = [] # predict one for first point
334 |
335 | # i: loop over num_points
336 | for i in tqdm(inds):
337 | pred = torch.zeros_like(ys[:, 0])
338 | model = ParallelNetworks(
339 | ys.shape[0], self.model_class, **self.model_class_args
340 | )
341 | model.cuda()
342 | if i > 0:
343 | pred = torch.zeros_like(ys[:, 0])
344 |
345 | train_xs, train_ys = xs[:, :i], ys[:, :i]
346 | test_xs, test_ys = xs[:, i : i + 1], ys[:, i : i + 1]
347 |
348 | if self.opt_alg == "sgd":
349 | optimizer = torch.optim.SGD(model.parameters(), lr=self.lr)
350 | elif self.opt_alg == "adam":
351 | optimizer = torch.optim.Adam(model.parameters(), lr=self.lr)
352 | else:
353 | raise NotImplementedError(f"{self.opt_alg} not implemented.")
354 |
355 | if self.loss_name == "squared":
356 | loss_criterion = nn.MSELoss()
357 | else:
358 | raise NotImplementedError(f"{self.loss_name} not implemented.")
359 |
360 | # Training loop
361 | for j in range(self.num_steps):
362 |
363 | # Prepare batch
364 | mask = torch.zeros(i).bool()
365 | perm = torch.randperm(i)
366 | mask[perm[: self.batch_size]] = True
367 | train_xs_cur, train_ys_cur = train_xs[:, mask, :], train_ys[:, mask]
368 |
369 | if verbose and j % print_step == 0:
370 | model.eval()
371 | with torch.no_grad():
372 | outputs = model(train_xs_cur)
373 | loss = loss_criterion(
374 | outputs[:, :, 0], train_ys_cur
375 | ).detach()
376 | outputs_test = model(test_xs)
377 | test_loss = loss_criterion(
378 | outputs_test[:, :, 0], test_ys
379 | ).detach()
380 | print(
381 | f"ind:{i},step:{j}, train_loss:{loss.item()}, test_loss:{test_loss.item()}"
382 | )
383 |
384 | optimizer.zero_grad()
385 |
386 | model.train()
387 | outputs = model(train_xs_cur)
388 | loss = loss_criterion(outputs[:, :, 0], train_ys_cur)
389 | loss.backward()
390 | optimizer.step()
391 |
392 | model.eval()
393 | pred = model(test_xs).detach()
394 |
395 | assert pred.shape[1] == 1 and pred.shape[2] == 1
396 | pred = pred[:, 0, 0]
397 |
398 | preds.append(pred)
399 |
400 | return torch.stack(preds, dim=1)
401 |
402 |
403 | class DecisionTreeModel:
404 | def __init__(self, max_depth=None):
405 | self.max_depth = max_depth
406 | self.name = f"decision_tree_max_depth={max_depth}"
407 |
408 | # inds is a list containing indices where we want the prediction.
409 | # prediction made at all indices by default.
410 | def __call__(self, xs, ys, inds=None):
411 | xs, ys = xs.cpu(), ys.cpu()
412 |
413 | if inds is None:
414 | inds = range(ys.shape[1])
415 | else:
416 | if max(inds) >= ys.shape[1] or min(inds) < 0:
417 | raise ValueError("inds contain indices where xs and ys are not defined")
418 |
419 | preds = []
420 |
421 | # i: loop over num_points
422 | # j: loop over bsize
423 | for i in inds:
424 | pred = torch.zeros_like(ys[:, 0])
425 |
426 | if i > 0:
427 | pred = torch.zeros_like(ys[:, 0])
428 | for j in range(ys.shape[0]):
429 | train_xs, train_ys = xs[j, :i], ys[j, :i]
430 |
431 | clf = tree.DecisionTreeRegressor(max_depth=self.max_depth)
432 | clf = clf.fit(train_xs, train_ys)
433 | test_x = xs[j, i : i + 1]
434 | y_pred = clf.predict(test_x)
435 | pred[j] = y_pred[0]
436 |
437 | preds.append(pred)
438 |
439 | return torch.stack(preds, dim=1)
440 |
441 |
442 | class XGBoostModel:
443 | def __init__(self):
444 | self.name = "xgboost"
445 |
446 | # inds is a list containing indices where we want the prediction.
447 | # prediction made at all indices by default.
448 | def __call__(self, xs, ys, inds=None):
449 | xs, ys = xs.cpu(), ys.cpu()
450 |
451 | if inds is None:
452 | inds = range(ys.shape[1])
453 | else:
454 | if max(inds) >= ys.shape[1] or min(inds) < 0:
455 | raise ValueError("inds contain indices where xs and ys are not defined")
456 |
457 | preds = []
458 |
459 | # i: loop over num_points
460 | # j: loop over bsize
461 | for i in tqdm(inds):
462 | pred = torch.zeros_like(ys[:, 0])
463 | if i > 0:
464 | pred = torch.zeros_like(ys[:, 0])
465 | for j in range(ys.shape[0]):
466 | train_xs, train_ys = xs[j, :i], ys[j, :i]
467 |
468 | clf = xgb.XGBRegressor()
469 |
470 | clf = clf.fit(train_xs, train_ys)
471 | test_x = xs[j, i : i + 1]
472 | y_pred = clf.predict(test_x)
473 | pred[j] = y_pred[0].item()
474 |
475 | preds.append(pred)
476 |
477 | return torch.stack(preds, dim=1)
478 |
--------------------------------------------------------------------------------
/src/plot_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import matplotlib.pyplot as plt
4 | import seaborn as sns
5 |
6 | from eval import get_run_metrics, baseline_names, get_model_from_run
7 | from models import build_model
8 |
9 | sns.set_theme("notebook", "darkgrid")
10 | palette = sns.color_palette("colorblind")
11 |
12 |
13 | relevant_model_names = {
14 | "linear_regression": [
15 | "Transformer",
16 | "Least Squares",
17 | "3-Nearest Neighbors",
18 | "Averaging",
19 | ],
20 | "sparse_linear_regression": [
21 | "Transformer",
22 | "Least Squares",
23 | "3-Nearest Neighbors",
24 | "Averaging",
25 | "Lasso (alpha=0.01)",
26 | ],
27 | "decision_tree": [
28 | "Transformer",
29 | "3-Nearest Neighbors",
30 | "2-layer NN, GD",
31 | "Greedy Tree Learning",
32 | "XGBoost",
33 | ],
34 | "relu_2nn_regression": [
35 | "Transformer",
36 | "Least Squares",
37 | "3-Nearest Neighbors",
38 | "2-layer NN, GD",
39 | ],
40 | }
41 |
42 |
43 | def basic_plot(metrics, models=None, trivial=1.0):
44 | fig, ax = plt.subplots(1, 1)
45 |
46 | if models is not None:
47 | metrics = {k: metrics[k] for k in models}
48 |
49 | color = 0
50 | ax.axhline(trivial, ls="--", color="gray")
51 | for name, vs in metrics.items():
52 | ax.plot(vs["mean"], "-", label=name, color=palette[color % 10], lw=2)
53 | low = vs["bootstrap_low"]
54 | high = vs["bootstrap_high"]
55 | ax.fill_between(range(len(low)), low, high, alpha=0.3)
56 | color += 1
57 | ax.set_xlabel("in-context examples")
58 | ax.set_ylabel("squared error")
59 | ax.set_xlim(-1, len(low) + 0.1)
60 | ax.set_ylim(-0.1, 1.25)
61 |
62 | legend = ax.legend(loc="upper left", bbox_to_anchor=(1, 1))
63 | fig.set_size_inches(4, 3)
64 | for line in legend.get_lines():
65 | line.set_linewidth(3)
66 |
67 | return fig, ax
68 |
69 |
70 | def collect_results(run_dir, df, valid_row=None, rename_eval=None, rename_model=None):
71 | all_metrics = {}
72 | for _, r in df.iterrows():
73 | if valid_row is not None and not valid_row(r):
74 | continue
75 |
76 | run_path = os.path.join(run_dir, r.task, r.run_id)
77 | _, conf = get_model_from_run(run_path, only_conf=True)
78 |
79 | print(r.run_name, r.run_id)
80 | metrics = get_run_metrics(run_path, skip_model_load=True)
81 |
82 | for eval_name, results in sorted(metrics.items()):
83 | processed_results = {}
84 | for model_name, m in results.items():
85 | if "gpt2" in model_name in model_name:
86 | model_name = r.model
87 | if rename_model is not None:
88 | model_name = rename_model(model_name, r)
89 | else:
90 | model_name = baseline_names(model_name)
91 | m_processed = {}
92 | n_dims = conf.model.n_dims
93 |
94 | xlim = 2 * n_dims + 1
95 | if r.task in ["relu_2nn_regression", "decision_tree"]:
96 | xlim = 200
97 |
98 | normalization = n_dims
99 | if r.task == "sparse_linear_regression":
100 | normalization = int(r.kwargs.split("=")[-1])
101 | if r.task == "decision_tree":
102 | normalization = 1
103 |
104 | for k, v in m.items():
105 | v = v[:xlim]
106 | v = [vv / normalization for vv in v]
107 | m_processed[k] = v
108 | processed_results[model_name] = m_processed
109 | if rename_eval is not None:
110 | eval_name = rename_eval(eval_name, r)
111 | if eval_name not in all_metrics:
112 | all_metrics[eval_name] = {}
113 | all_metrics[eval_name].update(processed_results)
114 | return all_metrics
115 |
--------------------------------------------------------------------------------
/src/samplers.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 |
5 |
6 | class DataSampler:
7 | def __init__(self, n_dims):
8 | self.n_dims = n_dims
9 |
10 | def sample_xs(self):
11 | raise NotImplementedError
12 |
13 |
14 | def get_data_sampler(data_name, n_dims, **kwargs):
15 | names_to_classes = {
16 | "gaussian": GaussianSampler,
17 | }
18 | if data_name in names_to_classes:
19 | sampler_cls = names_to_classes[data_name]
20 | return sampler_cls(n_dims, **kwargs)
21 | else:
22 | print("Unknown sampler")
23 | raise NotImplementedError
24 |
25 |
26 | def sample_transformation(eigenvalues, normalize=False):
27 | n_dims = len(eigenvalues)
28 | U, _, _ = torch.linalg.svd(torch.randn(n_dims, n_dims))
29 | t = U @ torch.diag(eigenvalues) @ torch.transpose(U, 0, 1)
30 | if normalize:
31 | norm_subspace = torch.sum(eigenvalues**2)
32 | t *= math.sqrt(n_dims / norm_subspace)
33 | return t
34 |
35 |
36 | class GaussianSampler(DataSampler):
37 | def __init__(self, n_dims, bias=None, scale=None):
38 | super().__init__(n_dims)
39 | self.bias = bias
40 | self.scale = scale
41 |
42 | def sample_xs(self, n_points, b_size, n_dims_truncated=None, seeds=None):
43 | if seeds is None:
44 | xs_b = torch.randn(b_size, n_points, self.n_dims)
45 | else:
46 | xs_b = torch.zeros(b_size, n_points, self.n_dims)
47 | generator = torch.Generator()
48 | assert len(seeds) == b_size
49 | for i, seed in enumerate(seeds):
50 | generator.manual_seed(seed)
51 | xs_b[i] = torch.randn(n_points, self.n_dims, generator=generator)
52 | if self.scale is not None:
53 | xs_b = xs_b @ self.scale
54 | if self.bias is not None:
55 | xs_b += self.bias
56 | if n_dims_truncated is not None:
57 | xs_b[:, :, n_dims_truncated:] = 0
58 | return xs_b
59 |
--------------------------------------------------------------------------------
/src/schema.py:
--------------------------------------------------------------------------------
1 | from quinine import (
2 | tstring,
3 | tinteger,
4 | tfloat,
5 | tboolean,
6 | stdict,
7 | tdict,
8 | default,
9 | required,
10 | allowed,
11 | nullable,
12 | )
13 | from funcy import merge
14 |
15 |
16 | model_schema = {
17 | "family": merge(tstring, allowed(["gpt2", "lstm"])),
18 | "n_positions": merge(tinteger, required), # maximum context length
19 | "n_dims": merge(tinteger, required), # latent dimension
20 | "n_embd": merge(tinteger, required),
21 | "n_layer": merge(tinteger, required),
22 | "n_head": merge(tinteger, required),
23 | }
24 |
25 | curriculum_base_schema = {
26 | "start": merge(tinteger, required), # initial parameter
27 | "end": merge(tinteger, required), # limit of final value
28 | "inc": merge(tinteger, required), # how much to increment each time
29 | "interval": merge(tinteger, required), # increment every how many steps
30 | }
31 |
32 | curriculum_schema = {
33 | "dims": stdict(curriculum_base_schema),
34 | "points": stdict(curriculum_base_schema),
35 | }
36 |
37 | TASK_LIST = [
38 | "linear_regression",
39 | "sparse_linear_regression",
40 | "linear_classification",
41 | "relu_2nn_regression",
42 | "decision_tree",
43 | ]
44 |
45 | training_schema = {
46 | "task": merge(tstring, allowed(TASK_LIST)),
47 | "task_kwargs": merge(tdict, required),
48 | "num_tasks": merge(tinteger, nullable, default(None)),
49 | "num_training_examples": merge(tinteger, nullable, default(None)),
50 | "data": merge(tstring, allowed(["gaussian"])),
51 | "batch_size": merge(tinteger, default(64)),
52 | "learning_rate": merge(tfloat, default(3e-4)),
53 | "train_steps": merge(tinteger, default(1000)),
54 | "save_every_steps": merge(tinteger, default(1000)), # how often to checkpoint
55 | "keep_every_steps": merge(tinteger, default(-1)), # permanent checkpoints
56 | "resume_id": merge(tstring, nullable, default(None)), # run uuid64
57 | "curriculum": stdict(curriculum_schema),
58 | }
59 |
60 | wandb_schema = {
61 | "project": merge(tstring, default("in-context-training")),
62 | "entity": merge(tstring, default("in-context")),
63 | "notes": merge(tstring, default("")),
64 | "name": merge(tstring, nullable, default(None)),
65 | "log_every_steps": merge(tinteger, default(10)),
66 | }
67 |
68 | schema = {
69 | "out_dir": merge(tstring, required),
70 | "model": stdict(model_schema),
71 | "training": stdict(training_schema),
72 | "wandb": stdict(wandb_schema),
73 | "test_run": merge(tboolean, default(False)),
74 | }
75 |
--------------------------------------------------------------------------------
/src/tasks.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 |
5 |
6 | def squared_error(ys_pred, ys):
7 | return (ys - ys_pred).square()
8 |
9 |
10 | def mean_squared_error(ys_pred, ys):
11 | return (ys - ys_pred).square().mean()
12 |
13 |
14 | def accuracy(ys_pred, ys):
15 | return (ys == ys_pred.sign()).float()
16 |
17 |
18 | sigmoid = torch.nn.Sigmoid()
19 | bce_loss = torch.nn.BCELoss()
20 |
21 |
22 | def cross_entropy(ys_pred, ys):
23 | output = sigmoid(ys_pred)
24 | target = (ys + 1) / 2
25 | return bce_loss(output, target)
26 |
27 |
28 | class Task:
29 | def __init__(self, n_dims, batch_size, pool_dict=None, seeds=None):
30 | self.n_dims = n_dims
31 | self.b_size = batch_size
32 | self.pool_dict = pool_dict
33 | self.seeds = seeds
34 | assert pool_dict is None or seeds is None
35 |
36 | def evaluate(self, xs):
37 | raise NotImplementedError
38 |
39 | @staticmethod
40 | def generate_pool_dict(n_dims, num_tasks):
41 | raise NotImplementedError
42 |
43 | @staticmethod
44 | def get_metric():
45 | raise NotImplementedError
46 |
47 | @staticmethod
48 | def get_training_metric():
49 | raise NotImplementedError
50 |
51 |
52 | def get_task_sampler(
53 | task_name, n_dims, batch_size, pool_dict=None, num_tasks=None, **kwargs
54 | ):
55 | task_names_to_classes = {
56 | "linear_regression": LinearRegression,
57 | "sparse_linear_regression": SparseLinearRegression,
58 | "linear_classification": LinearClassification,
59 | "noisy_linear_regression": NoisyLinearRegression,
60 | "quadratic_regression": QuadraticRegression,
61 | "relu_2nn_regression": Relu2nnRegression,
62 | "decision_tree": DecisionTree,
63 | }
64 | if task_name in task_names_to_classes:
65 | task_cls = task_names_to_classes[task_name]
66 | if num_tasks is not None:
67 | if pool_dict is not None:
68 | raise ValueError("Either pool_dict or num_tasks should be None.")
69 | pool_dict = task_cls.generate_pool_dict(n_dims, num_tasks, **kwargs)
70 | return lambda **args: task_cls(n_dims, batch_size, pool_dict, **args, **kwargs)
71 | else:
72 | print("Unknown task")
73 | raise NotImplementedError
74 |
75 |
76 | class LinearRegression(Task):
77 | def __init__(self, n_dims, batch_size, pool_dict=None, seeds=None, scale=1):
78 | """scale: a constant by which to scale the randomly sampled weights."""
79 | super(LinearRegression, self).__init__(n_dims, batch_size, pool_dict, seeds)
80 | self.scale = scale
81 |
82 | if pool_dict is None and seeds is None:
83 | self.w_b = torch.randn(self.b_size, self.n_dims, 1)
84 | elif seeds is not None:
85 | self.w_b = torch.zeros(self.b_size, self.n_dims, 1)
86 | generator = torch.Generator()
87 | assert len(seeds) == self.b_size
88 | for i, seed in enumerate(seeds):
89 | generator.manual_seed(seed)
90 | self.w_b[i] = torch.randn(self.n_dims, 1, generator=generator)
91 | else:
92 | assert "w" in pool_dict
93 | indices = torch.randperm(len(pool_dict["w"]))[:batch_size]
94 | self.w_b = pool_dict["w"][indices]
95 |
96 | def evaluate(self, xs_b):
97 | w_b = self.w_b.to(xs_b.device)
98 | ys_b = self.scale * (xs_b @ w_b)[:, :, 0]
99 | return ys_b
100 |
101 | @staticmethod
102 | def generate_pool_dict(n_dims, num_tasks, **kwargs): # ignore extra args
103 | return {"w": torch.randn(num_tasks, n_dims, 1)}
104 |
105 | @staticmethod
106 | def get_metric():
107 | return squared_error
108 |
109 | @staticmethod
110 | def get_training_metric():
111 | return mean_squared_error
112 |
113 |
114 | class SparseLinearRegression(LinearRegression):
115 | def __init__(
116 | self,
117 | n_dims,
118 | batch_size,
119 | pool_dict=None,
120 | seeds=None,
121 | scale=1,
122 | sparsity=3,
123 | valid_coords=None,
124 | ):
125 | """scale: a constant by which to scale the randomly sampled weights."""
126 | super(SparseLinearRegression, self).__init__(
127 | n_dims, batch_size, pool_dict, seeds, scale
128 | )
129 | self.sparsity = sparsity
130 | if valid_coords is None:
131 | valid_coords = n_dims
132 | assert valid_coords <= n_dims
133 |
134 | for i, w in enumerate(self.w_b):
135 | mask = torch.ones(n_dims).bool()
136 | if seeds is None:
137 | perm = torch.randperm(valid_coords)
138 | else:
139 | generator = torch.Generator()
140 | generator.manual_seed(seeds[i])
141 | perm = torch.randperm(valid_coords, generator=generator)
142 | mask[perm[:sparsity]] = False
143 | w[mask] = 0
144 |
145 | def evaluate(self, xs_b):
146 | w_b = self.w_b.to(xs_b.device)
147 | ys_b = self.scale * (xs_b @ w_b)[:, :, 0]
148 | return ys_b
149 |
150 | @staticmethod
151 | def get_metric():
152 | return squared_error
153 |
154 | @staticmethod
155 | def get_training_metric():
156 | return mean_squared_error
157 |
158 |
159 | class LinearClassification(LinearRegression):
160 | def evaluate(self, xs_b):
161 | ys_b = super().evaluate(xs_b)
162 | return ys_b.sign()
163 |
164 | @staticmethod
165 | def get_metric():
166 | return accuracy
167 |
168 | @staticmethod
169 | def get_training_metric():
170 | return cross_entropy
171 |
172 |
173 | class NoisyLinearRegression(LinearRegression):
174 | def __init__(
175 | self,
176 | n_dims,
177 | batch_size,
178 | pool_dict=None,
179 | seeds=None,
180 | scale=1,
181 | noise_std=0,
182 | renormalize_ys=False,
183 | ):
184 | """noise_std: standard deviation of noise added to the prediction."""
185 | super(NoisyLinearRegression, self).__init__(
186 | n_dims, batch_size, pool_dict, seeds, scale
187 | )
188 | self.noise_std = noise_std
189 | self.renormalize_ys = renormalize_ys
190 |
191 | def evaluate(self, xs_b):
192 | ys_b = super().evaluate(xs_b)
193 | ys_b_noisy = ys_b + torch.randn_like(ys_b) * self.noise_std
194 | if self.renormalize_ys:
195 | ys_b_noisy = ys_b_noisy * math.sqrt(self.n_dims) / ys_b_noisy.std()
196 |
197 | return ys_b_noisy
198 |
199 |
200 | class QuadraticRegression(LinearRegression):
201 | def evaluate(self, xs_b):
202 | w_b = self.w_b.to(xs_b.device)
203 | ys_b_quad = ((xs_b**2) @ w_b)[:, :, 0]
204 | # ys_b_quad = ys_b_quad * math.sqrt(self.n_dims) / ys_b_quad.std()
205 | # Renormalize to Linear Regression Scale
206 | ys_b_quad = ys_b_quad / math.sqrt(3)
207 | ys_b_quad = self.scale * ys_b_quad
208 | return ys_b_quad
209 |
210 |
211 | class Relu2nnRegression(Task):
212 | def __init__(
213 | self,
214 | n_dims,
215 | batch_size,
216 | pool_dict=None,
217 | seeds=None,
218 | scale=1,
219 | hidden_layer_size=100,
220 | ):
221 | """scale: a constant by which to scale the randomly sampled weights."""
222 | super(Relu2nnRegression, self).__init__(n_dims, batch_size, pool_dict, seeds)
223 | self.scale = scale
224 | self.hidden_layer_size = hidden_layer_size
225 |
226 | if pool_dict is None and seeds is None:
227 | self.W1 = torch.randn(self.b_size, self.n_dims, hidden_layer_size)
228 | self.W2 = torch.randn(self.b_size, hidden_layer_size, 1)
229 | elif seeds is not None:
230 | self.W1 = torch.zeros(self.b_size, self.n_dims, hidden_layer_size)
231 | self.W2 = torch.zeros(self.b_size, hidden_layer_size, 1)
232 | generator = torch.Generator()
233 | assert len(seeds) == self.b_size
234 | for i, seed in enumerate(seeds):
235 | generator.manual_seed(seed)
236 | self.W1[i] = torch.randn(
237 | self.n_dims, hidden_layer_size, generator=generator
238 | )
239 | self.W2[i] = torch.randn(hidden_layer_size, 1, generator=generator)
240 | else:
241 | assert "W1" in pool_dict and "W2" in pool_dict
242 | assert len(pool_dict["W1"]) == len(pool_dict["W2"])
243 | indices = torch.randperm(len(pool_dict["W1"]))[:batch_size]
244 | self.W1 = pool_dict["W1"][indices]
245 | self.W2 = pool_dict["W2"][indices]
246 |
247 | def evaluate(self, xs_b):
248 | W1 = self.W1.to(xs_b.device)
249 | W2 = self.W2.to(xs_b.device)
250 | # Renormalize to Linear Regression Scale
251 | ys_b_nn = (torch.nn.functional.relu(xs_b @ W1) @ W2)[:, :, 0]
252 | ys_b_nn = ys_b_nn * math.sqrt(2 / self.hidden_layer_size)
253 | ys_b_nn = self.scale * ys_b_nn
254 | # ys_b_nn = ys_b_nn * math.sqrt(self.n_dims) / ys_b_nn.std()
255 | return ys_b_nn
256 |
257 | @staticmethod
258 | def generate_pool_dict(n_dims, num_tasks, hidden_layer_size=4, **kwargs):
259 | return {
260 | "W1": torch.randn(num_tasks, n_dims, hidden_layer_size),
261 | "W2": torch.randn(num_tasks, hidden_layer_size, 1),
262 | }
263 |
264 | @staticmethod
265 | def get_metric():
266 | return squared_error
267 |
268 | @staticmethod
269 | def get_training_metric():
270 | return mean_squared_error
271 |
272 |
273 | class DecisionTree(Task):
274 | def __init__(self, n_dims, batch_size, pool_dict=None, seeds=None, depth=4):
275 |
276 | super(DecisionTree, self).__init__(n_dims, batch_size, pool_dict, seeds)
277 | self.depth = depth
278 |
279 | if pool_dict is None:
280 |
281 | # We represent the tree using an array (tensor). Root node is at index 0, its 2 children at index 1 and 2...
282 | # dt_tensor stores the coordinate used at each node of the decision tree.
283 | # Only indices corresponding to non-leaf nodes are relevant
284 | self.dt_tensor = torch.randint(
285 | low=0, high=n_dims, size=(batch_size, 2 ** (depth + 1) - 1)
286 | )
287 |
288 | # Target value at the leaf nodes.
289 | # Only indices corresponding to leaf nodes are relevant.
290 | self.target_tensor = torch.randn(self.dt_tensor.shape)
291 | elif seeds is not None:
292 | self.dt_tensor = torch.zeros(batch_size, 2 ** (depth + 1) - 1)
293 | self.target_tensor = torch.zeros_like(dt_tensor)
294 | generator = torch.Generator()
295 | assert len(seeds) == self.b_size
296 | for i, seed in enumerate(seeds):
297 | generator.manual_seed(seed)
298 | self.dt_tensor[i] = torch.randint(
299 | low=0,
300 | high=n_dims - 1,
301 | size=2 ** (depth + 1) - 1,
302 | generator=generator,
303 | )
304 | self.target_tensor[i] = torch.randn(
305 | self.dt_tensor[i].shape, generator=generator
306 | )
307 | else:
308 | raise NotImplementedError
309 |
310 | def evaluate(self, xs_b):
311 | dt_tensor = self.dt_tensor.to(xs_b.device)
312 | target_tensor = self.target_tensor.to(xs_b.device)
313 | ys_b = torch.zeros(xs_b.shape[0], xs_b.shape[1], device=xs_b.device)
314 | for i in range(xs_b.shape[0]):
315 | xs_bool = xs_b[i] > 0
316 | # If a single decision tree present, use it for all the xs in the batch.
317 | if self.b_size == 1:
318 | dt = dt_tensor[0]
319 | target = target_tensor[0]
320 | else:
321 | dt = dt_tensor[i]
322 | target = target_tensor[i]
323 |
324 | cur_nodes = torch.zeros(xs_b.shape[1], device=xs_b.device).long()
325 | for j in range(self.depth):
326 | cur_coords = dt[cur_nodes]
327 | cur_decisions = xs_bool[torch.arange(xs_bool.shape[0]), cur_coords]
328 | cur_nodes = 2 * cur_nodes + 1 + cur_decisions
329 |
330 | ys_b[i] = target[cur_nodes]
331 |
332 | return ys_b
333 |
334 | @staticmethod
335 | def generate_pool_dict(n_dims, num_tasks, hidden_layer_size=4, **kwargs):
336 | raise NotImplementedError
337 |
338 | @staticmethod
339 | def get_metric():
340 | return squared_error
341 |
342 | @staticmethod
343 | def get_training_metric():
344 | return mean_squared_error
345 |
--------------------------------------------------------------------------------
/src/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | from random import randint
3 | import uuid
4 |
5 | from quinine import QuinineArgumentParser
6 | from tqdm import tqdm
7 | import torch
8 | import yaml
9 |
10 | from eval import get_run_metrics
11 | from tasks import get_task_sampler
12 | from samplers import get_data_sampler
13 | from curriculum import Curriculum
14 | from schema import schema
15 | from models import build_model
16 |
17 | import wandb
18 |
19 | torch.backends.cudnn.benchmark = True
20 |
21 |
22 | def train_step(model, xs, ys, optimizer, loss_func):
23 | optimizer.zero_grad()
24 | output = model(xs, ys)
25 | loss = loss_func(output, ys)
26 | loss.backward()
27 | optimizer.step()
28 | return loss.detach().item(), output.detach()
29 |
30 |
31 | def sample_seeds(total_seeds, count):
32 | seeds = set()
33 | while len(seeds) < count:
34 | seeds.add(randint(0, total_seeds - 1))
35 | return seeds
36 |
37 |
38 | def train(model, args):
39 | optimizer = torch.optim.Adam(model.parameters(), lr=args.training.learning_rate)
40 | curriculum = Curriculum(args.training.curriculum)
41 |
42 | starting_step = 0
43 | state_path = os.path.join(args.out_dir, "state.pt")
44 | if os.path.exists(state_path):
45 | state = torch.load(state_path)
46 | model.load_state_dict(state["model_state_dict"])
47 | optimizer.load_state_dict(state["optimizer_state_dict"])
48 | starting_step = state["train_step"]
49 | for i in range(state["train_step"] + 1):
50 | curriculum.update()
51 |
52 | n_dims = model.n_dims
53 | bsize = args.training.batch_size
54 | data_sampler = get_data_sampler(args.training.data, n_dims=n_dims)
55 | task_sampler = get_task_sampler(
56 | args.training.task,
57 | n_dims,
58 | bsize,
59 | num_tasks=args.training.num_tasks,
60 | **args.training.task_kwargs,
61 | )
62 | pbar = tqdm(range(starting_step, args.training.train_steps))
63 |
64 | num_training_examples = args.training.num_training_examples
65 |
66 | for i in pbar:
67 | data_sampler_args = {}
68 | task_sampler_args = {}
69 |
70 | if "sparse" in args.training.task:
71 | task_sampler_args["valid_coords"] = curriculum.n_dims_truncated
72 | if num_training_examples is not None:
73 | assert num_training_examples >= bsize
74 | seeds = sample_seeds(num_training_examples, bsize)
75 | data_sampler_args["seeds"] = seeds
76 | task_sampler_args["seeds"] = [s + 1 for s in seeds]
77 |
78 | xs = data_sampler.sample_xs(
79 | curriculum.n_points,
80 | bsize,
81 | curriculum.n_dims_truncated,
82 | **data_sampler_args,
83 | )
84 | task = task_sampler(**task_sampler_args)
85 | ys = task.evaluate(xs)
86 |
87 | loss_func = task.get_training_metric()
88 |
89 | loss, output = train_step(model, xs.cuda(), ys.cuda(), optimizer, loss_func)
90 |
91 | point_wise_tags = list(range(curriculum.n_points))
92 | point_wise_loss_func = task.get_metric()
93 | point_wise_loss = point_wise_loss_func(output, ys.cuda()).mean(dim=0)
94 |
95 | baseline_loss = (
96 | sum(
97 | max(curriculum.n_dims_truncated - ii, 0)
98 | for ii in range(curriculum.n_points)
99 | )
100 | / curriculum.n_points
101 | )
102 |
103 | if i % args.wandb.log_every_steps == 0 and not args.test_run:
104 | wandb.log(
105 | {
106 | "overall_loss": loss,
107 | "excess_loss": loss / baseline_loss,
108 | "pointwise/loss": dict(
109 | zip(point_wise_tags, point_wise_loss.cpu().numpy())
110 | ),
111 | "n_points": curriculum.n_points,
112 | "n_dims": curriculum.n_dims_truncated,
113 | },
114 | step=i,
115 | )
116 |
117 | curriculum.update()
118 |
119 | pbar.set_description(f"loss {loss}")
120 | if i % args.training.save_every_steps == 0 and not args.test_run:
121 | training_state = {
122 | "model_state_dict": model.state_dict(),
123 | "optimizer_state_dict": optimizer.state_dict(),
124 | "train_step": i,
125 | }
126 | torch.save(training_state, state_path)
127 |
128 | if (
129 | args.training.keep_every_steps > 0
130 | and i % args.training.keep_every_steps == 0
131 | and not args.test_run
132 | and i > 0
133 | ):
134 | torch.save(model.state_dict(), os.path.join(args.out_dir, f"model_{i}.pt"))
135 |
136 |
137 | def main(args):
138 | if args.test_run:
139 | curriculum_args = args.training.curriculum
140 | curriculum_args.points.start = curriculum_args.points.end
141 | curriculum_args.dims.start = curriculum_args.dims.end
142 | args.training.train_steps = 100
143 | else:
144 | wandb.init(
145 | dir=args.out_dir,
146 | project=args.wandb.project,
147 | entity=args.wandb.entity,
148 | config=args.__dict__,
149 | notes=args.wandb.notes,
150 | name=args.wandb.name,
151 | resume=True,
152 | )
153 |
154 | model = build_model(args.model)
155 | model.cuda()
156 | model.train()
157 |
158 | train(model, args)
159 |
160 | if not args.test_run:
161 | _ = get_run_metrics(args.out_dir) # precompute metrics for eval
162 |
163 |
164 | if __name__ == "__main__":
165 | parser = QuinineArgumentParser(schema=schema)
166 | args = parser.parse_quinfig()
167 | assert args.model.family in ["gpt2", "lstm"]
168 | print(f"Running with: {args}")
169 |
170 | if not args.test_run:
171 | run_id = args.training.resume_id
172 | if run_id is None:
173 | run_id = str(uuid.uuid4())
174 |
175 | out_dir = os.path.join(args.out_dir, run_id)
176 | if not os.path.exists(out_dir):
177 | os.makedirs(out_dir)
178 | args.out_dir = out_dir
179 |
180 | with open(os.path.join(out_dir, "config.yaml"), "w") as yaml_file:
181 | yaml.dump(args.__dict__, yaml_file, default_flow_style=False)
182 |
183 | main(args)
184 |
--------------------------------------------------------------------------------