├── .gitignore
├── LICENSE
├── README.md
├── code
├── __init__.py
├── data
│ ├── __init__.py
│ ├── cifar.py
│ ├── miniimagenet.py
│ └── mnist.py
├── models
│ ├── __init__.py
│ ├── mlp.py
│ └── resnet18.py
├── scripts
│ ├── ADI.py
│ ├── ARM.py
│ ├── __init__.py
│ ├── distill.py
│ ├── naive.py
│ ├── print_results.py
│ └── print_table.py
└── util
│ ├── __init__.py
│ ├── data.py
│ ├── eval.py
│ ├── general.py
│ ├── load.py
│ ├── losses.py
│ └── render.py
├── commands.txt
└── summary.png
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Xu Ji
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Automatic Recall Machines
2 |
3 | This repository contains the code for Automatic Recall Machines: Internal Replay, Continual Learning and the Brain.
4 |
5 |
6 |
7 | As well as ARM, we include implementations of Adaptive DeepInversion and LwF-style distillation.
8 |
9 | # Dependencies
10 |
11 | Our environment used:
12 | - python 3.6.8
13 | - pytorch 1.4.0
14 | - torchvision 0.5.0
15 | - numpy 1.18.4
16 |
17 | # Run the code
18 |
19 | Commands for all our results on CIFAR10, MiniImageNet and MNIST are given in `commands.txt`.
20 | For example, to run recall on CIFAR10:
21 | ```
22 | python -m code.scripts.ARM --model_ind_start 3717 --num_runs 5 --data cifar10 --lr 0.01 --task_model_type resnet18 --classes_per_task 2 --recall_from_t 950 --num_iterations 1 --M 100 --refine_sample_steps 10 --refine_sample_lr 10.0 --divergence_loss_weight 1.0 --L2 --L2_weight 1.0 --TV --TV_weight 1.0 --long_window --use_fixed_window --fixed_window 950 --sharpen_class --sharpen_class_weight 0.1 --notlocal_weight 1.0 --notlocal_new_weight 0.1 --diversity_weight 16.0 --aux_distill --aux_distill_weight 1.0 --max_t 4750 --store_model_freq 4750 --store_results_freq 950 --eval_freq 950 --cuda --out_root /scratch/ARM --data_path /scratch/CIFAR
23 | ```
24 | Print results:
25 | ```
26 | python -m code.scripts.print_results --root /scratch/shared/nfs1/xuji/ARM --start 3717
27 |
28 | average val: acc 0.2586 +- 0.0145, forgetting 0.1046 +- 0.0330
29 | average test: acc 0.2687 +- 0.0107, forgetting 0.0959 +- 0.0371
30 | ```
31 |
--------------------------------------------------------------------------------
/code/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xu-ji/ARM/704f69d73765f13ca5e2d8aee11f399b53c635e6/code/__init__.py
--------------------------------------------------------------------------------
/code/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .cifar import cifar10, cifar10val
2 | from .miniimagenet import miniimagenet, miniimagenetval
3 | from .mnist import mnist5k
--------------------------------------------------------------------------------
/code/data/cifar.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path
3 | import pickle
4 | import sys
5 |
6 | import numpy as np
7 | from PIL import Image
8 | from torchvision.datasets.vision import VisionDataset
9 |
10 | from code.util.general import np_rand_seed
11 |
12 | # Reference: https://github.com/optimass/Maximally_Interfered_Retrieval/blob/master/data.py
13 | # We use 1 dataloader rather than one per task
14 |
15 | __all__ = ["cifar10", "cifar10val"]
16 |
17 |
18 | class cifar(VisionDataset):
19 | train_pc = 0.95
20 |
21 | def __init__(self, root, num_classes, train=True, transform=None, target_transform=None,
22 | non_stat=False,
23 | two_classes_per_block=False, is_val=False, shuffle_classes=False,
24 | num_iterations=None):
25 | print("initialising cifar%d, is val: %s..." % (num_classes, is_val))
26 | super(cifar, self).__init__(root, transform=transform, target_transform=target_transform)
27 |
28 | self.num_classes = num_classes
29 |
30 | self.train = train
31 | self.is_val = is_val
32 | self.non_stat = non_stat
33 | self.two_classes_per_block = two_classes_per_block
34 | self.shuffle_classes = shuffle_classes
35 |
36 | self.num_iterations = num_iterations
37 | assert (num_iterations is not None)
38 |
39 | if self.is_val:
40 | assert (self.train)
41 |
42 | if self.train:
43 | downloaded_list = self.train_list
44 | else:
45 | downloaded_list = self.test_list
46 |
47 | self.data = []
48 | self.targets = []
49 |
50 | # now load the picked numpy arrays
51 | for file_name, checksum in downloaded_list:
52 | file_path = os.path.join(self.root, self.base_folder, file_name)
53 | with open(file_path, 'rb') as f:
54 | if sys.version_info[0] == 2:
55 | entry = pickle.load(f)
56 | else:
57 | entry = pickle.load(f, encoding='latin1')
58 | self.data.append(entry['data'])
59 | if 'labels' in entry:
60 | self.targets.extend(entry['labels'])
61 | else:
62 | self.targets.extend(entry['fine_labels'])
63 |
64 | self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
65 | self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
66 |
67 | sz = self.data.shape[0]
68 | assert (len(self.targets) == sz)
69 |
70 | # Split train from val deterministically -------------------------------------------------------
71 |
72 | if self.train:
73 | if not self.is_val:
74 | per_class_sz = int((sz / self.num_classes) * cifar.train_pc)
75 | data_inds = range(sz) # from start
76 | else:
77 | per_class_sz = int((sz / self.num_classes) * (1 - cifar.train_pc))
78 | data_inds = reversed(range(sz)) # from back
79 |
80 | class_counts = [0 for _ in range(self.num_classes)]
81 | chosen_inds = []
82 | for i in data_inds:
83 | c = self.targets[i]
84 | if class_counts[c] < per_class_sz:
85 | chosen_inds.append(i)
86 | class_counts[c] += 1
87 | assert (len(chosen_inds) == per_class_sz * self.num_classes)
88 |
89 | else:
90 | chosen_inds = list(range(sz))
91 |
92 | # Deterministic shuffle
93 | np.random.seed(np_rand_seed())
94 | chosen_inds = np.array(chosen_inds)
95 | np.random.shuffle(chosen_inds)
96 | self.data = [self.data[i] for i in chosen_inds]
97 | self.targets = [self.targets[i] for i in chosen_inds]
98 |
99 | # Rearrange into contiguous format for non-stationary training ---------------------------------
100 | self.task_dict_classes = {}
101 |
102 | if non_stat:
103 | # organise data and targets by targets
104 | per_class = [[] for _ in range(self.num_classes)]
105 | for i, label in enumerate(self.targets):
106 | per_class[label].append(self.data[i])
107 |
108 | new_data = []
109 | new_targets = []
110 | if not two_classes_per_block: # classes contiguous
111 | for c in range(self.num_classes):
112 | new_data += per_class[c]
113 | new_targets += [c] * len(per_class[c])
114 | self.task_dict_classes[c] = [c]
115 | else:
116 | classes = np.arange(self.num_classes)
117 | assert (not self.shuffle_classes) # sanity
118 | if self.shuffle_classes:
119 | np.random.seed(np_rand_seed())
120 | np.random.shuffle(classes)
121 |
122 | num_tasks = int(np.ceil(self.num_classes / 2.))
123 | print("two_classes_per_block: num tasks %d" % num_tasks)
124 | for t in range(num_tasks):
125 | inds = [t * 2, t * 2 + 1]
126 | if (t * 2 + 1 >= self.num_classes): # odd number of tasks
127 | assert (t == num_tasks - 1 and num_tasks % 2 == 1)
128 | inds = [t * 2]
129 |
130 | self.task_dict_classes[t] = [classes[i] for i in inds]
131 |
132 | t_data = []
133 | t_targets = []
134 | for i in inds:
135 | c = classes[i]
136 | t_data += per_class[c]
137 | t_targets += [c] * len(per_class[c])
138 |
139 | order = np.arange(len(t_data))
140 | np.random.shuffle(order)
141 | new_data += [t_data[i] for i in order]
142 | new_targets += [t_targets[i] for i in order]
143 |
144 | self.data = new_data
145 | self.targets = new_targets
146 |
147 | class_lengths = []
148 | targets_np = np.array(self.targets)
149 | for c in range(num_classes):
150 | class_lengths.append((targets_np == c).sum())
151 |
152 | print(
153 | "... finished initialising cifar train %s val %s non stat %s classes %d "
154 | "two_classes_per_block %s iterations %d shuffle %s" %
155 | (self.train, self.is_val, self.non_stat, self.num_classes, self.two_classes_per_block,
156 | self.num_iterations, self.shuffle_classes))
157 |
158 | self.orig_len = len(self.data)
159 | self.actual_len = self.orig_len * self.num_iterations
160 |
161 | if self.non_stat: # we need to care about looping over in task order
162 | assert (self.orig_len % self.num_classes == 0)
163 |
164 | self.orig_samples_per_task = int(self.orig_len / self.num_classes)
165 |
166 | if self.two_classes_per_block:
167 | self.orig_samples_per_task *= 2
168 |
169 | self.actual_samples_per_task = self.orig_samples_per_task * self.num_iterations
170 | print("orig samples per task: %d, actual samples per task: %d, orig len %d actual len %d" % (
171 | self.orig_samples_per_task, self.actual_samples_per_task, self.orig_len, self.actual_len))
172 |
173 | def __getitem__(self, index):
174 | assert (index < self.actual_len)
175 |
176 | if not self.non_stat:
177 | index = index % self.orig_len # looping over stationary data is arbitrary
178 | else:
179 | task_i, actual_offset = divmod(index, self.actual_samples_per_task)
180 | _, orig_offset = divmod(actual_offset, self.orig_samples_per_task)
181 | index = task_i * self.orig_samples_per_task + orig_offset
182 |
183 | img, target = self.data[index], self.targets[index]
184 |
185 | # doing this so that it is consistent with all other datasets
186 | # to return a PIL Image
187 | img = Image.fromarray(img)
188 |
189 | if self.transform is not None:
190 | img = self.transform(img)
191 |
192 | if self.target_transform is not None:
193 | target = self.target_transform(target)
194 |
195 | return img, target
196 |
197 | def __len__(self):
198 | return self.actual_len
199 |
200 |
201 | class cifar10(cifar):
202 | base_folder = 'cifar-10-batches-py'
203 |
204 | train_list = [
205 | ['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
206 | ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
207 | ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
208 | ['data_batch_4', '634d18415352ddfa80567beed471001a'],
209 | ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
210 | ]
211 |
212 | test_list = [
213 | ['test_batch', '40351d587109b95175f43aff81a1287e'],
214 | ]
215 |
216 | def __init__(self, root, train=True, transform=None, target_transform=None, non_stat=None,
217 | two_classes_per_block=False, is_val=False, num_iterations=None):
218 | assert (non_stat is not None)
219 | if not train:
220 | assert (num_iterations == 1)
221 | super(cifar10, self).__init__(root, num_classes=10, train=train, transform=transform,
222 | target_transform=target_transform,
223 | non_stat=non_stat, two_classes_per_block=two_classes_per_block,
224 | is_val=is_val, shuffle_classes=False,
225 | num_iterations=num_iterations)
226 |
227 |
228 | class cifar10val(cifar10):
229 | # 5% validation set
230 | def __init__(self, root, transform=None, non_stat=False, two_classes_per_block=False):
231 | super(cifar10val, self).__init__(root, train=True, transform=transform,
232 | non_stat=non_stat, two_classes_per_block=two_classes_per_block,
233 | is_val=True, num_iterations=1)
234 |
--------------------------------------------------------------------------------
/code/data/miniimagenet.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | import os.path
3 | from collections import defaultdict
4 |
5 | import numpy as np
6 | from PIL import Image
7 | from torchvision.datasets.vision import VisionDataset
8 |
9 | from code.util.general import make_valid_from_train
10 |
11 |
12 | # Reference: https://github.com/optimass/Maximally_Interfered_Retrieval/blob/master/data.py
13 | # We use 1 dataloader rather than one per task
14 |
15 | # Can download from https://www.dropbox.com/s/ed1s1dgei9kxd2p/mini-imagenet.zip?dl=0
16 |
17 | def get_data(setname, root_csv, root_images):
18 | csv_path = os.path.join(root_csv, setname + '.csv')
19 | lines = [x.strip() for x in open(csv_path, 'r').readlines()][1:]
20 |
21 | data = []
22 | label = []
23 | lb = -1
24 |
25 | wnids = []
26 |
27 | for l in lines:
28 | name, wnid = l.split(',')
29 | path = os.path.join(root_images, name)
30 | if wnid not in wnids:
31 | wnids.append(wnid)
32 | lb += 1
33 | data.append(path)
34 | label.append(lb)
35 |
36 | return data, label
37 |
38 |
39 | class MiniImagenetDatasetFolder(VisionDataset):
40 | train_val_pc = 0.95
41 | train_test_pc = 0.8
42 |
43 | def __init__(self, root, data_type=None, transform=None, target_transform=None,
44 | non_stat=False, classes_per_task=None, num_iterations=None):
45 | super(MiniImagenetDatasetFolder, self).__init__(root, transform=transform,
46 | target_transform=target_transform)
47 |
48 | self.data_type = data_type
49 | self.non_stat = non_stat
50 | self.classes_per_task = classes_per_task
51 |
52 | self.num_classes = 100
53 |
54 | self.num_iterations = num_iterations
55 | assert (num_iterations is not None)
56 |
57 | # Load data ------------------------------------------------------------------------------------
58 | # splits are deterministic
59 |
60 | images_path = os.path.join(root, "images")
61 | train_data, train_label = get_data("train", root,
62 | images_path) # zero indexed labels for all calls
63 | valid_data, valid_label = get_data("val", root, images_path)
64 | test_data, test_label = get_data("test", root, images_path)
65 |
66 | train_amt = np.unique(train_label).shape[0]
67 | valid_amt = np.unique(valid_label).shape[0]
68 | test_amt = np.unique(test_label).shape[0]
69 |
70 | assert (train_amt + valid_amt + test_amt == self.num_classes)
71 | valid_label = [x + train_amt for x in valid_label]
72 | test_label = [x + train_amt + valid_amt for x in test_label]
73 |
74 | all_data = np.array(train_data + valid_data + test_data) # np array of strings
75 | all_label = np.array(train_label + valid_label + test_label)
76 |
77 | train_ds, test_ds = [], []
78 | current_train, current_test = None, None
79 |
80 | cat = lambda x, y: np.concatenate((x, y), axis=0)
81 |
82 | self.task_dict_classes = defaultdict(list)
83 | task_i = 0
84 | for i in range(self.num_classes):
85 | self.task_dict_classes[task_i].append(i)
86 |
87 | class_indices = np.argwhere(all_label == i).reshape(-1)
88 | class_data = all_data[class_indices]
89 | class_label = all_label[class_indices]
90 | split = int(MiniImagenetDatasetFolder.train_test_pc * class_data.shape[0]) # train/test
91 |
92 | data_train, data_test = class_data[:split], class_data[split:]
93 | label_train, label_test = class_label[:split], class_label[split:]
94 |
95 | if current_train is None:
96 | current_train, current_test = (data_train, label_train), (
97 | data_test, label_test) # multiple samples here
98 | else:
99 | current_train = cat(current_train[0], data_train), cat(current_train[1], label_train)
100 | current_test = cat(current_test[0], data_test), cat(current_test[1], label_test)
101 |
102 | if i % self.classes_per_task == (self.classes_per_task - 1):
103 | train_ds += [current_train]
104 | test_ds += [current_test]
105 | current_train, current_test = None, None
106 | task_i += 1
107 |
108 | train_ds, val_ds = make_valid_from_train(train_ds,
109 | cut=MiniImagenetDatasetFolder.train_val_pc) # uses
110 | # random split, but seed set in main script
111 |
112 | # now we have list of list of (path, label), one list per task
113 | # pick the right source, flatten into one list and load images
114 |
115 | data_summary = {"train": train_ds, "val": val_ds, "test": test_ds}[self.data_type]
116 |
117 | self.data = []
118 | self.targets = []
119 | task_lengths = []
120 | for task_ds in data_summary:
121 | num_samples_task = len(task_ds[0])
122 | assert (len(task_ds[1]) == num_samples_task)
123 | task_lengths.append(num_samples_task)
124 | for i in range(num_samples_task):
125 | img_path = task_ds[0][i]
126 | label = task_ds[1][i]
127 | self.data.append(img_path)
128 | self.targets.append(label)
129 |
130 | print(self.task_dict_classes)
131 |
132 | # if stationary, shuffle
133 | if not self.non_stat:
134 | perm = np.random.permutation(len(self.data))
135 | self.data, self.targets = [self.data[perm_i] for perm_i in perm], [self.targets[perm_i] for
136 | perm_i in perm]
137 |
138 | self.orig_len = len(self.data)
139 | self.actual_len = self.orig_len * self.num_iterations
140 |
141 | if self.non_stat: # we need to care about looping over in task order
142 | assert (self.orig_len % self.num_classes == 0)
143 | self.orig_samples_per_task = int(
144 | self.orig_len / self.num_classes) * self.classes_per_task # equally split among tasks
145 |
146 | self.actual_samples_per_task = self.orig_samples_per_task * self.num_iterations
147 |
148 | # sanity
149 | if self.data_type == "train":
150 | assert (self.orig_samples_per_task == (int(
151 | 600 * MiniImagenetDatasetFolder.train_test_pc * MiniImagenetDatasetFolder.train_val_pc)
152 | * self.classes_per_task))
153 |
154 | if self.data_type == "val":
155 | assert (self.orig_samples_per_task == (int(600 * MiniImagenetDatasetFolder.train_test_pc * (
156 | 1. - MiniImagenetDatasetFolder.train_val_pc)) * self.classes_per_task))
157 |
158 | if self.data_type == "test":
159 | assert (self.orig_samples_per_task == (
160 | int(600 * (1. - MiniImagenetDatasetFolder.train_test_pc)) * self.classes_per_task))
161 |
162 | print("orig samples per task: %d, actual samples per task: %d" % (
163 | self.orig_samples_per_task, self.actual_samples_per_task))
164 |
165 | def __len__(self):
166 | return self.actual_len
167 |
168 | def __getitem__(self, index):
169 | assert (index < self.actual_len)
170 |
171 | if not self.non_stat:
172 | index = index % self.orig_len # looping over stationary data is arbitrary
173 | else:
174 | task_i, actual_offset = divmod(index, self.actual_samples_per_task)
175 | _, orig_offset = divmod(actual_offset, self.orig_samples_per_task)
176 | index = task_i * self.orig_samples_per_task + orig_offset
177 |
178 | sample_path, target = self.data[index], self.targets[index]
179 |
180 | with open(sample_path, "rb") as f:
181 | sample = Image.open(f).convert('RGB')
182 |
183 | if self.transform is not None:
184 | sample = self.transform(sample)
185 | if self.target_transform is not None:
186 | target = self.target_transform(target)
187 |
188 | return sample, target
189 |
190 | def __len__(self):
191 | return self.actual_len
192 |
193 |
194 | class miniimagenet(MiniImagenetDatasetFolder):
195 | def __init__(self, root, data_type, transform=None, target_transform=None, non_stat=None,
196 | classes_per_task=None, num_iterations=None):
197 | assert (non_stat is not None)
198 | if data_type == "val" or data_type == "test":
199 | assert (num_iterations == 1)
200 |
201 | super(miniimagenet, self).__init__(root,
202 | data_type=data_type,
203 | transform=transform, target_transform=target_transform,
204 | non_stat=non_stat, classes_per_task=classes_per_task,
205 | num_iterations=num_iterations)
206 |
207 |
208 | class miniimagenetval(miniimagenet):
209 | def __init__(self, root, transform=None, non_stat=False, classes_per_task=None):
210 | super(miniimagenetval, self).__init__(root, data_type="val",
211 | non_stat=non_stat, classes_per_task=classes_per_task,
212 | transform=transform, num_iterations=1)
213 |
--------------------------------------------------------------------------------
/code/data/mnist.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 |
3 | import numpy as np
4 | import torch
5 | from torchvision import datasets
6 | from torchvision.datasets.vision import VisionDataset
7 |
8 | from code.util.general import make_valid_from_train
9 |
10 |
11 | # Reference: https://github.com/optimass/Maximally_Interfered_Retrieval/blob/master/data.py
12 | # We use 1 dataloader rather than one per task
13 |
14 | class mnist5k(VisionDataset):
15 | train_val_pc = 0.95
16 |
17 | def __init__(self, root, data_type=None, non_stat=False, num_iterations=None, classes_per_task=None):
18 | super(mnist5k, self).__init__(root, transform=None, target_transform=None)
19 |
20 | self.data_type = data_type
21 | self.non_stat = non_stat
22 | self.classes_per_task = classes_per_task
23 | assert(self.classes_per_task == 2)
24 | self.num_classes = 10
25 | self.orig_train_samples_per_class = 500
26 |
27 | self.num_iterations = num_iterations
28 | assert (num_iterations is not None)
29 |
30 | # Load data ------------------------------------------------------------------------------------
31 | # splits are deterministic
32 |
33 | # follows https://github.com/optimass/Maximally_Interfered_Retrieval/
34 |
35 | train = datasets.MNIST(root, train=True, download=False)
36 | test = datasets.MNIST(root, train=False, download=False)
37 |
38 | train_x, train_y = train.data, train.targets # 60000, 28, 28; 60000
39 | test_x, test_y = test.data, test.targets
40 |
41 | # sort by label
42 | train_ds, test_ds = [], [] # doesn't really matter for test_ds because of batchnorm tracking
43 | # stats
44 | task_i = 0
45 | current_train, current_test = None, None
46 | self.task_dict_classes = defaultdict(list)
47 | for i in range(self.num_classes):
48 | self.task_dict_classes[task_i].append(i)
49 | train_i = train_y == i
50 | test_i = test_y == i
51 |
52 | if current_train is None:
53 | current_train, current_test = (train_x[train_i], train_y[train_i]), (
54 | test_x[test_i], test_y[test_i])
55 | else:
56 | current_train = (torch.cat((current_train[0], train_x[train_i]), dim=0),
57 | torch.cat((current_train[1], train_y[train_i]), dim=0))
58 | current_test = (torch.cat((current_test[0], test_x[test_i]), dim=0),
59 | torch.cat((current_test[1], test_y[test_i]), dim=0))
60 |
61 | if i % self.classes_per_task == (self.classes_per_task - 1):
62 | train_ds += [current_train]
63 | test_ds += [current_test]
64 | current_train, current_test = None, None
65 | task_i += 1
66 |
67 | # separate validation set (randomised)
68 | train_ds, val_ds = make_valid_from_train(train_ds, cut=mnist5k.train_val_pc)
69 |
70 | # flatten into single list, and truncate training data into 500 per class
71 | data_summary = {"train": train_ds, "val": val_ds, "test": test_ds}[self.data_type]
72 | self.data = [] # list of tensors
73 | self.targets = []
74 | counts_per_class = torch.zeros(self.num_classes, dtype=torch.int)
75 | task_lengths = []
76 | for task_ds in data_summary:
77 | assert (len(task_ds[1]) == len(task_ds[0]))
78 |
79 | num_samples_task = 0
80 | for i in range(len(task_ds[1])):
81 | target = task_ds[1][i]
82 | if self.data_type == "train" and counts_per_class[
83 | target] == self.orig_train_samples_per_class:
84 | continue
85 | else:
86 | self.data.append(task_ds[0][i])
87 | self.targets.append(target)
88 | counts_per_class[target] += 1
89 | num_samples_task += 1
90 |
91 | task_lengths.append(num_samples_task)
92 |
93 | print(self.task_dict_classes)
94 |
95 | # if stationary, shuffle
96 | if not self.non_stat:
97 | perm = np.random.permutation(len(self.data))
98 | self.data, self.targets = [self.data[perm_i] for perm_i in perm], [self.targets[perm_i] for
99 | perm_i in perm]
100 |
101 | self.orig_len = len(self.data)
102 | self.actual_len = self.orig_len * self.num_iterations
103 |
104 | if self.non_stat: # we need to care about looping over in task order
105 | assert (self.orig_len % self.num_classes == 0)
106 | self.orig_samples_per_task = int(
107 | self.orig_len / self.num_classes) * self.classes_per_task # equally split among tasks
108 | self.actual_samples_per_task = self.orig_samples_per_task * self.num_iterations
109 |
110 | # sanity
111 | if self.data_type == "train": assert (self.orig_samples_per_task == 1000)
112 |
113 | print("orig samples per task: %d, actual samples per task: %d" % (
114 | self.orig_samples_per_task, self.actual_samples_per_task))
115 |
116 | def __len__(self):
117 | return self.actual_len
118 |
119 | def __getitem__(self, index):
120 | assert (index < self.actual_len)
121 |
122 | if not self.non_stat:
123 | index = index % self.orig_len # looping over stationary data is arbitrary
124 | else:
125 | task_i, actual_offset = divmod(index, self.actual_samples_per_task)
126 | _, orig_offset = divmod(actual_offset, self.orig_samples_per_task)
127 | index = task_i * self.orig_samples_per_task + orig_offset
128 |
129 | sample, target = self.data[index], self.targets[index]
130 | sample = sample.view(-1).float() / 255. # flatten and turn from uint8 (255) -> [0., 1.]
131 |
132 | assert (self.transform is None)
133 | assert (self.target_transform is None)
134 |
135 | return sample, target
136 |
137 | def __len__(self):
138 | return self.actual_len
139 |
--------------------------------------------------------------------------------
/code/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .mlp import mlp
2 | from .resnet18 import *
--------------------------------------------------------------------------------
/code/models/mlp.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch.nn as nn
3 |
4 |
5 | # Consistent with https://github.com/optimass/Maximally_Interfered_Retrieval
6 |
7 | class mlp(nn.Module):
8 | def __init__(self, config):
9 | super(mlp, self).__init__()
10 |
11 | self.nf = 400
12 |
13 | self.input_size = np.prod(config.task_in_dims)
14 | self.hidden = nn.Sequential(nn.Linear(self.input_size, self.nf),
15 | nn.ReLU(True),
16 | nn.Linear(self.nf, self.nf),
17 | nn.ReLU(True))
18 |
19 | self.linear = nn.Linear(self.nf, np.prod(config.task_out_dims))
20 |
21 | def forward(self, x):
22 | x = x.view(-1, self.input_size)
23 | x = self.hidden(x)
24 | return self.linear(x)
25 |
--------------------------------------------------------------------------------
/code/models/resnet18.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | # Consistent with https://github.com/optimass/Maximally_Interfered_Retrieval
5 |
6 |
7 | class BasicBlock(nn.Module):
8 | """Basic Block for resnet 18 and resnet 34
9 | """
10 |
11 | # BasicBlock and BottleNeck block
12 | # have different output size
13 | # we use class attribute expansion
14 | # to distinct
15 | expansion = 1
16 |
17 | def __init__(self, in_channels, out_channels, stride=1, use_batchnorm=True, batchnorm_mom=None,
18 | batchnorm_dont_track=False):
19 | super().__init__()
20 |
21 | # residual function
22 | seq = [
23 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)]
24 | if use_batchnorm:
25 | if batchnorm_mom is None:
26 | seq += [nn.BatchNorm2d(out_channels, track_running_stats=(not batchnorm_dont_track))]
27 | else:
28 | seq += [nn.BatchNorm2d(out_channels, momentum=batchnorm_mom,
29 | track_running_stats=(not batchnorm_dont_track))]
30 |
31 | seq += [
32 | nn.ReLU(inplace=True),
33 | nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1,
34 | bias=False),
35 | ]
36 | if use_batchnorm:
37 | if batchnorm_mom is None:
38 | seq += [nn.BatchNorm2d(out_channels * BasicBlock.expansion,
39 | track_running_stats=(not batchnorm_dont_track))]
40 | else:
41 | seq += [nn.BatchNorm2d(out_channels * BasicBlock.expansion, momentum=batchnorm_mom,
42 | track_running_stats=(not batchnorm_dont_track))]
43 |
44 | self.residual_function = nn.Sequential(*seq)
45 |
46 | # shortcut
47 | self.shortcut = nn.Sequential()
48 |
49 | # the shortcut output dimension is not the same with residual function
50 | # use 1*1 convolution to match the dimension
51 | if stride != 1 or in_channels != BasicBlock.expansion * out_channels:
52 | shortcut_seq = [
53 | nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride,
54 | bias=False)]
55 | if use_batchnorm:
56 | if batchnorm_mom is None:
57 | shortcut_seq += [nn.BatchNorm2d(out_channels * BasicBlock.expansion,
58 | track_running_stats=(not batchnorm_dont_track))]
59 | else:
60 | shortcut_seq += [
61 | nn.BatchNorm2d(out_channels * BasicBlock.expansion, momentum=batchnorm_mom,
62 | track_running_stats=(not batchnorm_dont_track))]
63 |
64 | self.shortcut = nn.Sequential(*shortcut_seq)
65 |
66 | def forward(self, x):
67 | return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))
68 |
69 |
70 | class ResNet(nn.Module):
71 | def __init__(self, block, num_block, num_classes=100, use_batchnorm=True, batchnorm_mom=None,
72 | batchnorm_dont_track=False, in_channels=64, init=False, batchnorm_init=False,
73 | linear_sz=None):
74 | super().__init__()
75 |
76 | self.num_classes = num_classes
77 | self.use_batchnorm = use_batchnorm
78 | self.batchnorm_mom = batchnorm_mom
79 |
80 | self.in_channels = in_channels
81 |
82 | self.conv1 = nn.Sequential(
83 | nn.Conv2d(3, in_channels, kernel_size=3, padding=1, bias=False),
84 | nn.BatchNorm2d(in_channels, momentum=self.batchnorm_mom,
85 | track_running_stats=(not batchnorm_dont_track)),
86 | nn.ReLU(inplace=True))
87 | # we use a different inputsize than the original paper
88 | # so conv2_x's stride is 1
89 |
90 | # make layer resets in_channels to be out channels
91 | self.conv2_x = self._make_layer(block, in_channels, num_block[0], 1,
92 | batchnorm_mom=self.batchnorm_mom,
93 | batchnorm_dont_track=batchnorm_dont_track)
94 | self.conv3_x = self._make_layer(block, in_channels * 2, num_block[1], 2,
95 | batchnorm_mom=self.batchnorm_mom,
96 | batchnorm_dont_track=batchnorm_dont_track)
97 | self.conv4_x = self._make_layer(block, in_channels * 4, num_block[2], 2,
98 | batchnorm_mom=self.batchnorm_mom,
99 | batchnorm_dont_track=batchnorm_dont_track)
100 | self.conv5_x = self._make_layer(block, in_channels * 8, num_block[3], 2,
101 | batchnorm_mom=self.batchnorm_mom,
102 | batchnorm_dont_track=batchnorm_dont_track)
103 |
104 | if self.num_classes == 10:
105 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
106 |
107 | self.fc1 = nn.Linear(linear_sz, self.num_classes)
108 |
109 | """
110 | if init: # else default, which is uniform
111 | print("calling _initialise")
112 | self._initialise()
113 |
114 | if batchnorm_init: # else default, which is all 1s since 1.2
115 | print("calling _batchnorm_initialise")
116 | self._batchnorm_initialise()
117 | """
118 |
119 | def _initialise(self):
120 | for m in self.modules():
121 | if isinstance(m, nn.Conv2d):
122 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
123 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
124 | nn.init.constant_(m.weight, 1)
125 | nn.init.constant_(m.bias, 0)
126 |
127 | def _batchnorm_initialise(self):
128 | for m in self.modules():
129 | if isinstance(m, nn.BatchNorm2d): # pre 1.2 default init
130 | nn.init.uniform_(m.weight, a=0.0, b=1.0)
131 | nn.init.constant_(m.bias, 0)
132 |
133 | def _make_layer(self, block, out_channels, num_blocks, stride, batchnorm_mom=None,
134 | batchnorm_dont_track=False):
135 | """make resnet layers(by layer i didnt mean this 'layer' was the
136 | same as a neuron netowork layer, ex. conv layer), one layer may
137 | contain more than one residual block
138 |
139 | Args:
140 | block: block type, basic block or bottle neck block
141 | out_channels: output depth channel number of this layer
142 | num_blocks: how many blocks per layer
143 | stride: the stride of the first block of this layer
144 |
145 | Return:
146 | return a resnet layer
147 | """
148 |
149 | # we have num_block blocks per layer, the first block
150 | # could be 1 or 2, other blocks would always be 1
151 | strides = [stride] + [1] * (num_blocks - 1)
152 | layers = []
153 | for stride in strides:
154 | layers.append(block(self.in_channels, out_channels, stride,
155 | use_batchnorm=self.use_batchnorm, batchnorm_mom=batchnorm_mom,
156 | batchnorm_dont_track=batchnorm_dont_track))
157 | self.in_channels = out_channels * block.expansion
158 |
159 | return nn.Sequential(*layers)
160 |
161 | def forward(self, x):
162 | x = self.conv1(x)
163 | x = self.conv2_x(x)
164 | x = self.conv3_x(x)
165 | x = self.conv4_x(x)
166 | x = self.conv5_x(x)
167 |
168 | if self.num_classes == 10:
169 | x = self.avg_pool(x) # 1x1 (as Aljundi)
170 | else:
171 | x = nn.functional.avg_pool2d(x, 4) # 2x2 (as Aljundi)
172 |
173 | x = x.view(x.size(0), -1)
174 |
175 | return self.fc1(x)
176 |
177 |
178 | def _batch_stats_hook(b, input):
179 | if isinstance(input, tuple):
180 | assert (len(input) == 1)
181 | input = input[0]
182 |
183 | assert (len(input.shape) == 4)
184 |
185 | stored_mean = b.running_mean
186 | stored_var = b.running_var
187 |
188 | curr_mean = input.mean(dim=(0, 2, 3))
189 | curr_var = input.var(dim=(0, 2, 3))
190 |
191 | assert (stored_mean.shape == (input.shape[1],))
192 | assert (stored_var.shape == (input.shape[1],))
193 | assert (curr_mean.shape == (input.shape[1],))
194 | assert (curr_var.shape == (input.shape[1],))
195 |
196 | b.batch_stats_loss = torch.norm(curr_mean - stored_mean, p=2) + torch.norm(curr_var - stored_var,
197 | p=2)
198 | assert (b.batch_stats_loss.shape == torch.Size([])) # scalar
199 |
200 |
201 | class resnet18(ResNet):
202 | def __init__(self, config):
203 | # Newer resnet code uses avgpool but Aljundi code uses 4 pool.
204 | if config.data == "miniimagenet":
205 | num_classes = 100
206 | linear_sz = 160 * 2 * 2
207 | elif config.data == "cifar10":
208 | num_classes = 10
209 | linear_sz = 160 * 1 * 1
210 |
211 | super(resnet18, self).__init__(BasicBlock, [2, 2, 2, 2], num_classes=num_classes,
212 | in_channels=20,
213 | use_batchnorm=True,
214 | batchnorm_dont_track=False,
215 | linear_sz=linear_sz)
216 |
217 |
218 | class resnet18_batch_stats(resnet18):
219 | def __init__(self, config):
220 | super(resnet18_batch_stats, self).__init__(config)
221 |
222 | self._compute_batch_stats_loss = False
223 |
224 | def compute_batch_stats_loss(self):
225 | self._compute_batch_stats_loss = True
226 |
227 | self.num_batchnorms = 0
228 | for m in self.modules():
229 | if isinstance(m, nn.BatchNorm2d):
230 | m.register_forward_pre_hook(_batch_stats_hook)
231 | m.batch_stats_loss = None
232 | self.num_batchnorms += 1
233 |
234 | def forward(self, x):
235 | # return loss for batch stats
236 |
237 | x = self.conv1(x)
238 | x = self.conv2_x(x)
239 | x = self.conv3_x(x)
240 | x = self.conv4_x(x)
241 | x = self.conv5_x(x)
242 |
243 | if self.num_classes == 10:
244 | x = self.avg_pool(x)
245 | else:
246 | x = nn.functional.avg_pool2d(x, 4)
247 | x = x.view(x.size(0), -1)
248 |
249 | x = self.fc1(x)
250 |
251 | if self._compute_batch_stats_loss:
252 | batch_stats_losses = []
253 | for m in self.modules():
254 | if isinstance(m, nn.BatchNorm2d):
255 | assert (m.batch_stats_loss is not None)
256 | batch_stats_losses.append(m.batch_stats_loss)
257 | m.batch_stats_loss = None
258 | assert (len(batch_stats_losses) == self.num_batchnorms)
259 | return x, torch.mean(torch.stack(batch_stats_losses))
260 | else:
261 | return x
--------------------------------------------------------------------------------
/code/scripts/ADI.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from copy import deepcopy
4 | from datetime import datetime
5 |
6 | import torch.optim as optim
7 |
8 | from code.util.eval import evaluate_basic
9 | from code.util.general import *
10 | from code.util.load import *
11 | from code.util.losses import *
12 | from code.util.render import render_aux_x
13 |
14 | # --------------------------------------------------------------------------------------------------
15 | # Settings
16 | # --------------------------------------------------------------------------------------------------
17 |
18 | orig_config = argparse.ArgumentParser(allow_abbrev=False)
19 |
20 | orig_config.add_argument("--model_ind_start", type=int, required=True)
21 |
22 | orig_config.add_argument("--num_runs", type=int, required=True)
23 |
24 | orig_config.add_argument("--out_root", type=str, required=True)
25 |
26 | # Data and model
27 |
28 | orig_config.add_argument("--data", type=str, required=True)
29 |
30 | orig_config.add_argument("--data_path", type=str, required=True)
31 |
32 | orig_config.add_argument("--stationary", default=False, action="store_true")
33 |
34 | orig_config.add_argument("--classes_per_task", type=int, required=True)
35 |
36 | orig_config.add_argument("--max_t", type=int, required=True)
37 |
38 | orig_config.add_argument("--tasks_train_batch_sz", type=int, default=10)
39 |
40 | orig_config.add_argument("--tasks_eval_batch_sz", type=int, default=10)
41 |
42 | orig_config.add_argument("--num_iterations", type=int, default=1)
43 |
44 | orig_config.add_argument("--task_model_type", type=str, required=True)
45 |
46 | # Adaptive DeepInversion
47 |
48 | orig_config.add_argument("--recall_from_t", type=int, required=True)
49 |
50 | orig_config.add_argument("--M", type=int, default=100)
51 |
52 | orig_config.add_argument("--lr", type=float, default=0.1)
53 |
54 | orig_config.add_argument("--refine_sample_lr", type=float, default=0.1)
55 |
56 | orig_config.add_argument("--refine_sample_steps", type=int, default=10)
57 |
58 | orig_config.add_argument("--aux_weight", type=float, default=1.0)
59 |
60 | orig_config.add_argument("--no_aux_distill", default=False, action="store_true")
61 |
62 | orig_config.add_argument("--aux_distill_weight", type=float, default=1.0)
63 |
64 | orig_config.add_argument("--adaptive", default=False, action="store_true")
65 |
66 | orig_config.add_argument("--adaptive_weight", type=float, default=1.0)
67 |
68 | orig_config.add_argument("--classes_loss_weight", type=float, default=1.0)
69 |
70 | orig_config.add_argument("--choose_past_classes", default=False, action="store_true")
71 |
72 | orig_config.add_argument("--opt_batch_stats", default=False, action="store_true")
73 |
74 | orig_config.add_argument("--opt_batch_stats_weight", type=float, default=1.0)
75 |
76 | orig_config.add_argument("--sharpen_class", default=False, action="store_true")
77 |
78 | orig_config.add_argument("--sharpen_class_weight", type=float, default=1.0)
79 |
80 | orig_config.add_argument("--TV", default=False, action="store_true")
81 |
82 | orig_config.add_argument("--TV_weight", type=float, default=1.0)
83 |
84 | orig_config.add_argument("--L2", default=False, action="store_true")
85 |
86 | orig_config.add_argument("--L2_weight", type=float, default=1.0)
87 |
88 | orig_config.add_argument("--long_window", default=False, action="store_true")
89 |
90 | orig_config.add_argument("--long_window_range", type=int, nargs="+", default=[1, 1]) # inclusive
91 |
92 | orig_config.add_argument("--use_fixed_window", default=False, action="store_true")
93 |
94 | orig_config.add_argument("--fixed_window", type=int, default=1)
95 |
96 | # Included for testing:
97 |
98 | orig_config.add_argument("--refine_theta_steps", type=int, default=1)
99 |
100 | orig_config.add_argument("--refine_sample_from_scratch", default=False, action="store_true")
101 |
102 | orig_config.add_argument("--hard_targets", default=False, action="store_true")
103 |
104 | orig_config.add_argument("--aux_x_random", default=False, action="store_true")
105 |
106 | orig_config.add_argument("--no_classes_loss", default=False, action="store_true")
107 |
108 | # Admin
109 |
110 | orig_config.add_argument("--cuda", default=False, action="store_true")
111 |
112 | orig_config.add_argument("--eval_freq", type=int, required=True)
113 |
114 | orig_config.add_argument("--store_results_freq", type=int, required=True)
115 |
116 | orig_config.add_argument("--store_model_freq", type=int, required=True)
117 |
118 | orig_config.add_argument("--specific_torch_seed", default=False, action="store_true")
119 |
120 | orig_config.add_argument("--torch_seed", type=int)
121 |
122 | orig_config.add_argument("--render_aux_x", default=False, action="store_true")
123 |
124 | orig_config.add_argument("--render_aux_x_freq", type=int, default=50)
125 |
126 | orig_config.add_argument("--render_aux_x_num", type=int, default=3)
127 |
128 | orig_config = orig_config.parse_args()
129 |
130 |
131 | def main(config):
132 | # ------------------------------------------------------------------------------------------------
133 | # Setup
134 | # ------------------------------------------------------------------------------------------------
135 |
136 | reproc_settings(config)
137 |
138 | config.out_dir = osp.join(config.out_root, str(config.model_ind))
139 |
140 | tasks_model, trainloader, testloader, valloader = get_model_and_data(config)
141 |
142 | if not osp.exists(config.out_dir):
143 | os.makedirs(config.out_dir)
144 |
145 | next_t = 0
146 | last_classes = None
147 | seen_classes = None
148 |
149 | if config.long_window:
150 | old_tasks_model = None
151 | config.next_update_old_model_t_history = []
152 | config.next_update_old_model_t = 0 # old model needs to be set in first timestep
153 | assert (len(config.long_window_range) == 2)
154 |
155 | optimizer = optim.SGD(tasks_model.parameters(), lr=config.lr, momentum=0, dampening=0,
156 | weight_decay=0, nesterov=False)
157 |
158 | refine_sample_metrics = []
159 | refine_theta_metrics = []
160 |
161 | if not config.no_aux_distill: refine_theta_metrics.append("loss_aux_distill")
162 | if not config.no_classes_loss: refine_sample_metrics.append("classes_loss")
163 | if config.adaptive: refine_sample_metrics.append("adaptive_loss")
164 | if config.sharpen_class: refine_sample_metrics.append("sharpen_class_loss")
165 | if config.TV: refine_sample_metrics.append("TV_loss")
166 | if config.L2: refine_sample_metrics.append("L2_loss")
167 | if config.opt_batch_stats: refine_sample_metrics.append("opt_batch_stats_loss")
168 |
169 | refine_sample_metrics += ["loss_refine"]
170 | refine_theta_metrics += ["not_present_class", "loss_aux", "final_loss_aux"]
171 |
172 | t = next_t
173 |
174 | while t <= config.max_t:
175 | for xs, ys in trainloader:
176 | present_classes = ys.unique().to(get_device(config.cuda))
177 |
178 | # --------------------------------------------------------------------------------------------
179 | # Eval
180 | # --------------------------------------------------------------------------------------------
181 |
182 | if (t - next_t) < 1000 or t % 100 == 0:
183 | print("m %d t: %d %s, targets %s" % (
184 | config.model_ind, t, datetime.now(), str(list(present_classes.cpu().numpy()))))
185 | sys.stdout.flush()
186 |
187 | save_dict = {"tasks_model": tasks_model, "t": t, "last_classes": last_classes,
188 | "seen_classes": seen_classes}
189 | if config.long_window: # else no need to save because it's always the one just before
190 | # current update
191 | save_dict["old_tasks_model"] = old_tasks_model
192 |
193 | last_step = t == (config.max_t)
194 | if (t % config.eval_freq == 0) or (t % config.batches_per_epoch == 0) or last_step or (
195 | t == 0):
196 | evaluate_basic(config, tasks_model, valloader, t, is_val=True,
197 | last_classes=last_classes, seen_classes=seen_classes)
198 | evaluate_basic(config, tasks_model, testloader, t, is_val=False,
199 | last_classes=last_classes, seen_classes=seen_classes)
200 |
201 | if (t % config.store_model_freq == 0) or last_step:
202 | torch.save(save_dict, osp.join(config.out_dir, "latest_models.pytorch"))
203 |
204 | if (t % config.store_results_freq == 0) or last_step:
205 | render_graphs(config)
206 | store(config)
207 |
208 | if last_step:
209 | return
210 |
211 | # --------------------------------------------------------------------------------------------
212 | # Train
213 | # --------------------------------------------------------------------------------------------
214 |
215 | tasks_model.train()
216 |
217 | xs = xs.to(get_device(config.cuda))
218 | ys = ys.to(get_device(config.cuda))
219 |
220 | curr_classes = ys.unique()
221 | assert (curr_classes.max() < config.task_out_dims[0])
222 |
223 | optimizer.zero_grad()
224 |
225 | # set old_tasks_model if needed
226 | if not config.long_window or (config.long_window and t == config.next_update_old_model_t):
227 | old_tasks_model = deepcopy(tasks_model)
228 | if config.opt_batch_stats: old_tasks_model.compute_batch_stats_loss()
229 |
230 | if config.long_window:
231 | config.next_update_old_model_t_history.append(config.next_update_old_model_t)
232 | if config.use_fixed_window:
233 | window_offset = config.fixed_window
234 | else:
235 | window_offset = np.random.randint(config.long_window_range[0], high=(
236 | config.long_window_range[1] + 1)) # randint is excl
237 |
238 | config.next_update_old_model_t = t + window_offset
239 |
240 | # --------------------------------------------------------------------------------------------
241 | # Train on real data
242 | # --------------------------------------------------------------------------------------------
243 |
244 | preds = tasks_model(xs)
245 | loss_orig = F.cross_entropy(preds, ys, reduction="mean")
246 | record_and_check(config, "loss_orig", loss_orig.item(), t)
247 | loss_orig.backward()
248 | optimizer.step()
249 |
250 | # --------------------------------------------------------------------------------------------
251 | # Generate data
252 | # --------------------------------------------------------------------------------------------
253 |
254 | if t >= config.recall_from_t:
255 | optimizer.zero_grad()
256 |
257 | metrics = dict([(metric, 0.) for metric in refine_sample_metrics + refine_theta_metrics])
258 |
259 | # pick classes for classes_loss
260 | num_classes = int(np.prod(config.task_out_dims))
261 | if not config.choose_past_classes:
262 | classes_to_refine = torch.tensor(
263 | np.random.choice(num_classes, config.M, replace=(config.M > num_classes)),
264 | dtype=torch.long, device=get_device(config.cuda))
265 | else:
266 | # Explicitly pick seen classes excluding present classes. There will be at least 2 bc
267 | # recall from 2nd task
268 | seen_classes_excl_pres = seen_classes.clone()
269 | num_seen = seen_classes.shape[0]
270 | for c in present_classes:
271 | seen_classes_excl_pres = seen_classes_excl_pres[seen_classes_excl_pres != c]
272 | num_seen_excl_pres = seen_classes_excl_pres.shape[0]
273 |
274 | assert (num_seen_excl_pres <= num_seen)
275 | chosen_inds = np.random.choice(num_seen_excl_pres, config.M,
276 | replace=(config.M > num_seen_excl_pres))
277 | classes_to_refine = seen_classes_excl_pres[chosen_inds]
278 |
279 | for r in range(config.refine_theta_steps):
280 | if config.refine_sample_from_scratch or r == 0:
281 | if not config.aux_x_random:
282 | aux_x = xs[np.random.choice(xs.shape[0], config.M, replace=(config.M > xs.shape[0]))]
283 | else:
284 | aux_x = torch.rand((config.M,) + xs.shape[1:]).to(get_device(config.cuda))
285 |
286 | aux_x_orig = aux_x.clone()
287 | aux_x.requires_grad_(True)
288 |
289 | for s in range(config.refine_sample_steps):
290 | aux_preds_old = old_tasks_model(aux_x)
291 | aux_preds_new = tasks_model(aux_x)
292 |
293 | if config.opt_batch_stats:
294 | aux_preds_old, opt_batch_stats_loss = aux_preds_old
295 |
296 | loss_refine = torch.tensor(0.).to(get_device(config.cuda))
297 | if not config.no_classes_loss:
298 | classes_loss = deep_inversion_classes_loss(classes_to_refine, aux_preds_old)
299 | loss_refine += config.classes_loss_weight * classes_loss
300 |
301 | if config.adaptive:
302 | adaptive_loss = neg_symmetric_KL(aux_preds_old, aux_preds_new)
303 | loss_refine += config.adaptive_weight * adaptive_loss
304 |
305 | if config.sharpen_class:
306 | sharpen_class_loss = sharpen_class(aux_preds_old)
307 | loss_refine += config.sharpen_class_weight * sharpen_class_loss
308 |
309 | if config.TV:
310 | TV_loss = TV(aux_x)
311 | loss_refine += config.TV_weight * TV_loss
312 |
313 | if config.L2:
314 | L2_loss = L2(aux_x)
315 | loss_refine += config.L2_weight * L2_loss
316 |
317 | if config.opt_batch_stats:
318 | loss_refine += config.opt_batch_stats_weight * opt_batch_stats_loss
319 |
320 | for metric in refine_sample_metrics:
321 | metrics[metric] += locals()[metric].item()
322 |
323 | aux_x_grads = \
324 | torch.autograd.grad(loss_refine, aux_x, only_inputs=True, retain_graph=False)[0]
325 | aux_x = (aux_x - config.refine_sample_lr * aux_x_grads).detach().requires_grad_(True)
326 |
327 | aux_x.requires_grad_(False)
328 | with torch.no_grad():
329 | aux_y = old_tasks_model(aux_x)
330 | distill_xs_targets = old_tasks_model(xs)
331 |
332 | if config.opt_batch_stats:
333 | aux_y, _ = aux_y
334 | distill_xs_targets, _ = distill_xs_targets
335 |
336 | if config.render_aux_x and t % config.render_aux_x_freq == 0:
337 | render_aux_x(config, t, r, aux_x_orig, aux_x, aux_y, present_classes)
338 |
339 | aux_y_hard = aux_y.argmax(dim=1)
340 |
341 | if not hasattr(config, "aux_y_hard"):
342 | config.aux_y_hard = OrderedDict()
343 | if t not in config.aux_y_hard:
344 | config.aux_y_hard[t] = aux_y_hard.cpu()
345 | else:
346 | config.aux_y_hard[t] = torch.cat((config.aux_y_hard[t], aux_y_hard.cpu())).unique()
347 |
348 | if not hasattr(config, "aux_y_probs"):
349 | config.aux_y_probs = OrderedDict()
350 | aux_y_probs = F.softmax(aux_y, dim=1).mean(dim=0)
351 | if t not in config.aux_y_probs:
352 | config.aux_y_probs[t] = aux_y_probs.cpu() * (1. / config.refine_theta_steps) # n c
353 | else:
354 | config.aux_y_probs[t] += (aux_y_probs.cpu() * (1. / config.refine_theta_steps))
355 |
356 | is_present_class = 0
357 | for c in present_classes:
358 | is_present_class += (aux_y_hard == c).sum().item()
359 | metrics["not_present_class"] += aux_y_hard.shape[0] - is_present_class
360 |
361 | # ----------------------------------------------------------------------------------------
362 | # Train on generated data
363 | # ----------------------------------------------------------------------------------------
364 |
365 | preds = tasks_model(aux_x)
366 | if not config.hard_targets:
367 | loss_aux = F.kl_div(F.log_softmax(preds, dim=1),
368 | F.softmax(aux_y, dim=1), reduction="batchmean")
369 | else:
370 | loss_aux = F.cross_entropy(preds, aux_y_hard, reduction="mean")
371 |
372 | final_loss_aux = loss_aux * config.aux_weight
373 |
374 | if not config.no_aux_distill:
375 | loss_aux_distill = F.kl_div(F.log_softmax(tasks_model(xs), dim=1),
376 | F.softmax(distill_xs_targets, dim=1), reduction="batchmean")
377 | final_loss_aux += config.aux_distill_weight * loss_aux_distill
378 | metrics["loss_aux_distill"] += loss_aux_distill.item()
379 |
380 | metrics["loss_aux"] += loss_aux.item()
381 | metrics["final_loss_aux"] += final_loss_aux.item()
382 |
383 | final_loss_aux.backward()
384 | optimizer.step()
385 |
386 | for metric in refine_sample_metrics:
387 | metrics[metric] /= float(config.refine_sample_steps * config.refine_theta_steps)
388 | record_and_check(config, metric, metrics[metric], t)
389 |
390 | for metric in refine_theta_metrics:
391 | metrics[metric] /= float(config.refine_theta_steps)
392 | record_and_check(config, metric, metrics[metric], t)
393 |
394 | t += 1
395 | if seen_classes is None:
396 | seen_classes = present_classes
397 | else:
398 | seen_classes = torch.cat((seen_classes, present_classes)).unique()
399 | last_classes = present_classes
400 |
401 |
402 | if __name__ == "__main__":
403 | ms = range(orig_config.model_ind_start, orig_config.model_ind_start + orig_config.num_runs)
404 | for m in ms:
405 | c = deepcopy(orig_config)
406 | c.model_ind = m
407 | main(c)
408 | print("Done m %d" % m)
409 |
--------------------------------------------------------------------------------
/code/scripts/ARM.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from copy import deepcopy
3 | from datetime import datetime
4 | import os
5 |
6 | import torch.optim as optim
7 |
8 | from code.models import *
9 | from code.util.data import *
10 | from code.util.eval import evaluate_basic
11 | from code.util.general import *
12 | from code.util.load import *
13 | from code.util.losses import *
14 | from code.util.render import render_aux_x
15 |
16 | # --------------------------------------------------------------------------------------------------
17 | # Settings
18 | # --------------------------------------------------------------------------------------------------
19 |
20 | orig_config = argparse.ArgumentParser(allow_abbrev=False)
21 |
22 | orig_config.add_argument("--model_ind_start", type=int, required=True)
23 |
24 | orig_config.add_argument("--num_runs", type=int, required=True)
25 |
26 | orig_config.add_argument("--out_root", type=str, required=True)
27 |
28 | # Data and model
29 |
30 | orig_config.add_argument("--data", type=str, required=True)
31 |
32 | orig_config.add_argument("--data_path", type=str, required=True)
33 |
34 | orig_config.add_argument("--stationary", default=False, action="store_true")
35 |
36 | orig_config.add_argument("--classes_per_task", type=int, required=True)
37 |
38 | orig_config.add_argument("--max_t", type=int, required=True)
39 |
40 | orig_config.add_argument("--tasks_train_batch_sz", type=int, default=10)
41 |
42 | orig_config.add_argument("--tasks_eval_batch_sz", type=int, default=10)
43 |
44 | orig_config.add_argument("--num_iterations", type=int, default=1)
45 |
46 | orig_config.add_argument("--task_model_type", type=str, required=True)
47 |
48 | # ARM
49 |
50 | orig_config.add_argument("--recall_from_t", type=int, required=True)
51 |
52 | orig_config.add_argument("--M", type=int, default=100)
53 |
54 | orig_config.add_argument("--lr", type=float, default=0.1) # eta_0
55 |
56 | orig_config.add_argument("--refine_sample_lr", type=float, default=0.1) # eta_1
57 |
58 | orig_config.add_argument("--refine_sample_steps", type=int, default=10) # S
59 |
60 | orig_config.add_argument("--aux_weight", type=float, default=1.0) # lambda_0 (loss on recall)
61 |
62 | orig_config.add_argument("--aux_distill", default=False, action="store_true")
63 |
64 | orig_config.add_argument("--aux_distill_weight", type=float, default=1.0) # lambda_0 (loss on real)
65 |
66 | orig_config.add_argument("--divergence_loss_weight", type=float, default=1.0) # optional weight D
67 |
68 | orig_config.add_argument("--notlocal_weight", type=float, default=1.0) # lambda_1
69 |
70 | orig_config.add_argument("--notlocal_new_weight", type=float, default=1.0) # lambda_2
71 |
72 | orig_config.add_argument("--diversity_weight", type=float, default=1.0) # lambda_3
73 |
74 | orig_config.add_argument("--sharpen_class", default=False, action="store_true")
75 |
76 | orig_config.add_argument("--sharpen_class_weight", type=float, default=1.0) # lambda_4
77 |
78 | orig_config.add_argument("--L2", default=False, action="store_true")
79 |
80 | orig_config.add_argument("--L2_weight", type=float, default=1.0) # lambda_5
81 |
82 | orig_config.add_argument("--TV", default=False, action="store_true")
83 |
84 | orig_config.add_argument("--TV_weight", type=float, default=1.0) # lambda_6
85 |
86 | orig_config.add_argument("--long_window", default=False, action="store_true")
87 |
88 | orig_config.add_argument("--long_window_range", type=int, nargs="+", default=[1, 1])
89 |
90 | orig_config.add_argument("--use_fixed_window", default=False, action="store_true")
91 |
92 | orig_config.add_argument("--fixed_window", type=int, default=1) # when to update old theta
93 |
94 | # Usually not used, included for testing:
95 |
96 | orig_config.add_argument("--refine_theta_steps", type=int, default=1)
97 |
98 | orig_config.add_argument("--refine_sample_from_scratch", default=False, action="store_true")
99 |
100 | orig_config.add_argument("--hard_targets", default=False, action="store_true")
101 |
102 | orig_config.add_argument("--aux_x_random", default=False, action="store_true")
103 |
104 | orig_config.add_argument("--use_crossent_as_D", default=False, action="store_true")
105 |
106 | orig_config.add_argument("--opt_batch_stats", default=False, action="store_true")
107 |
108 | orig_config.add_argument("--opt_batch_stats_weight", type=float, default=1.0)
109 |
110 | # Admin
111 |
112 | orig_config.add_argument("--cuda", default=False, action="store_true")
113 |
114 | orig_config.add_argument("--eval_freq", type=int, required=True)
115 |
116 | orig_config.add_argument("--store_results_freq", type=int, required=True)
117 |
118 | orig_config.add_argument("--store_model_freq", type=int, required=True)
119 |
120 | orig_config.add_argument("--specific_torch_seed", default=False, action="store_true")
121 |
122 | orig_config.add_argument("--torch_seed", type=int)
123 |
124 | orig_config.add_argument("--render_aux_x", default=False, action="store_true")
125 |
126 | orig_config.add_argument("--render_aux_x_freq", type=int, default=10)
127 |
128 | orig_config.add_argument("--render_aux_x_num", type=int, default=3)
129 |
130 | orig_config.add_argument("--render_separate", default=False, action="store_true")
131 |
132 | orig_config.add_argument("--count_params_only", default=False, action="store_true")
133 |
134 | orig_config = orig_config.parse_args()
135 |
136 |
137 | def main(config):
138 | # ------------------------------------------------------------------------------------------------
139 | # Setup
140 | # ------------------------------------------------------------------------------------------------
141 |
142 | reproc_settings(config)
143 |
144 | config.out_dir = osp.join(config.out_root, str(config.model_ind))
145 |
146 | tasks_model, trainloader, testloader, valloader = get_model_and_data(config)
147 |
148 | if config.count_params_only:
149 | sz = 0
150 | for n, p in tasks_model.named_parameters():
151 | curr_sz = np.prod(tuple(p.data.shape))
152 | print((n, p.data.shape, curr_sz))
153 | sz += curr_sz
154 | print("Num params: %d" % sz)
155 | exit(0)
156 |
157 | if not osp.exists(config.out_dir):
158 | os.makedirs(config.out_dir)
159 |
160 | next_t = 0
161 | last_classes = None # for info only
162 | seen_classes = None
163 |
164 | if config.long_window:
165 | old_tasks_model = None
166 | config.next_update_old_model_t_history = []
167 | config.next_update_old_model_t = 0 # Old model set in first timestep
168 | assert (len(config.long_window_range) == 2)
169 |
170 | optimizer = optim.SGD(tasks_model.parameters(), lr=config.lr, momentum=0, dampening=0,
171 | weight_decay=0, nesterov=False)
172 |
173 | refine_sample_metrics = ["divergence_loss", "notlocal_loss", "notlocal_new_loss",
174 | "diversity_loss", "loss_refine"]
175 | refine_theta_metrics = ["not_present_class", "loss_aux", "final_loss_aux"]
176 |
177 | if config.sharpen_class: refine_sample_metrics.append("sharpen_class_loss")
178 | if config.TV: refine_sample_metrics.append("TV_loss")
179 | if config.L2: refine_sample_metrics.append("L2_loss")
180 | if config.opt_batch_stats: refine_sample_metrics.append("opt_batch_stats_loss")
181 | if config.aux_distill: refine_theta_metrics.append("loss_aux_distill")
182 |
183 | t = next_t
184 |
185 | while t <= config.max_t:
186 | for xs, ys in trainloader:
187 | present_classes = ys.unique().to(get_device(config.cuda))
188 |
189 | # --------------------------------------------------------------------------------------------
190 | # Eval
191 | # --------------------------------------------------------------------------------------------
192 |
193 | if (t - next_t) < 1000 or t % 100 == 0:
194 | print("m %d t: %d %s, targets %s" % (
195 | config.model_ind, t, datetime.now(), str(list(present_classes.cpu().numpy()))))
196 | sys.stdout.flush()
197 |
198 | save_dict = {"tasks_model": tasks_model, "t": t, "last_classes": last_classes,
199 | "seen_classes": seen_classes}
200 | if config.long_window: # else no need to save because it's always the one just before
201 | # current update
202 | save_dict["old_tasks_model"] = old_tasks_model
203 |
204 | last_step = t == (config.max_t)
205 | if (t % config.eval_freq == 0) or (t % config.batches_per_epoch == 0) or last_step or (
206 | t == 0):
207 | evaluate_basic(config, tasks_model, valloader, t, is_val=True,
208 | last_classes=last_classes, seen_classes=seen_classes)
209 | evaluate_basic(config, tasks_model, testloader, t, is_val=False,
210 | last_classes=last_classes, seen_classes=seen_classes)
211 |
212 | if (t % config.store_model_freq == 0) or last_step:
213 | torch.save(save_dict, osp.join(config.out_dir, "latest_models.pytorch"))
214 |
215 | if (t % config.store_results_freq == 0) or last_step:
216 | render_graphs(config)
217 | store(config)
218 |
219 | if last_step:
220 | return
221 |
222 | # --------------------------------------------------------------------------------------------
223 | # Train
224 | # --------------------------------------------------------------------------------------------
225 |
226 | tasks_model.train()
227 |
228 | xs = xs.to(get_device(config.cuda))
229 | ys = ys.to(get_device(config.cuda))
230 |
231 | curr_classes = ys.unique()
232 |
233 | optimizer.zero_grad()
234 |
235 | # Update old_tasks_model if needed
236 | if not config.long_window or (config.long_window and t == config.next_update_old_model_t):
237 | old_tasks_model = deepcopy(tasks_model)
238 | if config.opt_batch_stats: old_tasks_model.compute_batch_stats_loss()
239 |
240 | if config.long_window:
241 | config.next_update_old_model_t_history.append(config.next_update_old_model_t)
242 | if config.use_fixed_window:
243 | window_offset = config.fixed_window
244 | else:
245 | window_offset = np.random.randint(config.long_window_range[0], high=(
246 | config.long_window_range[1] + 1)) # randint is excl
247 |
248 | config.next_update_old_model_t = t + window_offset
249 |
250 | # --------------------------------------------------------------------------------------------
251 | # Train on real data
252 | # --------------------------------------------------------------------------------------------
253 |
254 | preds = tasks_model(xs)
255 | loss_orig = F.cross_entropy(preds, ys, reduction="mean")
256 | record_and_check(config, "loss_orig", loss_orig.item(), t)
257 | loss_orig.backward()
258 | optimizer.step()
259 |
260 | # --------------------------------------------------------------------------------------------
261 | # Generate recall
262 | # --------------------------------------------------------------------------------------------
263 |
264 | if t >= config.recall_from_t:
265 | optimizer.zero_grad()
266 |
267 | metrics = dict([(metric, 0.) for metric in refine_sample_metrics + refine_theta_metrics])
268 |
269 | for r in range(config.refine_theta_steps): # each step updates tasks_model
270 | if config.refine_sample_from_scratch or r == 0: # fresh aux_x
271 | if not config.aux_x_random:
272 | chosen_aux_inds = np.random.choice(xs.shape[0], config.M,
273 | replace=(config.M > xs.shape[0]))
274 | aux_x = xs[chosen_aux_inds]
275 | else:
276 | aux_x = torch.rand((config.M,) + xs.shape[1:]).to(get_device(config.cuda))
277 |
278 | aux_x_orig = aux_x.clone()
279 | aux_x.requires_grad_(True)
280 |
281 | for s in range(config.refine_sample_steps):
282 | aux_preds_old = old_tasks_model(aux_x)
283 | aux_preds_new = tasks_model(aux_x)
284 |
285 | if config.opt_batch_stats:
286 | aux_preds_old, opt_batch_stats_loss = aux_preds_old
287 |
288 | if config.use_crossent_as_D:
289 | divergence_loss = - crossent_logits(preds=aux_preds_new, targets=aux_preds_old)
290 | else:
291 | divergence_loss = neg_symmetric_KL(aux_preds_old, aux_preds_new)
292 |
293 | notlocal_loss = notlocal(aux_preds_old, present_classes)
294 | notlocal_new_loss = notlocal(aux_preds_new, present_classes)
295 |
296 | diversity_loss = diversity(aux_preds_old)
297 |
298 | loss_refine = config.divergence_loss_weight * divergence_loss + \
299 | config.diversity_weight * diversity_loss + \
300 | config.notlocal_weight * notlocal_loss + config.notlocal_new_weight * \
301 | notlocal_new_loss
302 |
303 | if config.sharpen_class:
304 | sharpen_class_loss = sharpen_class(aux_preds_old)
305 | loss_refine += config.sharpen_class_weight * sharpen_class_loss
306 |
307 | if config.TV:
308 | TV_loss = TV(aux_x)
309 | loss_refine += config.TV_weight * TV_loss
310 |
311 | if config.L2:
312 | L2_loss = L2(aux_x)
313 | loss_refine += config.L2_weight * L2_loss
314 |
315 | if config.opt_batch_stats:
316 | loss_refine += config.opt_batch_stats_weight * opt_batch_stats_loss
317 |
318 | for metric in refine_sample_metrics:
319 | metrics[metric] += locals()[metric].item()
320 |
321 | aux_x_grads = \
322 | torch.autograd.grad(loss_refine, aux_x, only_inputs=True, retain_graph=False)[0]
323 | aux_x = (aux_x - config.refine_sample_lr * aux_x_grads).detach().requires_grad_(True)
324 |
325 | # Get final predictions on recalled data from old model
326 | aux_x.requires_grad_(False)
327 | with torch.no_grad():
328 | aux_y = old_tasks_model(aux_x)
329 | distill_xs_targets = old_tasks_model(xs)
330 |
331 | if config.opt_batch_stats:
332 | aux_y, _ = aux_y
333 | distill_xs_targets, _ = distill_xs_targets
334 |
335 | if config.render_aux_x and t % config.render_aux_x_freq == 0:
336 | render_aux_x(config, t, r, aux_x_orig, aux_x, aux_y, present_classes)
337 |
338 | # count number whose top class isn't in current classes
339 | aux_y_hard = aux_y.argmax(dim=1)
340 |
341 | if not hasattr(config, "aux_y_hard"):
342 | config.aux_y_hard = OrderedDict()
343 | if t not in config.aux_y_hard:
344 | config.aux_y_hard[t] = aux_y_hard.cpu()
345 | else:
346 | config.aux_y_hard[t] = torch.cat((config.aux_y_hard[t], aux_y_hard.cpu())).unique()
347 |
348 | if not hasattr(config, "aux_y_probs"):
349 | config.aux_y_probs = OrderedDict()
350 | aux_y_probs = F.softmax(aux_y, dim=1).mean(dim=0)
351 | if t not in config.aux_y_probs:
352 | config.aux_y_probs[t] = aux_y_probs.cpu() * (1. / config.refine_theta_steps) # n c
353 | else:
354 | config.aux_y_probs[t] += (aux_y_probs.cpu() * (1. / config.refine_theta_steps))
355 |
356 | is_present_class = 0
357 | for c in present_classes:
358 | is_present_class += (aux_y_hard == c).sum().item()
359 | metrics["not_present_class"] += aux_y_hard.shape[0] - is_present_class
360 |
361 | # ----------------------------------------------------------------------------------------
362 | # Train on recalled data
363 | # ----------------------------------------------------------------------------------------
364 |
365 | preds = tasks_model(aux_x)
366 | if not config.hard_targets:
367 | loss_aux = crossent_logits(preds=preds, targets=aux_y)
368 | else:
369 | loss_aux = F.cross_entropy(preds, aux_y_hard, reduction="mean")
370 |
371 | final_loss_aux = loss_aux * config.aux_weight
372 |
373 | if config.aux_distill:
374 | loss_aux_distill = crossent_logits(tasks_model(xs), distill_xs_targets)
375 | #loss_aux_distill = F.kl_div(F.log_softmax(tasks_model(xs), dim=1),
376 | # F.softmax(distill_xs_targets, dim=1), reduction="batchmean")
377 | final_loss_aux += config.aux_distill_weight * loss_aux_distill
378 | metrics["loss_aux_distill"] += loss_aux_distill.item()
379 |
380 | metrics["loss_aux"] += loss_aux.item()
381 | metrics["final_loss_aux"] += final_loss_aux.item()
382 |
383 | final_loss_aux.backward()
384 | optimizer.step()
385 |
386 | for metric in refine_sample_metrics:
387 | metrics[metric] /= float(config.refine_sample_steps * config.refine_theta_steps)
388 | record_and_check(config, metric, metrics[metric], t)
389 |
390 | for metric in refine_theta_metrics:
391 | metrics[metric] /= float(config.refine_theta_steps)
392 | record_and_check(config, metric, metrics[metric], t)
393 |
394 | if t < config.recall_from_t and config.render_aux_x and t % config.render_aux_x_freq == 0:
395 | render_aux_x(config, t, 0, xs, None, None, present_classes)
396 |
397 | t += 1
398 | if seen_classes is None:
399 | seen_classes = present_classes
400 | else:
401 | seen_classes = torch.cat((seen_classes, present_classes)).unique()
402 | last_classes = present_classes
403 |
404 |
405 | if __name__ == "__main__":
406 | ms = range(orig_config.model_ind_start, orig_config.model_ind_start + orig_config.num_runs)
407 | for m in ms:
408 | print("Doing m %d" % m)
409 | c = deepcopy(orig_config)
410 | c.model_ind = m
411 | main(c)
412 | print("Done m %d" % m)
413 |
--------------------------------------------------------------------------------
/code/scripts/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xu-ji/ARM/704f69d73765f13ca5e2d8aee11f399b53c635e6/code/scripts/__init__.py
--------------------------------------------------------------------------------
/code/scripts/distill.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from copy import deepcopy
4 | from datetime import datetime
5 |
6 | import torch.optim as optim
7 |
8 | from code.util.eval import evaluate_basic
9 | from code.util.general import *
10 | from code.util.load import *
11 | from code.util.losses import *
12 |
13 | # --------------------------------------------------------------------------------------------------
14 | # Settings
15 | # --------------------------------------------------------------------------------------------------
16 |
17 | orig_config = argparse.ArgumentParser(allow_abbrev=False)
18 |
19 | orig_config.add_argument("--model_ind_start", type=int, required=True)
20 |
21 | orig_config.add_argument("--num_runs", type=int, required=True)
22 |
23 | orig_config.add_argument("--out_root", type=str, required=True)
24 |
25 | # Data and model
26 |
27 | orig_config.add_argument("--data", type=str, default="cifar100")
28 |
29 | orig_config.add_argument("--data_path", type=str, required=True)
30 |
31 | orig_config.add_argument("--stationary", default=False, action="store_true")
32 |
33 | orig_config.add_argument("--classes_per_task", type=int, required=True)
34 |
35 | orig_config.add_argument("--max_t", type=int, required=True)
36 |
37 | orig_config.add_argument("--tasks_train_batch_sz", type=int, default=10)
38 |
39 | orig_config.add_argument("--tasks_eval_batch_sz", type=int, default=10)
40 |
41 | orig_config.add_argument("--num_iterations", type=int, default=1)
42 |
43 | orig_config.add_argument("--task_model_type", type=str, required=True)
44 |
45 | # Distill options
46 |
47 | orig_config.add_argument("--lr", type=float, default=0.1)
48 |
49 | orig_config.add_argument("--recall_from_t", type=int, required=True)
50 |
51 | orig_config.add_argument("--long_window", default=False, action="store_true")
52 |
53 | orig_config.add_argument("--long_window_range", type=int, nargs="+", default=[1, 1]) # inclusive
54 |
55 | orig_config.add_argument("--use_fixed_window", default=False, action="store_true")
56 |
57 | orig_config.add_argument("--fixed_window", type=int, default=1)
58 |
59 | orig_config.add_argument("--aux_distill_weight", type=float, default=1.0)
60 |
61 | # Admin
62 |
63 | orig_config.add_argument("--cuda", default=False, action="store_true")
64 |
65 | orig_config.add_argument("--eval_freq", type=int, required=True)
66 |
67 | orig_config.add_argument("--store_results_freq", type=int, required=True)
68 |
69 | orig_config.add_argument("--store_model_freq", type=int, required=True)
70 |
71 | orig_config.add_argument("--specific_torch_seed", default=False, action="store_true")
72 |
73 | orig_config.add_argument("--torch_seed", type=int)
74 |
75 | orig_config = orig_config.parse_args()
76 |
77 |
78 | def main(config):
79 | # ------------------------------------------------------------------------------------------------
80 | # Setup
81 | # ------------------------------------------------------------------------------------------------
82 |
83 | reproc_settings(config)
84 |
85 | config.out_dir = osp.join(config.out_root, str(config.model_ind))
86 |
87 | tasks_model, trainloader, testloader, valloader = get_model_and_data(config)
88 |
89 | if not osp.exists(config.out_dir):
90 | os.makedirs(config.out_dir)
91 |
92 | next_t = 0
93 | last_classes = None
94 | seen_classes = None
95 |
96 | if config.long_window:
97 | old_tasks_model = None
98 | config.next_update_old_model_t_history = []
99 | config.next_update_old_model_t = 0
100 | assert (len(config.long_window_range) == 2)
101 |
102 | optimizer = optim.SGD(tasks_model.parameters(), lr=config.lr, momentum=0, dampening=0,
103 | weight_decay=0, nesterov=False)
104 |
105 | refine_theta_metrics = ["final_loss_aux", "loss_aux_distill"]
106 |
107 | t = next_t
108 |
109 | while t <= config.max_t:
110 | for xs, ys in trainloader:
111 | present_classes = ys.unique().to(get_device(config.cuda))
112 |
113 | # --------------------------------------------------------------------------------------------
114 | # Eval
115 | # --------------------------------------------------------------------------------------------
116 |
117 | if (t - next_t) < 1000 or t % 100 == 0:
118 | print("m %d t: %d %s, targets %s" % (
119 | config.model_ind, t, datetime.now(), str(list(present_classes.cpu().numpy()))))
120 | sys.stdout.flush()
121 |
122 | save_dict = {"tasks_model": tasks_model, "t": t, "last_classes": last_classes,
123 | "seen_classes": seen_classes}
124 | if config.long_window: # else no need to save because it's always the one just before
125 | # current update
126 | save_dict["old_tasks_model"] = old_tasks_model
127 |
128 | last_step = t == (config.max_t)
129 | if (t % config.eval_freq == 0) or (t % config.batches_per_epoch == 0) or last_step or (
130 | t == 0):
131 | evaluate_basic(config, tasks_model, valloader, t, is_val=True,
132 | last_classes=last_classes, seen_classes=seen_classes)
133 | evaluate_basic(config, tasks_model, testloader, t, is_val=False,
134 | last_classes=last_classes, seen_classes=seen_classes)
135 |
136 | if (t % config.store_model_freq == 0) or last_step:
137 | torch.save(save_dict, osp.join(config.out_dir, "latest_models.pytorch"))
138 |
139 | if (t % config.store_results_freq == 0) or last_step:
140 | render_graphs(config)
141 | store(config)
142 |
143 | if last_step:
144 | return
145 |
146 | # --------------------------------------------------------------------------------------------
147 | # Train
148 | # --------------------------------------------------------------------------------------------
149 |
150 | tasks_model.train()
151 |
152 | xs = xs.to(get_device(config.cuda))
153 | ys = ys.to(get_device(config.cuda))
154 |
155 | curr_classes = ys.unique()
156 | assert (curr_classes.max() < config.task_out_dims[0])
157 |
158 | optimizer.zero_grad()
159 |
160 | # set old_tasks_model if needed
161 | if not config.long_window or (config.long_window and t == config.next_update_old_model_t):
162 | old_tasks_model = deepcopy(tasks_model)
163 |
164 | if config.long_window:
165 | config.next_update_old_model_t_history.append(config.next_update_old_model_t)
166 | if config.use_fixed_window:
167 | window_offset = config.fixed_window
168 | else:
169 | window_offset = np.random.randint(config.long_window_range[0],
170 | high=(config.long_window_range[1] + 1))
171 |
172 | config.next_update_old_model_t = t + window_offset
173 |
174 | # --------------------------------------------------------------------------------------------
175 | # Train on real data
176 | # --------------------------------------------------------------------------------------------
177 |
178 | preds = tasks_model(xs)
179 | loss_orig = F.cross_entropy(preds, ys, reduction="mean")
180 | record_and_check(config, "loss_orig", loss_orig.item(), t) # added!
181 | loss_orig.backward()
182 | optimizer.step() # updates tasks_model, which is now \theta'
183 |
184 | # --------------------------------------------------------------------------------------------
185 | # Distill
186 | # --------------------------------------------------------------------------------------------
187 | if t >= config.recall_from_t:
188 | optimizer.zero_grad()
189 | metrics = dict([(metric, 0.) for metric in refine_theta_metrics])
190 |
191 | with torch.no_grad():
192 | distill_xs_targets = old_tasks_model(xs)
193 |
194 | loss_aux_distill = F.kl_div(F.log_softmax(tasks_model(xs), dim=1),
195 | F.softmax(distill_xs_targets, dim=1), reduction="batchmean")
196 | final_loss_aux = config.aux_distill_weight * loss_aux_distill
197 |
198 | metrics["loss_aux_distill"] += loss_aux_distill.item()
199 | metrics["final_loss_aux"] += final_loss_aux.item()
200 |
201 | final_loss_aux.backward()
202 | optimizer.step() # updates tasks_model
203 |
204 | for metric in refine_theta_metrics:
205 | record_and_check(config, metric, metrics[metric], t)
206 |
207 | t += 1
208 | if seen_classes is None:
209 | seen_classes = present_classes
210 | else:
211 | seen_classes = torch.cat((seen_classes, present_classes)).unique()
212 | last_classes = present_classes
213 |
214 |
215 | if __name__ == "__main__":
216 | ms = range(orig_config.model_ind_start, orig_config.model_ind_start + orig_config.num_runs)
217 | for m in ms:
218 | print("Doing m %d" % m)
219 | c = deepcopy(orig_config)
220 | c.model_ind = m
221 | main(c)
222 | print("Done m %d" % m)
223 |
--------------------------------------------------------------------------------
/code/scripts/naive.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import sys
4 | from datetime import datetime
5 |
6 | import torch.optim as optim
7 |
8 | from code.models import *
9 | from code.util.eval import evaluate_basic
10 | from code.util.general import *
11 | from code.util.load import *
12 |
13 | # --------------------------------------------------------------------------------------------------
14 | # Settings
15 | # --------------------------------------------------------------------------------------------------
16 |
17 | orig_config = argparse.ArgumentParser(allow_abbrev=False)
18 |
19 | orig_config.add_argument("--model_ind_start", type=int, required=True)
20 |
21 | orig_config.add_argument("--num_runs", type=int, required=True)
22 |
23 | orig_config.add_argument("--out_root", type=str, required=True)
24 |
25 | # Data and model
26 |
27 | orig_config.add_argument("--data", type=str, default="cifar100")
28 |
29 | orig_config.add_argument("--data_path", type=str, required=True)
30 |
31 | orig_config.add_argument("--stationary", default=False, action="store_true")
32 |
33 | orig_config.add_argument("--classes_per_task", type=int, required=True)
34 |
35 | orig_config.add_argument("--max_t", type=int, required=True)
36 |
37 | orig_config.add_argument("--tasks_train_batch_sz", type=int, default=10)
38 |
39 | orig_config.add_argument("--tasks_eval_batch_sz", type=int, default=10)
40 |
41 | orig_config.add_argument("--num_iterations", type=int, default=1)
42 |
43 | orig_config.add_argument("--task_model_type", type=str, required=True)
44 |
45 | orig_config.add_argument("--lr", type=float, default=0.1)
46 |
47 | # Admin
48 |
49 | orig_config.add_argument("--cuda", default=False, action="store_true")
50 |
51 | orig_config.add_argument("--eval_freq", type=int, required=True)
52 |
53 | orig_config.add_argument("--store_results_freq", type=int, required=True)
54 |
55 | orig_config.add_argument("--store_model_freq", type=int, required=True)
56 |
57 | orig_config.add_argument("--specific_torch_seed", default=False, action="store_true")
58 |
59 | orig_config.add_argument("--torch_seed", type=int)
60 |
61 | orig_config = orig_config.parse_args()
62 |
63 |
64 | def main(config):
65 | # ------------------------------------------------------------------------------------------------
66 | # Setup
67 | # ------------------------------------------------------------------------------------------------
68 |
69 | reproc_settings(config)
70 |
71 | config.out_dir = osp.join(config.out_root, str(config.model_ind))
72 |
73 | tasks_model, trainloader, testloader, valloader = get_model_and_data(config)
74 |
75 | if not osp.exists(config.out_dir):
76 | os.makedirs(config.out_dir)
77 |
78 | next_t = 0
79 | last_classes = None
80 | seen_classes = None
81 |
82 | optimizer = optim.SGD(tasks_model.parameters(), lr=config.lr, momentum=0, dampening=0,
83 | weight_decay=0, nesterov=False) # basic
84 |
85 | t = next_t
86 |
87 | while t <= config.max_t:
88 | for xs, ys in trainloader:
89 | present_classes = ys.unique().to(get_device(config.cuda))
90 |
91 | # --------------------------------------------------------------------------------------------
92 | # Eval
93 | # --------------------------------------------------------------------------------------------
94 |
95 | if (t - next_t) < 1000 or t % 100 == 0:
96 | print("m %d t %d %s, fst targets %s" % (
97 | config.model_ind, t, datetime.now(), str(list(present_classes.cpu().numpy()))))
98 | sys.stdout.flush()
99 |
100 | save_dict = {"tasks_model": tasks_model, "t": t, "last_classes": last_classes,
101 | "seen_classes": seen_classes}
102 |
103 | last_step = t == (config.max_t)
104 | if (t % config.eval_freq == 0) or (t % config.batches_per_epoch == 0) or last_step or (
105 | t == 0):
106 | evaluate_basic(config, tasks_model, valloader, t, is_val=True, last_classes=last_classes,
107 | seen_classes=seen_classes)
108 | evaluate_basic(config, tasks_model, testloader, t, is_val=False, last_classes=last_classes,
109 | seen_classes=seen_classes)
110 |
111 | if (t % config.store_model_freq == 0) or last_step:
112 | torch.save(save_dict, osp.join(config.out_dir, "latest_models.pytorch"))
113 |
114 | if (t % config.store_results_freq == 0) or last_step:
115 | render_graphs(config)
116 | store(config)
117 |
118 | if last_step:
119 | return
120 |
121 | # --------------------------------------------------------------------------------------------
122 | # Train
123 | # --------------------------------------------------------------------------------------------
124 |
125 | tasks_model.train()
126 |
127 | optimizer.zero_grad()
128 |
129 | xs = xs.to(get_device(config.cuda))
130 | ys = ys.to(get_device(config.cuda))
131 |
132 | preds = tasks_model(xs)
133 | loss_orig = torch.nn.functional.cross_entropy(preds, ys, reduction="mean")
134 |
135 | loss_orig.backward()
136 |
137 | optimizer.step()
138 |
139 | record_and_check(config, "loss_orig", loss_orig.item(), t)
140 |
141 | t += 1
142 | if seen_classes is None:
143 | seen_classes = present_classes
144 | else:
145 | seen_classes = torch.cat((seen_classes, present_classes)).unique()
146 | last_classes = present_classes
147 |
148 |
149 | if __name__ == "__main__":
150 | ms = range(orig_config.model_ind_start, orig_config.model_ind_start + orig_config.num_runs)
151 | for m in ms:
152 | print("Doing m %d" % m)
153 | c = deepcopy(orig_config)
154 | c.model_ind = m
155 | main(c)
156 | print("Done m %d" % m)
157 |
--------------------------------------------------------------------------------
/code/scripts/print_results.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import pickle
4 | import torch
5 | import numpy as np
6 |
7 | args = argparse.ArgumentParser(allow_abbrev=False)
8 | args.add_argument("--root", type=str, required=True)
9 | args.add_argument("--start", type=int, required=True)
10 | args.add_argument("--load_model", default=False, action="store_true")
11 | args.add_argument("--num_runs", type=int, default=5)
12 | args = args.parse_args()
13 |
14 | def treat_underscores(x):
15 | res = []
16 | for c in x:
17 | if c == "_":
18 | res.append("\\_")
19 | else:
20 | res.append(c)
21 |
22 | return "".join(res)
23 |
24 |
25 | def print_results(args):
26 | ms_avg = {"val": {"acc": [], "forgetting": []},
27 | "test": {"acc": [], "forgetting": []}}
28 |
29 | for m in range(args.start, args.start + args.num_runs):
30 | out_dir = os.path.join(args.root, str(m))
31 | config_p = os.path.join(out_dir, "config.pickle")
32 |
33 | config = None
34 | tries = 0
35 | while tries < 1000:
36 | try:
37 | with open(config_p, "rb") as config_f:
38 | config = pickle.load(config_f)
39 | break
40 | except:
41 | tries += 1
42 |
43 | if config is None:
44 | continue
45 |
46 | if args.load_model:
47 | torch.load(os.path.join(config.out_dir, "latest_models.pytorch"))
48 |
49 | actual_t = config.max_t
50 | for prefix in ["val", "test"]:
51 | if not config.stationary:
52 | accs_dict = getattr(config, "%s_accs" % prefix)
53 |
54 | ms_avg[prefix]["acc"].append(accs_dict[actual_t])
55 |
56 | forgetting_dict = getattr(config, "%s_forgetting" % prefix)
57 | if actual_t in forgetting_dict:
58 | ms_avg[prefix]["forgetting"].append(forgetting_dict[actual_t])
59 |
60 | print("model %d, %s: acc %.4f, forgetting %.4f" % (
61 | config.model_ind, prefix, accs_dict[actual_t], forgetting_dict[actual_t]))
62 | else:
63 | accs_dict = getattr(config, "%s_accs_data" % prefix)
64 | ms_avg[prefix]["acc"].append(accs_dict[actual_t])
65 | print("model %d, %s: acc %.4f" % (config.model_ind, prefix, accs_dict[actual_t]))
66 |
67 | print("---")
68 |
69 | for prefix in ["val", "test"]:
70 | for metric in ["acc", "forgetting"]:
71 | if len(ms_avg[prefix][metric]) == 0:
72 | ms_avg[prefix][metric] = (-1, -1)
73 | else:
74 | avg = np.array(ms_avg[prefix][metric]).mean()
75 | std = np.array(ms_avg[prefix][metric]).std()
76 | ms_avg[prefix][metric] = (avg, std)
77 |
78 | print("average %s: acc %.4f +- %.4f, forgetting %.4f +- %.4f" % (
79 | prefix, ms_avg[prefix]["acc"][0], ms_avg[prefix]["acc"][1],
80 | ms_avg[prefix]["forgetting"][0], ms_avg[prefix]["forgetting"][1]))
81 |
82 |
83 | if __name__ == "__main__":
84 | print_results(args)
85 |
--------------------------------------------------------------------------------
/code/scripts/print_table.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import argparse
4 | import pickle
5 | import numpy as np
6 |
7 | args = argparse.ArgumentParser(allow_abbrev=False)
8 | args.add_argument("--root", type=str, required=True)
9 | args.add_argument("--num_runs", type=int, default=5)
10 | args = args.parse_args()
11 |
12 | experiments = [
13 | ("ARM MNIST", 6579),
14 | ("ARM Cifar10", 3717),
15 | ("ARM MiniImageNet", 4821),
16 |
17 | ("ADI MNIST", 6232),
18 | ("ADI Cifar10", 2262),
19 | ("ADI MiniImageNet", 6042),
20 |
21 | ("Distill MNIST", 6102),
22 | ("Distill Cifar10", 4082),
23 | ("Distill MiniImageNet", 6092),
24 |
25 | ("Naive MNIST", 4967),
26 | ("Naive Cifar10", 2522),
27 | ("Naive MiniImageNet", 4557),
28 |
29 | ("Naive stationary MNIST", 4947),
30 | ("Naive stationary Cifar10", 6462),
31 | ("Naive stationary MiniImageNet", 4527),
32 |
33 | # table 4
34 | ("Distill Cifar10, unit lag", 6452),
35 | ("ADI Cifar10, unit lag", 6467),
36 | ("ADI Cifar10, no distill", 2327),
37 | ("ARM Cifar10, unit lag", 6602),
38 | ("ARM Cifar10, no distill", 2717),
39 |
40 | # table 7
41 | ("$\lambda_1 = 0, \lambda_2 = 0$", 3982),
42 | ("$\lambda_3 = 0$", 6562),
43 | ("$\lambda_4 = 0$", 3977),
44 | ("$\lambda_5 = 0$", 3967),
45 | ("$\lambda_6 = 0$", 3972),
46 | ("$M = 150$ (+50)", 6624),
47 | ("$M = 50$ (-50)", 5922),
48 | ("$S = 20$ (doubled)", 3957),
49 | ("$S = 5$ (halved)", 3962),
50 | ("Cross-entropy as D", 3617),
51 | ("Random noise init", 4077),
52 | ("Recall 2x per t", 6502),
53 | ("Recall 4x per t", 6507),
54 | ]
55 |
56 | num_runs = 5
57 |
58 | print("LaTeX table:")
59 | print("\\begin{table}[h]")
60 | print("\\centering")
61 | print("\\fontsize{7}{7}\\selectfont")
62 | print("\\begin{tabular}{l c c c c}")
63 | print("\\toprule")
64 | print("& \\multicolumn{2}{c}{Val} & \\multicolumn{2}{c}{Test} \\\\")
65 | print("& Accuracy & Forgetting & Accuracy & Forgetting \\\\")
66 | print("\\midrule")
67 | for name, m_start in experiments:
68 | ms_avg = {"val": {"acc": [], "forgetting": []},
69 | "test": {"acc": [], "forgetting": []}}
70 |
71 | counts = 0
72 | for m in range(m_start, m_start + args.num_runs):
73 | out_dir = os.path.join(args.root, str(m))
74 | config_p = os.path.join(out_dir, "config.pickle")
75 |
76 | config = None
77 | tries = 0
78 | while tries < 1000:
79 | try:
80 | with open(config_p, "rb") as config_f:
81 | config = pickle.load(config_f)
82 | break
83 | except:
84 | tries += 1
85 |
86 | if config is None:
87 | continue
88 |
89 | actual_t = config.max_t
90 |
91 | if not actual_t in config.test_accs:
92 | continue
93 |
94 | for prefix in ["val", "test"]:
95 | if not config.stationary:
96 | accs_dict = getattr(config, "%s_accs" % prefix)
97 |
98 | ms_avg[prefix]["acc"].append(accs_dict[actual_t])
99 |
100 | forgetting_dict = getattr(config, "%s_forgetting" % prefix)
101 | if actual_t in forgetting_dict:
102 | ms_avg[prefix]["forgetting"].append(forgetting_dict[actual_t])
103 | else:
104 | accs_dict = getattr(config, "%s_accs_data" % prefix)
105 | ms_avg[prefix]["acc"].append(accs_dict[actual_t])
106 |
107 | counts += 1
108 |
109 | for prefix in ["val", "test"]:
110 | for metric in ["acc", "forgetting"]:
111 | if len(ms_avg[prefix][metric]) == 0:
112 | ms_avg[prefix][metric] = (-1, -1)
113 | else:
114 | avg = np.array(ms_avg[prefix][metric]).mean()
115 | std = np.array(ms_avg[prefix][metric]).std()
116 | ms_avg[prefix][metric] = (avg, std)
117 |
118 | print("%s (%d) & %.4f $\pm$ %.4f & %.4f $\pm$ %.4f & %.4f $\pm$ %.4f & %.4f $\pm$ %.4f \\\\" %
119 | (name, counts,
120 |
121 | ms_avg["val"]["acc"][0], ms_avg["val"]["acc"][1],
122 | ms_avg["val"]["forgetting"][0], ms_avg["val"]["forgetting"][1],
123 |
124 | ms_avg["test"]["acc"][0], ms_avg["test"]["acc"][1],
125 | ms_avg["test"]["forgetting"][0], ms_avg["test"]["forgetting"][1],
126 | ))
127 |
128 | print("\\bottomrule")
129 | print("\\end{tabular}")
130 | print("\\end{table}")
131 |
--------------------------------------------------------------------------------
/code/util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xu-ji/ARM/704f69d73765f13ca5e2d8aee11f399b53c635e6/code/util/__init__.py
--------------------------------------------------------------------------------
/code/util/data.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import DataLoader, SequentialSampler
2 | from torchvision import transforms
3 |
4 | from code.data import *
5 |
6 |
7 | def get_data(config):
8 | dataloaders, datasets = globals()["get_%s_loaders" % config.data](config)
9 | num_train_samples = len(datasets[0])
10 | assert (num_train_samples % config.tasks_train_batch_sz) == 0
11 | config.batches_per_epoch = int(num_train_samples / config.tasks_train_batch_sz)
12 |
13 | print("length of training d, test d, val d: %d %d %d, batches_per_epoch %d" %
14 | (num_train_samples, len(datasets[1]), len(datasets[2]), config.batches_per_epoch))
15 | assert (config.store_model_freq == config.batches_per_epoch) # once per epoch
16 |
17 | return dataloaders
18 |
19 |
20 | def get_cifar10_loaders(config):
21 | assert (config.data == "cifar10")
22 | #config.task_in_dims = (3, 32, 32)
23 | #config.task_out_dims = (10,)
24 | two_classes_per_block = (config.classes_per_task == 2)
25 |
26 | train_fns = [transforms.ToTensor()]
27 | transform_train = transforms.Compose(train_fns)
28 |
29 | cifar10_training = cifar10(root=config.data_path, train=True, transform=transform_train,
30 | non_stat=(not config.stationary),
31 | two_classes_per_block=two_classes_per_block,
32 | num_iterations=config.num_iterations)
33 |
34 | if not config.stationary:
35 | cifar10_training_loader = DataLoader(cifar10_training,
36 | sampler=SequentialSampler(cifar10_training), shuffle=False,
37 | batch_size=config.tasks_train_batch_sz)
38 | else:
39 | cifar10_training_loader = DataLoader(cifar10_training, shuffle=True,
40 | batch_size=config.tasks_train_batch_sz)
41 |
42 | test_fns = [transforms.ToTensor()]
43 | transform_test = transforms.Compose(test_fns)
44 |
45 | evals_non_stat = False # does not make a difference, test time behavior independent of rest of
46 | # batch
47 |
48 | cifar10_val = cifar10val(root=config.data_path, transform=transform_test,
49 | non_stat=evals_non_stat, two_classes_per_block=two_classes_per_block)
50 | cifar10_val_loader = DataLoader(cifar10_val, shuffle=False, batch_size=config.tasks_eval_batch_sz)
51 |
52 | cifar10_test = cifar10(root=config.data_path, train=False, transform=transform_test,
53 | num_iterations=1,
54 | non_stat=evals_non_stat, two_classes_per_block=two_classes_per_block)
55 | cifar10_test_loader = DataLoader(cifar10_test, shuffle=False,
56 | batch_size=config.tasks_eval_batch_sz)
57 |
58 | return (cifar10_training_loader, cifar10_test_loader, cifar10_val_loader), \
59 | (cifar10_training, cifar10_test, cifar10_val)
60 |
61 |
62 | def get_miniimagenet_loaders(config):
63 | assert (config.data == "miniimagenet")
64 | #config.task_in_dims = (3, 84, 84)
65 | #config.task_out_dims = (100,)
66 |
67 | train_fns = [
68 | transforms.Resize(84),
69 | transforms.CenterCrop(84),
70 | transforms.ToTensor(),
71 | ]
72 |
73 | transform_train = transforms.Compose(train_fns)
74 |
75 | miniimagenet_training = miniimagenet(root=config.data_path, data_type="train",
76 | transform=transform_train,
77 | non_stat=(not config.stationary),
78 | classes_per_task=config.classes_per_task,
79 | num_iterations=config.num_iterations)
80 |
81 | if not config.stationary:
82 | miniimagenet_training_loader = DataLoader(miniimagenet_training,
83 | sampler=SequentialSampler(miniimagenet_training),
84 | shuffle=False, batch_size=config.tasks_train_batch_sz)
85 | else:
86 | miniimagenet_training_loader = DataLoader(miniimagenet_training, shuffle=True,
87 | batch_size=config.tasks_train_batch_sz)
88 |
89 | test_fns = [
90 | transforms.Resize(84),
91 | transforms.CenterCrop(84),
92 | transforms.ToTensor(),
93 | ]
94 |
95 | transform_test = transforms.Compose(test_fns)
96 |
97 | evals_non_stat = False
98 |
99 | miniimagenet_val = miniimagenetval(root=config.data_path, transform=transform_test,
100 | non_stat=evals_non_stat,
101 | classes_per_task=config.classes_per_task)
102 | miniimagenet_val_loader = DataLoader(miniimagenet_val, shuffle=False,
103 | batch_size=config.tasks_eval_batch_sz)
104 |
105 | miniimagenet_test = miniimagenet(root=config.data_path, data_type="test",
106 | transform=transform_test, num_iterations=1,
107 | non_stat=evals_non_stat,
108 | classes_per_task=config.classes_per_task)
109 | miniimagenet_test_loader = DataLoader(miniimagenet_test, shuffle=False,
110 | batch_size=config.tasks_eval_batch_sz)
111 |
112 | return (miniimagenet_training_loader, miniimagenet_test_loader, miniimagenet_val_loader), \
113 | (miniimagenet_training, miniimagenet_test, miniimagenet_val)
114 |
115 |
116 | def get_mnist5k_loaders(config):
117 | assert (config.data == "mnist5k")
118 | #config.task_in_dims = (28 * 28,)
119 | #config.task_out_dims = (10,)
120 |
121 | mnist5k_training = mnist5k(root=config.data_path, data_type="train",
122 | non_stat=(not config.stationary), num_iterations=config.num_iterations,
123 | classes_per_task=config.classes_per_task)
124 |
125 | if not config.stationary:
126 | mnist5k_training_loader = DataLoader(mnist5k_training,
127 | sampler=SequentialSampler(mnist5k_training),
128 | shuffle=False, batch_size=config.tasks_train_batch_sz)
129 | else:
130 | mnist5k_training_loader = DataLoader(mnist5k_training, shuffle=True,
131 | batch_size=config.tasks_train_batch_sz)
132 |
133 | evals_non_stat = False
134 |
135 | mnist5k_val = mnist5k(root=config.data_path, data_type="val", non_stat=evals_non_stat,
136 | num_iterations=1, classes_per_task=config.classes_per_task)
137 | mnist5k_val_loader = DataLoader(mnist5k_val, shuffle=False, batch_size=config.tasks_eval_batch_sz)
138 |
139 | mnist5k_test = mnist5k(root=config.data_path, data_type="test", non_stat=evals_non_stat,
140 | num_iterations=1, classes_per_task=config.classes_per_task)
141 | mnist5k_test_loader = DataLoader(mnist5k_test, shuffle=False,
142 | batch_size=config.tasks_eval_batch_sz)
143 |
144 | return (mnist5k_training_loader, mnist5k_test_loader, mnist5k_val_loader), \
145 | (mnist5k_training, mnist5k_test, mnist5k_val)
146 |
--------------------------------------------------------------------------------
/code/util/eval.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict, defaultdict
2 |
3 | import numpy as np
4 | import torch
5 |
6 | from code.util.general import get_device
7 |
8 |
9 | # accs_data: average over all data (old metric)
10 | # per_label_accs: acc per class
11 | # per_task_accs: acc per task
12 | # accs: average over all seen tasks (after last task, is same as Chaudry def.)
13 | # forgetting: average over all seen tasks
14 |
15 | def evaluate_basic(config, tasks_model, data_loader, t, is_val, last_classes=None,
16 | seen_classes=None, tag=""):
17 | if is_val:
18 | prefix = "val"
19 | else:
20 | prefix = "test"
21 |
22 | prefix = "%s%s" % (tag, prefix)
23 |
24 | tasks_model.eval()
25 |
26 | acc_data = 0.
27 | counts = 0
28 |
29 | num_out = int(np.prod(config.task_out_dims))
30 | per_label_acc = np.zeros(num_out)
31 | per_label_counts = np.zeros(num_out)
32 |
33 | for x, y in data_loader:
34 | x = x.to(get_device(config.cuda))
35 | y = y.to(get_device(config.cuda))
36 |
37 | with torch.no_grad():
38 | preds = tasks_model(x)
39 |
40 | preds_flat = torch.argmax(preds, dim=1)
41 |
42 | acc_data += (preds_flat == y).sum().item()
43 | counts += y.shape[0]
44 |
45 | for c in range(num_out):
46 | pos = (y == c)
47 | per_label_acc[c] += (pos * (preds_flat == c)).sum().item()
48 | per_label_counts[c] += pos.sum().item()
49 |
50 | # acc over all data
51 | acc_data /= counts
52 |
53 | # acc per class
54 | per_label_counts = np.maximum(per_label_counts, 1) # avoid div 0
55 | per_label_acc /= per_label_counts
56 |
57 | # acc per seen task and avg
58 | acc = None
59 | if hasattr(config, "%s_accs_data" % prefix) and (not config.stationary): # start from after training starts
60 | last_classes = last_classes.cpu().numpy()
61 | seen_classes = seen_classes.cpu().numpy()
62 |
63 | per_task_acc = defaultdict(list)
64 | for c in seen_classes: # seen tasks only
65 | per_task_acc[config.class_dict_tasks[c]].append(per_label_acc[c])
66 |
67 | acc = 0.
68 | for task_i in per_task_acc:
69 | assert (len(per_task_acc[task_i]) == config.classes_per_task)
70 | per_task_acc[task_i] = np.array(per_task_acc[task_i]).mean()
71 | acc += per_task_acc[task_i]
72 | acc /= len(per_task_acc)
73 |
74 | if not hasattr(config, "%s_accs" % prefix):
75 | setattr(config, "%s_accs_data" % prefix, OrderedDict())
76 | setattr(config, "%s_per_label_accs" % prefix, OrderedDict())
77 |
78 | setattr(config, "%s_per_task_accs" % prefix, OrderedDict())
79 | setattr(config, "%s_accs" % prefix, OrderedDict())
80 |
81 | setattr(config, "%s_forgetting" % prefix, OrderedDict())
82 |
83 | getattr(config, "%s_accs_data" % prefix)[t] = acc_data
84 | getattr(config, "%s_per_label_accs" % prefix)[t] = per_label_acc
85 | if acc is not None:
86 | getattr(config, "%s_per_task_accs" % prefix)[t] = per_task_acc
87 | getattr(config, "%s_accs" % prefix)[t] = acc
88 |
89 | # for all previous (excl latest) tasks, find the maximum drop to curr acc
90 | if not config.stationary:
91 | if len(getattr(config, "%s_accs_data" % prefix)) >= 3: # at least 1 previous (non pre training) eval
92 | assert (last_classes is not None)
93 | getattr(config, "%s_forgetting" % prefix)[t] = compute_forgetting(config, t,
94 | getattr(config,
95 | "%s_per_task_accs" % prefix),
96 | last_classes)
97 |
98 |
99 | def compute_forgetting(config, t, per_task_accs, last_classes):
100 | # per_task_acc is not equal length per timestep so can't array
101 |
102 | assert (t % config.eval_freq == 0)
103 |
104 | # find task that just finished
105 | last_task_i = None
106 | for c in last_classes:
107 | task_i = config.class_dict_tasks[c]
108 | if last_task_i is None:
109 | last_task_i = task_i
110 | else:
111 | assert (last_task_i == task_i)
112 |
113 | forgetting_per_task = {}
114 | for task_i in range(last_task_i): # excl last (tasks are numbered chronologically)
115 | best_acc = None
116 | for past_t in per_task_accs:
117 | if past_t == 0: continue # not used
118 | if past_t == t: continue
119 |
120 | if task_i in per_task_accs[past_t]:
121 | if best_acc is None or per_task_accs[past_t][task_i] > best_acc:
122 | best_acc = per_task_accs[past_t][task_i]
123 | assert (best_acc is not None)
124 |
125 | forgetting_per_task[task_i] = best_acc - per_task_accs[t][task_i]
126 |
127 | assert (len(forgetting_per_task) == last_task_i)
128 | return np.array(list(forgetting_per_task.values())).mean()
129 |
--------------------------------------------------------------------------------
/code/util/general.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import pickle
3 |
4 | import matplotlib
5 | import numpy as np
6 | import torch
7 |
8 | matplotlib.use('Agg')
9 | import matplotlib.pyplot as plt
10 | from collections import OrderedDict
11 | from copy import deepcopy
12 |
13 |
14 | def get_device(cuda):
15 | if cuda:
16 | return torch.device("cuda:0")
17 | else:
18 | return torch.device("cpu")
19 |
20 |
21 | def store(config):
22 | # anonymization
23 | data_path = config.data_path
24 | out_root = config.out_root
25 | out_dir = config.out_dir
26 |
27 | config.data_path = ""
28 | config.out_root = ""
29 | config.out_dir = ""
30 |
31 | with open(osp.join(out_dir, "config.pickle"),
32 | 'wb') as outfile:
33 | pickle.dump(config, outfile)
34 |
35 | with open(osp.join(out_dir, "config.txt"),
36 | "w") as text_file:
37 | text_file.write("%s" % config)
38 |
39 | config.data_path = data_path
40 | config.out_root = out_root
41 | config.out_dir = out_dir
42 |
43 |
44 | def get_avg_grads(model):
45 | total = None
46 | count = 0
47 | for p in model.parameters():
48 | sz = np.prod(p.grad.shape)
49 | grad_sum = p.grad.abs().sum()
50 | if total is None:
51 | total = grad_sum
52 | count = sz
53 | else:
54 | total += grad_sum
55 | count += sz
56 |
57 | return total / float(count)
58 |
59 |
60 | def record_and_check(config, name, val, t):
61 | if hasattr(val, "item"):
62 | val = val.item()
63 |
64 | record(config, name, val, t)
65 | if not np.isfinite(val):
66 | print("value (probably loss) not finite, aborting:")
67 | print(name)
68 | print("t %d" % t)
69 | print(val)
70 | store(config) # to store nan values
71 | exit(1)
72 |
73 |
74 | def check(config, val, t):
75 | if not np.isfinite(val):
76 | print("value (probably loss) not finite, aborting:")
77 | print("t %d" % t)
78 | print(val)
79 | store(config) # to store nan values
80 | exit(1)
81 |
82 |
83 | def record(config, val_name, val, t, abs=False):
84 | if not hasattr(config, val_name):
85 | setattr(config, val_name, OrderedDict())
86 |
87 | storage = getattr(config, val_name)
88 |
89 | if "torch" in str(val.__class__):
90 | if abs:
91 | val = torch.abs(val)
92 |
93 | if val.dtype == torch.int64:
94 | assert (val.shape == torch.Size([]))
95 | else:
96 | val = torch.mean(val)
97 |
98 | storage[t] = val.item() # either scalar loss or vector of grads
99 | else:
100 | if abs:
101 | val = abs(val) # default abs
102 |
103 | storage[t] = val
104 |
105 | if not hasattr(config, "record_names"):
106 | config.record_names = []
107 |
108 | if not val_name in config.record_names:
109 | config.record_names.append(val_name)
110 |
111 |
112 | def get_gpu_mem(nvsmi, gpu_ind):
113 | mem_stats = nvsmi.DeviceQuery('memory.free, memory.total')["gpu"][gpu_ind]["fb_memory_usage"]
114 | return mem_stats["total"] - mem_stats["free"]
115 |
116 |
117 | def get_task_in_dims(config):
118 | return config.task_in_dims
119 |
120 |
121 | def get_task_out_dims(config):
122 | return config.task_out_dims
123 |
124 |
125 | def render_graphs(config):
126 | if not hasattr(config, "record_names"):
127 | return
128 |
129 | training_val_names = config.record_names
130 | fig0, axarr0 = plt.subplots(max(len(training_val_names), 2), sharex=False,
131 | figsize=(8, len(training_val_names) * 4))
132 |
133 | for i, val_name in enumerate(training_val_names):
134 | if hasattr(config, val_name):
135 | storage = getattr(config, val_name)
136 | axarr0[i].clear()
137 | axarr0[i].plot(list(storage.keys()), list(storage.values())) # ordereddict
138 |
139 | axarr0[i].set_title(val_name)
140 |
141 | fig0.suptitle("Model %d" % (config.model_ind), fontsize=8)
142 | fig0.savefig(osp.join(config.out_dir, "plots_0.png"))
143 |
144 | if hasattr(config, "val_accs"):
145 | fig1, axarr1 = plt.subplots(4, sharex=False, figsize=(8, 4 * 4))
146 |
147 | for pi, prefix in enumerate(["val", "test"]):
148 | accs_name = "%s_accs" % prefix
149 | axarr1[pi * 2].clear()
150 | axarr1[pi * 2].plot(list(getattr(config, accs_name).keys()),
151 | list(getattr(config, accs_name).values())) # ordereddict
152 | axarr1[pi * 2].set_title(accs_name)
153 |
154 | per_label_accs_name = "%s_per_label_accs" % prefix
155 | axarr1[pi * 2 + 1].clear()
156 | per_label_accs_t = getattr(config, per_label_accs_name).keys()
157 | per_label_accs_np = np.array(list(getattr(config, per_label_accs_name).values()))
158 | for c in range(int(np.prod(get_task_out_dims(config)))):
159 | axarr1[pi * 2 + 1].plot(list(per_label_accs_t), list(per_label_accs_np[:, c]), label=str(c))
160 | axarr1[pi * 2 + 1].legend()
161 | axarr1[pi * 2 + 1].set_title(per_label_accs_name)
162 |
163 | fig1.suptitle("Model %d" % (config.model_ind), fontsize=8)
164 | fig1.savefig(osp.join(config.out_dir, "plots_1.png"))
165 |
166 | # render predictions, if exist
167 | if hasattr(config, "aux_y_probs"):
168 | # time along x axis, classes along y axis
169 | fig2, ax2 = plt.subplots(1, figsize=(16, 8)) # width, height
170 |
171 | num_t = len(config.aux_y_probs)
172 | num_classes = int(np.prod(get_task_out_dims(config)))
173 |
174 | aux_y_probs = list(config.aux_y_probs.values())
175 | aux_y_probs = [aux_y_prob.numpy() for aux_y_prob in aux_y_probs]
176 | aux_y_probs = np.array(aux_y_probs)
177 |
178 | # print(aux_y_probs.shape)
179 | assert (aux_y_probs.shape == (len(config.aux_y_probs), int(np.prod(get_task_out_dims(config)))))
180 |
181 | aux_y_probs = aux_y_probs.transpose() # now num classes, time
182 | min_val = aux_y_probs.min()
183 | max_val = aux_y_probs.max()
184 |
185 | # tile along y axis to make each class fatter. Should be same number of pixels altogether as
186 | # current t / 2
187 | scale = int(0.5 * float(num_t) / num_classes)
188 | if scale > 1:
189 | aux_y_probs = np.repeat(aux_y_probs, scale, axis=0)
190 | ax2.set_yticks(np.arange(num_classes) * scale)
191 | ax2.set_yticklabels(np.arange(num_classes))
192 |
193 | num_thousands = int(num_t / 1000)
194 | ax2.set_xticks(np.arange(num_thousands) * 1000)
195 | ax2.set_xticklabels(np.arange(num_thousands) * 1000 + list(config.aux_y_probs.keys())[0])
196 |
197 | im = ax2.imshow(aux_y_probs)
198 | fig2.colorbar(im, ax=ax2)
199 | # ax2.colorbar()
200 |
201 | fig2.suptitle("Model %d, max %f min %f" % (config.model_ind, max_val, min_val), fontsize=8)
202 | fig2.savefig(osp.join(config.out_dir, "plots_2.png"))
203 |
204 | plt.close("all")
205 |
206 |
207 | def trim_config(config, next_t):
208 | # trim everything down to next_t numbers
209 | # we are starting at top of loop *before* eval step
210 |
211 | for val_name in config.record_names:
212 | storage = getattr(config, val_name)
213 | if isinstance(storage, list):
214 | assert (len(storage) >= (next_t))
215 | setattr(config, val_name, storage[:next_t])
216 | else:
217 | assert (isinstance(storage, OrderedDict))
218 | storage_copy = deepcopy(storage)
219 | for k, v in storage.items():
220 | if k >= next_t:
221 | del storage_copy[k]
222 | setattr(config, val_name, storage_copy)
223 |
224 | for prefix in ["val", "test"]:
225 | accs_storage = getattr(config, "%s_accs" % prefix)
226 | per_label_accs_storage = getattr(config, "%s_per_label_accs" % prefix)
227 |
228 | if isinstance(accs_storage, list):
229 | assert (isinstance(per_label_accs_storage, list))
230 | assert (len(accs_storage) >= (next_t) and len(per_label_accs_storage) >= (
231 | next_t)) # at least next_t stored
232 |
233 | setattr(config, "%s_accs" % prefix, accs_storage[:next_t])
234 | setattr(config, "%s_per_label_accs" % prefix, per_label_accs_storage[:next_t])
235 | else:
236 | assert (
237 | isinstance(accs_storage, OrderedDict) and isinstance(per_label_accs_storage, OrderedDict))
238 | for dn, d in [("accs", accs_storage), ("per_label_accs", per_label_accs_storage)]:
239 | d_copy = deepcopy(d)
240 | for k, v in d.items():
241 | if k >= next_t:
242 | del d_copy[k]
243 | setattr(config, dn, d_copy)
244 |
245 | # deal with window
246 | if config.long_window:
247 | # find index of first historical t for update >= next_t
248 | # set config.next_update_old_model_t = that t
249 | # trim history behind it, backing onto nest_t
250 |
251 | next_t_i = None
252 | for i, update_t in enumerate(config.next_update_old_model_t_history):
253 | if update_t > next_t:
254 | next_t_i = i
255 | break
256 |
257 | # there must be a t in update history that is greater than next_t
258 | # unless config.next_update_old_model_t >= next_t and we stopped before it was added to history
259 | # in which case we don't need to trim any update history
260 |
261 | if next_t_i is None:
262 | print("no trimming:")
263 | print(("config.next_update_old_model_t", config.next_update_old_model_t))
264 | print(("next_t", next_t))
265 | assert (config.next_update_old_model_t >= next_t)
266 | else:
267 | config.next_update_old_model_t = config.next_update_old_model_t_history[next_t_i]
268 | config.next_update_old_model_t_history = config.next_update_old_model_t_history[:next_t_i]
269 |
270 |
271 | def sum_seq(seq):
272 | res = None
273 | for elem in seq:
274 | if res is None:
275 | res = elem
276 | else:
277 | res += elem
278 | return res
279 |
280 |
281 | def np_rand_seed(): # fixed classes shuffling
282 | return 111
283 |
284 |
285 | def reproc_settings(config):
286 | np.random.seed(0) # set separately when shuffling data too
287 | if config.specific_torch_seed:
288 | torch.manual_seed(config.torch_seed)
289 | else:
290 | torch.manual_seed(config.model_ind) # allow initialisations different per model
291 |
292 | torch.backends.cudnn.deterministic = True
293 | torch.backends.cudnn.benchmark = False
294 |
295 |
296 | def copy_parameter_values(from_model, to_model):
297 | to_params = list(to_model.named_parameters())
298 | assert (isinstance(to_params[0], tuple) and len(to_params[0]) == 2)
299 |
300 | to_params = dict(to_params)
301 |
302 | for n, p in from_model.named_parameters():
303 | to_params[n].data.copy_(p.data) # not clone
304 |
305 |
306 | def make_valid_from_train(dataset, cut):
307 | tr_ds, val_ds = [], []
308 | for task_ds in dataset:
309 | x_t, y_t = task_ds
310 |
311 | # shuffle before splitting
312 | perm = torch.randperm(len(x_t))
313 | x_t, y_t = x_t[perm], y_t[perm]
314 |
315 | split = int(len(x_t) * cut)
316 | x_tr, y_tr = x_t[:split], y_t[:split]
317 | x_val, y_val = x_t[split:], y_t[split:]
318 |
319 | tr_ds += [(x_tr, y_tr)]
320 | val_ds += [(x_val, y_val)]
321 |
322 | return tr_ds, val_ds
323 |
324 |
325 | def invert_dict(dict_to_invert):
326 | new_dict = {}
327 | for k, vs in dict_to_invert.items():
328 | for v in vs:
329 | new_dict[v] = k
330 | return new_dict
--------------------------------------------------------------------------------
/code/util/load.py:
--------------------------------------------------------------------------------
1 | from .data import get_data
2 | from .general import get_device, invert_dict
3 | from code.models import *
4 |
5 | def get_model_and_data(config):
6 | config.task_in_dims = {"mnist5k": (28 * 28,), "miniimagenet": (3, 84, 84), "cifar10": (3, 32, 32)}[config.data]
7 | config.task_out_dims = {"mnist5k": (10,), "miniimagenet": (100,), "cifar10": (10,)}[config.data]
8 |
9 | tasks_model = globals()[config.task_model_type](config).to(get_device(config.cuda))
10 |
11 | trainloader, testloader, valloader = get_data(config)
12 |
13 | config.class_dict_tasks = invert_dict(trainloader.dataset.task_dict_classes)
14 |
15 | return tasks_model, trainloader, testloader, valloader
16 |
--------------------------------------------------------------------------------
/code/util/losses.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | import torch
4 | import torch.nn.functional as F
5 |
6 |
7 | def TV(x):
8 | if len(x.shape) == 2: # mnist
9 | side = int(x.shape[1] ** 0.5)
10 | assert (side == 28) # sanity
11 | x = x.view(-1, 1, side, side)
12 |
13 | batch_sz = x.shape[0]
14 | h_x = x.shape[2]
15 | w_x = x.shape[3]
16 | count_h = batch_sample_size(
17 | x[:, :, 1:, :]) # num points in 1 image inc channels except 1 row missing
18 | count_w = batch_sample_size(x[:, :, :, 1:])
19 | h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum()
20 | w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum()
21 | return 2 * (h_tv / count_h + w_tv / count_w) / batch_sz
22 |
23 |
24 | def batch_sample_size(x):
25 | return x.shape[1] * x.shape[2] * x.shape[3]
26 |
27 |
28 | def L2(x):
29 | if len(x.shape) == 2: # mnist
30 | side = int(x.shape[1] ** 0.5)
31 | assert (side == 28) # sanity
32 | x = x.view(-1, 1, side, side)
33 |
34 | sz = batch_sample_size(x)
35 | x = x.view(x.shape[0], -1)
36 | res = torch.norm(x, p=2, dim=1).sum() # this is not size normalised!
37 | return res / (x.shape[0] * sz)
38 |
39 |
40 | def sharpen_class(aux_preds):
41 | # minimise cross entropy with its own argmax
42 | hard_targets = aux_preds.argmax(dim=1).detach()
43 | return F.cross_entropy(aux_preds, hard_targets, reduction="mean")
44 |
45 |
46 | def img_distance(aux_x, xs):
47 | aux_x_exp = aux_x.repeat(xs.shape[0], 1, 1, 1)
48 | xs_exp = xs.repeat_interleave(aux_x.shape[0], dim=0)
49 | assert (aux_x_exp.shape[1:] == xs_exp.shape[1:])
50 | assert (aux_x.shape[1:] == aux_x_exp.shape[1:])
51 |
52 | assert (aux_x_exp.shape[0] == aux_x.shape[0] * xs.shape[0])
53 | assert (xs_exp.shape[0] == aux_x.shape[0] * xs.shape[0])
54 |
55 | # maximise the distance
56 | return -F.mse_loss(aux_x_exp, xs_exp, reduction="mean")
57 |
58 |
59 | def diversity_raw(preds):
60 | # maximise the entropy, use avg of logs
61 | assert (len(preds.shape) == 2)
62 | cross_batch_preds = preds.mean(dim=0, keepdim=True)
63 | entrop = crossent_logits(cross_batch_preds, cross_batch_preds)
64 | return -entrop
65 |
66 |
67 | def diversity(preds, EPS=sys.float_info.epsilon):
68 | # maximise the entropy, use avg after softmax!!
69 | assert (len(preds.shape) == 2)
70 |
71 | preds = F.softmax(preds, dim=1)
72 | cross_batch_preds = preds.mean(dim=0)
73 | cross_batch_preds[(cross_batch_preds < EPS).detach()] = EPS
74 | loss = (cross_batch_preds * torch.log(cross_batch_preds)).sum() # minimise negative entropy
75 | return loss
76 |
77 |
78 | def notlocal(preds, present_classes):
79 | # by minimizing this, maximise distance with present classes
80 |
81 | num_samples, _ = preds.shape
82 | assert (len(present_classes.shape) == 1)
83 |
84 | cum_losses = preds.new_zeros(num_samples)
85 | for c in present_classes:
86 | c_targets = preds.new_full((num_samples,), c, dtype=torch.long)
87 | y_loss = - F.cross_entropy(preds, c_targets, reduction="none")
88 | cum_losses += y_loss
89 |
90 | loss = cum_losses.mean() / present_classes.shape[0]
91 | assert (loss.requires_grad)
92 |
93 | return loss
94 |
95 |
96 | def crossent_logits(preds, targets):
97 | # targets are not yet softmaxed probabilities
98 | assert (len(targets.shape) == 2)
99 | targets = F.softmax(targets, dim=1)
100 |
101 | num_samples, num_outputs = preds.shape
102 | cum_losses = preds.new_zeros(num_samples)
103 |
104 | for c in range(num_outputs):
105 | # grads should still bp back through pred and target if nec
106 | target_temp = preds.new_full((num_samples,), c, dtype=torch.long) # filled with c
107 | y_loss = F.cross_entropy(preds, target_temp, reduction="none")
108 | cum_losses += targets[:, c] * y_loss # weight each one by prob c given by actual target
109 |
110 | loss = cum_losses.mean()
111 | assert (loss.requires_grad)
112 | return loss
113 |
114 |
115 | def deep_inversion_classes_loss(classes_to_refine, aux_preds_old):
116 | # try to get characteristic images for these classes
117 | assert (classes_to_refine.shape == (aux_preds_old.shape[0],))
118 | # aux preds old not softmaxed
119 | return F.cross_entropy(aux_preds_old, classes_to_refine, reduction="mean")
120 |
121 |
122 | def neg_symmetric_KL(aux_preds_old, aux_preds_new):
123 | assert (len(aux_preds_old.shape) == 2)
124 | assert (len(aux_preds_new.shape) == 2)
125 | assert (aux_preds_new.shape == aux_preds_old.shape)
126 |
127 | # inputs logsoftmax, targets softmax
128 | avg_preds = 0.5 * (F.softmax(aux_preds_old, dim=1) + F.softmax(aux_preds_new, dim=1))
129 |
130 | return 1 - 0.5 * (
131 | F.kl_div(F.log_softmax(aux_preds_old, dim=1), avg_preds, reduction="batchmean") +
132 | F.kl_div(F.log_softmax(aux_preds_new, dim=1), avg_preds, reduction="batchmean"))
133 |
134 |
135 | def avoid_unseen_classes(aux_preds_old, seen_classes):
136 | # max KL div with each of seen classes
137 | num_samples, num_outputs = aux_preds_old.shape
138 |
139 | total = None
140 | for c in seen_classes:
141 | target_temp = aux_preds_old.new_full((num_samples,), c, dtype=torch.long) # filled with c
142 | y_loss = F.cross_entropy(aux_preds_old, target_temp, reduction="mean")
143 |
144 | if total is None:
145 | total = y_loss
146 | else:
147 | total += y_loss
148 |
149 | return - total / seen_classes.shape[0] # avg and negate, because encouraging separation
150 |
--------------------------------------------------------------------------------
/code/util/render.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import matplotlib
4 | import numpy as np
5 | import torch
6 |
7 | matplotlib.use('Agg')
8 | import matplotlib.pyplot as plt
9 | from colorsys import hsv_to_rgb
10 |
11 |
12 | def render_aux_x(config, t, r, aux_x_orig, aux_x, aux_y, present_classes):
13 | if len(aux_x_orig.shape) == 2:
14 | assert (config.data == "mnist5k")
15 | side = int(aux_x_orig.shape[1] ** 0.5)
16 | aux_x_orig = aux_x_orig.view(-1, side, side)
17 | if not aux_x is None:
18 | aux_x = aux_x.view(-1, side, side)
19 |
20 | render_out_dir = os.path.join(config.out_dir, "render")
21 | if not os.path.exists(render_out_dir):
22 | os.makedirs(render_out_dir)
23 |
24 | if not aux_x is None:
25 | # render a random selection of render_aux_x_num samples
26 |
27 | if not config.render_separate:
28 | fig0, axarr0 = plt.subplots(2, max(config.render_aux_x_num, 2), sharex=False,
29 | figsize=(config.render_aux_x_num * 4, 2 * 4))
30 |
31 | num_aux_x = aux_x.shape[0]
32 | assert (num_aux_x == config.M)
33 | assert (config.render_aux_x_num <= num_aux_x)
34 | selected = np.random.choice(num_aux_x, config.render_aux_x_num, replace=False)
35 |
36 | s_aux_x = aux_x[selected]
37 | s_aux_y = aux_y[selected]
38 | s_aux_x_orig = aux_x_orig[selected]
39 | if config.render_aux_x_num == 1:
40 | s_aux_x = s_aux_x.unsqueeze(dim=0)
41 | s_aux_x_orig = s_aux_x_orig.unsqueeze(dim=0)
42 | s_aux_y = s_aux_y.unsqueeze(dim=0)
43 |
44 | s_aux_y = torch.nn.functional.softmax(s_aux_y, dim=1) # pre softmax
45 |
46 | diff_sum = (s_aux_x_orig - s_aux_x).abs().sum().item() / config.render_aux_x_num
47 |
48 | for i in range(config.render_aux_x_num):
49 | img_orig = s_aux_x_orig[i]
50 | img = s_aux_x[i]
51 |
52 | if len(s_aux_x_orig.shape) == 4:
53 | img_orig = img_orig.permute(1, 2, 0)
54 | img = img.permute(1, 2, 0) # h, w, 3
55 |
56 | orig_min, orig_max = img.min(), img.max()
57 | img = img - orig_min
58 | img = img / orig_max
59 |
60 | top_class = s_aux_y[i].argmax()
61 |
62 | if not config.render_separate:
63 | axarr0[0, i].imshow(img.cpu().numpy())
64 |
65 | vals = (
66 | top_class.item(), s_aux_y[i, top_class].item(), str(list(present_classes.cpu().numpy())),
67 | orig_min.item(), orig_max.item())
68 |
69 | axarr0[0, i].set_title("top c %d: val %f, pres %s, img min %f max %f" % vals, fontsize=6)
70 | axarr0[1, i].imshow(img_orig.cpu().numpy())
71 | else:
72 | fig_curr_orig, ax_curr_orig = plt.subplots(1, figsize=(4, 4))
73 | ax_curr_orig.imshow(img_orig.cpu().numpy())
74 | # ax_curr_orig.set_axis_off() # barebones
75 |
76 | ax_curr_orig.axis('off')
77 | fig_curr_orig.patch.set_visible(False)
78 | ax_curr_orig.patch.set_visible(False)
79 |
80 | fig_curr_orig.savefig(
81 | os.path.join(render_out_dir, "m_%d_t_%d_%r_i_%d_orig.png" % (config.model_ind, t, r, i)),
82 | bbox_inches=0)
83 |
84 | fig_curr_aux, ax_curr_aux = plt.subplots(1, figsize=(4, 4))
85 | ax_curr_aux.imshow(img.cpu().numpy())
86 | # ax_curr_aux.set_axis_off() # barebones
87 |
88 | ax_curr_aux.axis('off')
89 | fig_curr_aux.patch.set_visible(False)
90 | ax_curr_aux.patch.set_visible(False)
91 |
92 | fig_curr_aux.savefig(os.path.join(render_out_dir, "m_%d_t_%d_%r_i_%d_aux_%d_%.3f.png" %
93 | (config.model_ind, t, r, i, top_class.item(),
94 | s_aux_y[i, top_class].item())), bbox_inches=0)
95 |
96 | plt.close("all")
97 |
98 | if not config.render_separate:
99 | fig0.suptitle("Model %d t %d r %d, diffs %f" % (config.model_ind, t, r, diff_sum), fontsize=8)
100 | fig0.savefig(
101 | os.path.join(render_out_dir, "render_m_%d_t_%d_r_%d.png" % (config.model_ind, t, r)))
102 | plt.close("all")
103 | else:
104 | # just render original batch
105 | for i in range(config.render_aux_x_num):
106 | img_orig = aux_x_orig[i]
107 | if len(img_orig.shape) == 3:
108 | img_orig = img_orig.permute(1, 2, 0) # channels last
109 |
110 | fig_curr_orig, ax_curr_orig = plt.subplots(1, figsize=(4, 4))
111 | ax_curr_orig.imshow(img_orig.cpu().numpy())
112 | # ax_curr_orig.set_axis_off() # barebones
113 |
114 | ax_curr_orig.axis('off')
115 | fig_curr_orig.patch.set_visible(False)
116 | ax_curr_orig.patch.set_visible(False)
117 |
118 | fig_curr_orig.savefig(
119 | os.path.join(render_out_dir, "m_%d_t_%d_%r_i_%d_orig.png" % (config.model_ind, t, r, i)),
120 | bbox_inches=0)
121 |
122 |
123 | def get_colours(n):
124 | hues = np.linspace(0.0, 1.0, n + 1)[0:-1] # ignore last one
125 | all_colours = [np.array(hsv_to_rgb(hue, 0.75, 0.75)) for hue in hues]
126 | return all_colours
127 |
--------------------------------------------------------------------------------
/commands.txt:
--------------------------------------------------------------------------------
1 | # ------------------------------------------
2 | To print results summary after running commands:
3 | # ------------------------------------------
4 | Experiments are named by arbitrary integers for easy reference.
5 |
6 | E.g. ARM CIFAR10
7 | python -m code.scripts.print_results --root /scratch/shared/nfs1/xuji/ARM --start 3717
8 |
9 | average val: acc 0.2586 +- 0.0145, forgetting 0.1046 +- 0.0330
10 | average test: acc 0.2687 +- 0.0107, forgetting 0.0959 +- 0.0371
11 |
12 | Print all results:
13 | python -m code.scripts.print_table --root /scratch/shared/nfs1/xuji/ARM
14 |
15 |
16 | # ------------------------------------------
17 | Tables 1 - 3
18 | # ------------------------------------------
19 |
20 | ARM
21 | MNIST
22 | nohup python -m code.scripts.ARM --model_ind_start 6579 --num_runs 5 --data mnist5k --lr 0.05 --task_model_type mlp --classes_per_task 2 --recall_from_t 100 --num_iterations 1 --M 10 --refine_sample_steps 10 --refine_sample_lr 25.0 --divergence_loss_weight 1.0 --L2 --L2_weight 1.0 --TV --TV_weight 1.0 --long_window --use_fixed_window --fixed_window 100 --sharpen_class --sharpen_class_weight 0.1 --notlocal_weight 1.0 --notlocal_new_weight 0.1 --diversity_weight 16.0 --aux_distill --aux_distill_weight 1.0 --max_t 500 --store_model_freq 500 --store_results_freq 100 --eval_freq 100 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/ > out/m6579s.out &
23 |
24 | CIFAR10
25 | nohup python -m code.scripts.ARM --model_ind_start 3717 --num_runs 5 --data cifar10 --lr 0.01 --task_model_type resnet18 --classes_per_task 2 --recall_from_t 950 --num_iterations 1 --M 100 --refine_sample_steps 10 --refine_sample_lr 10.0 --divergence_loss_weight 1.0 --L2 --L2_weight 1.0 --TV --TV_weight 1.0 --long_window --use_fixed_window --fixed_window 950 --sharpen_class --sharpen_class_weight 0.1 --notlocal_weight 1.0 --notlocal_new_weight 0.1 --diversity_weight 16.0 --aux_distill --aux_distill_weight 1.0 --max_t 4750 --store_model_freq 4750 --store_results_freq 950 --eval_freq 950 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/CIFAR > out/m3717s.out &
26 |
27 | MiniImageNet
28 | nohup python -m code.scripts.ARM --model_ind_start 4821 --num_runs 5 --data miniimagenet --lr 0.01 --task_model_type resnet18 --classes_per_task 5 --recall_from_t 684 --num_iterations 3 --M 100 --refine_sample_steps 10 --refine_sample_lr 10.0 --divergence_loss_weight 1.0 --L2 --L2_weight 1.0 --TV --TV_weight 1.0 --long_window --use_fixed_window --fixed_window 684 --sharpen_class --sharpen_class_weight 1.0 --notlocal_weight 1.0 --notlocal_new_weight 0.1 --diversity_weight 1.0 --aux_distill --aux_distill_weight 2.0 --max_t 13680 --store_model_freq 13680 --store_results_freq 684 --eval_freq 684 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/MINIIMAGENET/ > out/m4821.out &
29 |
30 | ADI
31 | MNIST
32 | nohup python -m code.scripts.ADI --model_ind_start 6232 --num_runs 5 --data mnist5k --lr 0.05 --task_model_type mlp --classes_per_task 2 --recall_from_t 100 --num_iterations 1 --M 10 --refine_theta_steps 1 --refine_sample_steps 10 --refine_sample_lr 25.0 --L2 --L2_weight 1.0 --TV --TV_weight 1.0 --long_window --use_fixed_window --fixed_window 100 --classes_loss_weight 1.0 --choose_past_classes --adaptive --adaptive_weight 1.0 --aux_distill_weight 0.5 --max_t 500 --store_model_freq 500 --store_results_freq 100 --eval_freq 100 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/ > out/m6232s.out &
33 |
34 | CIFAR10
35 | nohup python -m code.scripts.ADI --model_ind_start 2262 --num_runs 5 --data cifar10 --lr 0.01 --task_model_type resnet18_batch_stats --classes_per_task 2 --recall_from_t 950 --num_iterations 1 --M 100 --refine_theta_steps 1 --refine_sample_steps 10 --refine_sample_lr 10.0 --L2 --L2_weight 1.0 --TV --TV_weight 1.0 --long_window --use_fixed_window --fixed_window 950 --classes_loss_weight 1.0 --opt_batch_stats --opt_batch_stats_weight 0.1 --choose_past_classes --adaptive --adaptive_weight 1.0 --max_t 4750 --store_model_freq 4750 --store_results_freq 950 --eval_freq 950 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/CIFAR > out/m2262s.out &
36 |
37 | MiniImageNet
38 | nohup python -m code.scripts.ADI --model_ind_start 6042 --num_runs 5 --data miniimagenet --lr 0.01 --task_model_type resnet18_batch_stats --classes_per_task 5 --recall_from_t 684 --num_iterations 3 --M 100 --refine_theta_steps 1 --refine_sample_steps 10 --refine_sample_lr 10.0 --L2 --L2_weight 1.0 --TV --TV_weight 1.0 --long_window --use_fixed_window --fixed_window 684 --classes_loss_weight 1.0 --opt_batch_stats --opt_batch_stats_weight 0.1 --choose_past_classes --adaptive --adaptive_weight 1.0 --aux_distill_weight 2.0 --max_t 13680 --store_model_freq 13680 --store_results_freq 684 --eval_freq 684 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/MINIIMAGENET/ > out/m6042s.out &
39 |
40 | Distill
41 | MNIST
42 | nohup python -m code.scripts.distill --model_ind_start 6102 --num_runs 5 --data mnist5k --lr 0.05 --task_model_type mlp --classes_per_task 2 --recall_from_t 100 --num_iterations 1 --long_window --use_fixed_window --fixed_window 100 --aux_distill_weight 1.0 --max_t 500 --store_model_freq 500 --store_results_freq 100 --eval_freq 100 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/ > out/m6102s.out &
43 |
44 | CIFAR10
45 | nohup python -m code.scripts.distill --model_ind_start 4082 --num_runs 5 --data cifar10 --lr 0.01 --task_model_type resnet18 --classes_per_task 2 --recall_from_t 950 --num_iterations 1 --long_window --use_fixed_window --fixed_window 950 --aux_distill_weight 1.0 --max_t 4750 --store_model_freq 4750 --store_results_freq 950 --eval_freq 950 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/CIFAR > out/m4082s.out &
46 |
47 | MiniImageNet
48 | nohup python -m code.scripts.distill --model_ind_start 6092 --num_runs 5 --data miniimagenet --lr 0.01 --task_model_type resnet18 --classes_per_task 5 --recall_from_t 684 --num_iterations 3 --long_window --use_fixed_window --fixed_window 684 --aux_distill_weight 2.0 --max_t 13680 --store_model_freq 13680 --store_results_freq 684 --eval_freq 684 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/MINIIMAGENET/ > out/m6092s.out &
49 |
50 | Naive
51 | MNIST
52 | nohup python -m code.scripts.naive --model_ind_start 4967 --num_runs 5 --data mnist5k --lr 0.05 --task_model_type mlp --classes_per_task 2 --num_iterations 1 --max_t 500 --store_model_freq 500 --store_results_freq 100 --eval_freq 100 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/ > out/m4967s.out &
53 |
54 | CIFAR10
55 | nohup python -m code.scripts.naive --model_ind_start 2522 --num_runs 5 --data cifar10 --lr 0.1 --task_model_type resnet18 --classes_per_task 2 --num_iterations 1 --max_t 4750 --store_model_freq 4750 --store_results_freq 950 --eval_freq 950 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/CIFAR > out/m2522s.out &
56 |
57 | MiniImageNet
58 | nohup python -m code.scripts.naive --model_ind_start 4557 --num_runs 5 --data miniimagenet --classes_per_task 5 --lr 0.1 --task_model_type resnet18 --num_iterations 3 --max_t 13680 --store_model_freq 13680 --store_results_freq 684 --eval_freq 684 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/MINIIMAGENET/ > out/m4557s.out &
59 |
60 | Naive stationary
61 | MNIST
62 | nohup python -m code.scripts.naive --model_ind_start 4947 --num_runs 5 --data mnist5k --lr 0.05 --task_model_type mlp --classes_per_task 2 --stationary --num_iterations 1 --max_t 500 --store_model_freq 500 --store_results_freq 100 --eval_freq 100 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/ > out/m4947s.out &
63 |
64 | CIFAR10
65 | nohup python -m code.scripts.naive --model_ind_start 6462 --num_runs 5 --data cifar10 --lr 0.1 --task_model_type resnet18 --classes_per_task 2 --stationary --num_iterations 1 --max_t 4750 --store_model_freq 4750 --store_results_freq 950 --eval_freq 950 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/CIFAR > out/m6462s.out &
66 |
67 | MiniImageNet
68 | nohup python -m code.scripts.naive --model_ind_start 4527 --num_runs 5 --data miniimagenet --classes_per_task 5 --stationary --lr 0.1 --task_model_type resnet18 --num_iterations 3 --max_t 13680 --store_model_freq 13680 --store_results_freq 684 --eval_freq 684 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/MINIIMAGENET/ > out/m4527s.out &
69 |
70 | # ------------------------------------------
71 | Table 4
72 | # ------------------------------------------
73 |
74 | Distill unit lag
75 | nohup python -m code.scripts.distill --model_ind_start 6452 --num_runs 5 --data cifar10 --lr 0.01 --task_model_type resnet18 --classes_per_task 2 --recall_from_t 950 --num_iterations 1 --aux_distill_weight 1.0 --max_t 4750 --store_model_freq 4750 --store_results_freq 950 --eval_freq 950 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/CIFAR > out/m6452s.out &
76 |
77 | ADI unit lag
78 | nohup python -m code.scripts.ADI --model_ind_start 6467 --num_runs 5 --data cifar10 --lr 0.01 --task_model_type resnet18_batch_stats --classes_per_task 2 --recall_from_t 950 --num_iterations 1 --M 100 --refine_theta_steps 1 --refine_sample_steps 10 --refine_sample_lr 10.0 --L2 --L2_weight 1.0 --TV --TV_weight 1.0 --classes_loss_weight 1.0 --opt_batch_stats --opt_batch_stats_weight 0.1 --choose_past_classes --adaptive --adaptive_weight 8.0 --max_t 4750 --store_model_freq 4750 --store_results_freq 950 --eval_freq 950 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/CIFAR > out/m6467s.out &
79 |
80 | ADI no distill
81 | nohup python -m code.scripts.ADI --model_ind_start 2327 --num_runs 5 --data cifar10 --lr 0.01 --task_model_type resnet18_batch_stats --classes_per_task 2 --recall_from_t 950 --num_iterations 1 --M 100 --refine_theta_steps 1 --refine_sample_steps 10 --refine_sample_lr 10.0 --L2 --L2_weight 1.0 --TV --TV_weight 1.0 --long_window --use_fixed_window --fixed_window 950 --classes_loss_weight 1.0 --opt_batch_stats --opt_batch_stats_weight 0.1 --choose_past_classes --adaptive --adaptive_weight 1.0 --no_aux_distill --max_t 4750 --store_model_freq 4750 --store_results_freq 950 --eval_freq 950 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/CIFAR > out/m2327s.out &
82 |
83 | ARM unit lag
84 | nohup python -m code.scripts.ARM --model_ind_start 6602 --num_runs 5 --data cifar10 --lr 0.01 --task_model_type resnet18 --classes_per_task 2 --recall_from_t 950 --num_iterations 1 --M 100 --refine_sample_steps 10 --refine_sample_lr 10.0 --divergence_loss_weight 8.0 --L2 --L2_weight 1.0 --TV --TV_weight 1.0 --sharpen_class --sharpen_class_weight 0.1 --notlocal_weight 1.0 --notlocal_new_weight 0.1 --diversity_weight 16.0 --aux_distill --aux_distill_weight 1.0 --max_t 4750 --store_model_freq 4750 --store_results_freq 950 --eval_freq 950 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/CIFAR > out/m4127s.out &
85 |
86 | ARM no distill
87 | nohup python -m code.scripts.ARM --model_ind_start 2717 --num_runs 5 --data cifar10 --lr 0.01 --task_model_type resnet18 --classes_per_task 2 --recall_from_t 950 --num_iterations 1 --M 100 --refine_sample_steps 10 --refine_sample_lr 10.0 --divergence_loss_weight 1.0 --L2 --L2_weight 1.0 --TV --TV_weight 1.0 --long_window --use_fixed_window --fixed_window 950 --sharpen_class --sharpen_class_weight 1.0 --notlocal_weight 1.0 --notlocal_new_weight 0.1 --diversity_weight 16.0 --max_t 4750 --store_model_freq 4750 --store_results_freq 950 --eval_freq 950 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/CIFAR > out/m2717s.out &
88 |
89 | # ------------------------------------------
90 | Table 7
91 | # ------------------------------------------
92 |
93 | \lambda_1 = 0, \lambda_2 = 0
94 | nohup python -m code.scripts.ARM --model_ind_start 3982 --num_runs 5 --data cifar10 --lr 0.01 --task_model_type resnet18 --classes_per_task 2 --recall_from_t 950 --num_iterations 1 --M 100 --refine_sample_steps 10 --refine_sample_lr 10.0 --divergence_loss_weight 1.0 --L2 --L2_weight 1.0 --TV --TV_weight 1.0 --long_window --use_fixed_window --fixed_window 950 --sharpen_class --sharpen_class_weight 0.1 --notlocal_weight 0.0 --notlocal_new_weight 0.0 --diversity_weight 16.0 --aux_distill --aux_distill_weight 1.0 --max_t 4750 --store_model_freq 4750 --store_results_freq 950 --eval_freq 950 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/CIFAR > out/m3982s.out &
95 |
96 | \lambda_3 = 0
97 | nohup python -m code.scripts.ARM --model_ind_start 6562 --num_runs 5 --data cifar10 --lr 0.01 --task_model_type resnet18 --classes_per_task 2 --recall_from_t 950 --num_iterations 1 --M 100 --refine_sample_steps 10 --refine_sample_lr 10.0 --divergence_loss_weight 1.0 --L2 --L2_weight 1.0 --TV --TV_weight 1.0 --long_window --use_fixed_window --fixed_window 950 --sharpen_class --sharpen_class_weight 0.1 --notlocal_weight 1.0 --notlocal_new_weight 0.1 --diversity_weight 0.0 --aux_distill --aux_distill_weight 1.0 --max_t 4750 --store_model_freq 4750 --store_results_freq 950 --eval_freq 950 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/CIFAR > out/m6562s.out &
98 |
99 | \lambda_4 = 0
100 | nohup python -m code.scripts.ARM --model_ind_start 3977 --num_runs 5 --data cifar10 --lr 0.01 --task_model_type resnet18 --classes_per_task 2 --recall_from_t 950 --num_iterations 1 --M 100 --refine_sample_steps 10 --refine_sample_lr 10.0 --divergence_loss_weight 1.0 --L2 --L2_weight 1.0 --TV --TV_weight 1.0 --long_window --use_fixed_window --fixed_window 950 --notlocal_weight 1.0 --notlocal_new_weight 0.1 --diversity_weight 16.0 --aux_distill --aux_distill_weight 1.0 --max_t 4750 --store_model_freq 4750 --store_results_freq 950 --eval_freq 950 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/CIFAR > out/m3977s.out &
101 |
102 | \lambda_5 = 0
103 | nohup python -m code.scripts.ARM --model_ind_start 3967 --num_runs 5 --data cifar10 --lr 0.01 --task_model_type resnet18 --classes_per_task 2 --recall_from_t 950 --num_iterations 1 --M 100 --refine_sample_steps 10 --refine_sample_lr 10.0 --divergence_loss_weight 1.0 --TV --TV_weight 1.0 --long_window --use_fixed_window --fixed_window 950 --sharpen_class --sharpen_class_weight 0.1 --notlocal_weight 1.0 --notlocal_new_weight 0.1 --diversity_weight 16.0 --aux_distill --aux_distill_weight 1.0 --max_t 4750 --store_model_freq 4750 --store_results_freq 950 --eval_freq 950 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/CIFAR > out/m3967s.out &
104 |
105 | \lambda_6 = 0
106 | nohup python -m code.scripts.ARM --model_ind_start 3972 --num_runs 5 --data cifar10 --lr 0.01 --task_model_type resnet18 --classes_per_task 2 --recall_from_t 950 --num_iterations 1 --M 100 --refine_sample_steps 10 --refine_sample_lr 10.0 --divergence_loss_weight 1.0 --L2 --L2_weight 1.0 --long_window --use_fixed_window --fixed_window 950 --sharpen_class --sharpen_class_weight 0.1 --notlocal_weight 1.0 --notlocal_new_weight 0.1 --diversity_weight 16.0 --aux_distill --aux_distill_weight 1.0 --max_t 4750 --store_model_freq 4750 --store_results_freq 950 --eval_freq 950 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/CIFAR > out/m3972s.out &
107 |
108 | M = 150 (+50)
109 | nohup python -m code.scripts.ARM --model_ind_start 6624 --num_runs 5 --data cifar10 --lr 0.01 --task_model_type resnet18 --classes_per_task 2 --recall_from_t 950 --num_iterations 1 --M 150 --refine_sample_steps 10 --refine_sample_lr 10.0 --divergence_loss_weight 1.0 --L2 --L2_weight 1.0 --TV --TV_weight 1.0 --long_window --use_fixed_window --fixed_window 950 --sharpen_class --sharpen_class_weight 0.1 --notlocal_weight 1.0 --notlocal_new_weight 0.1 --diversity_weight 16.0 --aux_distill --aux_distill_weight 1.0 --max_t 4750 --store_model_freq 4750 --store_results_freq 950 --eval_freq 950 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/CIFAR > out/m6478s.out &
110 |
111 | M = 50 (-50)
112 | nohup python -m code.scripts.ARM --model_ind_start 5922 --num_runs 5 --data cifar10 --lr 0.01 --task_model_type resnet18 --classes_per_task 2 --recall_from_t 950 --num_iterations 1 --M 50 --refine_sample_steps 10 --refine_sample_lr 10.0 --divergence_loss_weight 1.0 --L2 --L2_weight 1.0 --TV --TV_weight 1.0 --long_window --use_fixed_window --fixed_window 950 --sharpen_class --sharpen_class_weight 0.1 --notlocal_weight 1.0 --notlocal_new_weight 0.1 --diversity_weight 16.0 --aux_distill --aux_distill_weight 1.0 --max_t 4750 --store_model_freq 4750 --store_results_freq 950 --eval_freq 950 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/CIFAR > out/m5922s.out &
113 |
114 | S = 20 (doubled)
115 | nohup python -m code.scripts.ARM --model_ind_start 3957 --num_runs 5 --data cifar10 --lr 0.01 --task_model_type resnet18 --classes_per_task 2 --recall_from_t 950 --num_iterations 1 --M 100 --refine_sample_steps 20 --refine_sample_lr 10.0 --divergence_loss_weight 1.0 --L2 --L2_weight 1.0 --TV --TV_weight 1.0 --long_window --use_fixed_window --fixed_window 950 --sharpen_class --sharpen_class_weight 0.1 --notlocal_weight 1.0 --notlocal_new_weight 0.1 --diversity_weight 16.0 --aux_distill --aux_distill_weight 1.0 --max_t 4750 --store_model_freq 4750 --store_results_freq 950 --eval_freq 950 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/CIFAR > out/m3957s.out &
116 |
117 | S = 5 (halved)
118 | nohup python -m code.scripts.ARM --model_ind_start 3962 --num_runs 5 --data cifar10 --lr 0.01 --task_model_type resnet18 --classes_per_task 2 --recall_from_t 950 --num_iterations 1 --M 100 --refine_sample_steps 5 --refine_sample_lr 10.0 --divergence_loss_weight 1.0 --L2 --L2_weight 1.0 --TV --TV_weight 1.0 --long_window --use_fixed_window --fixed_window 950 --sharpen_class --sharpen_class_weight 0.1 --notlocal_weight 1.0 --notlocal_new_weight 0.1 --diversity_weight 16.0 --aux_distill --aux_distill_weight 1.0 --max_t 4750 --store_model_freq 4750 --store_results_freq 950 --eval_freq 950 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/CIFAR > out/m3962s.out &
119 |
120 | Cross-entropy as D
121 | nohup python -m code.scripts.ARM --model_ind_start 3617 --num_runs 5 --data cifar10 --lr 0.01 --task_model_type resnet18 --classes_per_task 2 --recall_from_t 950 --num_iterations 1 --M 100 --refine_sample_steps 10 --refine_sample_lr 10.0 --divergence_loss_weight 1.0 --L2 --L2_weight 1.0 --TV --TV_weight 1.0 --long_window --use_fixed_window --fixed_window 950 --sharpen_class --sharpen_class_weight 0.1 --notlocal_weight 1.0 --notlocal_new_weight 0.1 --diversity_weight 16.0 --aux_distill --aux_distill_weight 1.0 --use_crossent_as_D --max_t 4750 --store_model_freq 4750 --store_results_freq 950 --eval_freq 950 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/CIFAR > out/m3617s.out &
122 |
123 | Random noise init \mathcal{\hat{B}}_X
124 | nohup python -m code.scripts.ARM --model_ind_start 4077 --num_runs 5 --data cifar10 --lr 0.01 --task_model_type resnet18 --classes_per_task 2 --recall_from_t 950 --num_iterations 1 --M 100 --refine_sample_steps 10 --refine_sample_lr 10.0 --divergence_loss_weight 1.0 --L2 --L2_weight 1.0 --TV --TV_weight 1.0 --long_window --use_fixed_window --fixed_window 950 --sharpen_class --sharpen_class_weight 0.1 --notlocal_weight 1.0 --notlocal_new_weight 0.1 --diversity_weight 16.0 --aux_x_random --aux_distill --aux_distill_weight 1.0 --max_t 4750 --store_model_freq 4750 --store_results_freq 950 --eval_freq 950 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/CIFAR > out/m4077s.out &
125 |
126 | Recall 2x per t
127 | nohup python -m code.scripts.ARM --model_ind_start 6502 --num_runs 5 --data cifar10 --lr 0.01 --task_model_type resnet18 --classes_per_task 2 --recall_from_t 950 --num_iterations 1 --M 100 --refine_theta_steps 2 --refine_sample_steps 10 --refine_sample_lr 10.0 --divergence_loss_weight 1.0 --L2 --L2_weight 1.0 --TV --TV_weight 1.0 --long_window --use_fixed_window --fixed_window 950 --sharpen_class --sharpen_class_weight 0.1 --notlocal_weight 1.0 --notlocal_new_weight 0.1 --diversity_weight 16.0 --aux_distill --aux_distill_weight 1.0 --max_t 4750 --store_model_freq 4750 --store_results_freq 950 --eval_freq 950 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/CIFAR > out/m6502s.out &
128 |
129 | Recall 4x per t
130 | nohup python -m code.scripts.ARM --model_ind_start 6507 --num_runs 5 --data cifar10 --lr 0.01 --task_model_type resnet18 --classes_per_task 2 --recall_from_t 950 --num_iterations 1 --M 100 --refine_theta_steps 4 --refine_sample_steps 10 --refine_sample_lr 10.0 --divergence_loss_weight 1.0 --L2 --L2_weight 1.0 --TV --TV_weight 1.0 --long_window --use_fixed_window --fixed_window 950 --sharpen_class --sharpen_class_weight 0.1 --notlocal_weight 1.0 --notlocal_new_weight 0.1 --diversity_weight 16.0 --aux_distill --aux_distill_weight 1.0 --max_t 4750 --store_model_freq 4750 --store_results_freq 950 --eval_freq 950 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/CIFAR > out/m6507s.out &
131 |
--------------------------------------------------------------------------------
/summary.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xu-ji/ARM/704f69d73765f13ca5e2d8aee11f399b53c635e6/summary.png
--------------------------------------------------------------------------------