├── .gitignore ├── LICENSE ├── README.md ├── code ├── __init__.py ├── data │ ├── __init__.py │ ├── cifar.py │ ├── miniimagenet.py │ └── mnist.py ├── models │ ├── __init__.py │ ├── mlp.py │ └── resnet18.py ├── scripts │ ├── ADI.py │ ├── ARM.py │ ├── __init__.py │ ├── distill.py │ ├── naive.py │ ├── print_results.py │ └── print_table.py └── util │ ├── __init__.py │ ├── data.py │ ├── eval.py │ ├── general.py │ ├── load.py │ ├── losses.py │ └── render.py ├── commands.txt └── summary.png /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Xu Ji 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Automatic Recall Machines 2 | 3 | This repository contains the code for Automatic Recall Machines: Internal Replay, Continual Learning and the Brain. 4 | 5 | ARM 6 | 7 | As well as ARM, we include implementations of Adaptive DeepInversion and LwF-style distillation. 8 | 9 | # Dependencies 10 | 11 | Our environment used: 12 | - python 3.6.8 13 | - pytorch 1.4.0 14 | - torchvision 0.5.0 15 | - numpy 1.18.4 16 | 17 | # Run the code 18 | 19 | Commands for all our results on CIFAR10, MiniImageNet and MNIST are given in `commands.txt`. 20 | For example, to run recall on CIFAR10: 21 | ``` 22 | python -m code.scripts.ARM --model_ind_start 3717 --num_runs 5 --data cifar10 --lr 0.01 --task_model_type resnet18 --classes_per_task 2 --recall_from_t 950 --num_iterations 1 --M 100 --refine_sample_steps 10 --refine_sample_lr 10.0 --divergence_loss_weight 1.0 --L2 --L2_weight 1.0 --TV --TV_weight 1.0 --long_window --use_fixed_window --fixed_window 950 --sharpen_class --sharpen_class_weight 0.1 --notlocal_weight 1.0 --notlocal_new_weight 0.1 --diversity_weight 16.0 --aux_distill --aux_distill_weight 1.0 --max_t 4750 --store_model_freq 4750 --store_results_freq 950 --eval_freq 950 --cuda --out_root /scratch/ARM --data_path /scratch/CIFAR 23 | ``` 24 | Print results: 25 | ``` 26 | python -m code.scripts.print_results --root /scratch/shared/nfs1/xuji/ARM --start 3717 27 | 28 | average val: acc 0.2586 +- 0.0145, forgetting 0.1046 +- 0.0330 29 | average test: acc 0.2687 +- 0.0107, forgetting 0.0959 +- 0.0371 30 | ``` 31 | -------------------------------------------------------------------------------- /code/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xu-ji/ARM/704f69d73765f13ca5e2d8aee11f399b53c635e6/code/__init__.py -------------------------------------------------------------------------------- /code/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .cifar import cifar10, cifar10val 2 | from .miniimagenet import miniimagenet, miniimagenetval 3 | from .mnist import mnist5k -------------------------------------------------------------------------------- /code/data/cifar.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import pickle 4 | import sys 5 | 6 | import numpy as np 7 | from PIL import Image 8 | from torchvision.datasets.vision import VisionDataset 9 | 10 | from code.util.general import np_rand_seed 11 | 12 | # Reference: https://github.com/optimass/Maximally_Interfered_Retrieval/blob/master/data.py 13 | # We use 1 dataloader rather than one per task 14 | 15 | __all__ = ["cifar10", "cifar10val"] 16 | 17 | 18 | class cifar(VisionDataset): 19 | train_pc = 0.95 20 | 21 | def __init__(self, root, num_classes, train=True, transform=None, target_transform=None, 22 | non_stat=False, 23 | two_classes_per_block=False, is_val=False, shuffle_classes=False, 24 | num_iterations=None): 25 | print("initialising cifar%d, is val: %s..." % (num_classes, is_val)) 26 | super(cifar, self).__init__(root, transform=transform, target_transform=target_transform) 27 | 28 | self.num_classes = num_classes 29 | 30 | self.train = train 31 | self.is_val = is_val 32 | self.non_stat = non_stat 33 | self.two_classes_per_block = two_classes_per_block 34 | self.shuffle_classes = shuffle_classes 35 | 36 | self.num_iterations = num_iterations 37 | assert (num_iterations is not None) 38 | 39 | if self.is_val: 40 | assert (self.train) 41 | 42 | if self.train: 43 | downloaded_list = self.train_list 44 | else: 45 | downloaded_list = self.test_list 46 | 47 | self.data = [] 48 | self.targets = [] 49 | 50 | # now load the picked numpy arrays 51 | for file_name, checksum in downloaded_list: 52 | file_path = os.path.join(self.root, self.base_folder, file_name) 53 | with open(file_path, 'rb') as f: 54 | if sys.version_info[0] == 2: 55 | entry = pickle.load(f) 56 | else: 57 | entry = pickle.load(f, encoding='latin1') 58 | self.data.append(entry['data']) 59 | if 'labels' in entry: 60 | self.targets.extend(entry['labels']) 61 | else: 62 | self.targets.extend(entry['fine_labels']) 63 | 64 | self.data = np.vstack(self.data).reshape(-1, 3, 32, 32) 65 | self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC 66 | 67 | sz = self.data.shape[0] 68 | assert (len(self.targets) == sz) 69 | 70 | # Split train from val deterministically ------------------------------------------------------- 71 | 72 | if self.train: 73 | if not self.is_val: 74 | per_class_sz = int((sz / self.num_classes) * cifar.train_pc) 75 | data_inds = range(sz) # from start 76 | else: 77 | per_class_sz = int((sz / self.num_classes) * (1 - cifar.train_pc)) 78 | data_inds = reversed(range(sz)) # from back 79 | 80 | class_counts = [0 for _ in range(self.num_classes)] 81 | chosen_inds = [] 82 | for i in data_inds: 83 | c = self.targets[i] 84 | if class_counts[c] < per_class_sz: 85 | chosen_inds.append(i) 86 | class_counts[c] += 1 87 | assert (len(chosen_inds) == per_class_sz * self.num_classes) 88 | 89 | else: 90 | chosen_inds = list(range(sz)) 91 | 92 | # Deterministic shuffle 93 | np.random.seed(np_rand_seed()) 94 | chosen_inds = np.array(chosen_inds) 95 | np.random.shuffle(chosen_inds) 96 | self.data = [self.data[i] for i in chosen_inds] 97 | self.targets = [self.targets[i] for i in chosen_inds] 98 | 99 | # Rearrange into contiguous format for non-stationary training --------------------------------- 100 | self.task_dict_classes = {} 101 | 102 | if non_stat: 103 | # organise data and targets by targets 104 | per_class = [[] for _ in range(self.num_classes)] 105 | for i, label in enumerate(self.targets): 106 | per_class[label].append(self.data[i]) 107 | 108 | new_data = [] 109 | new_targets = [] 110 | if not two_classes_per_block: # classes contiguous 111 | for c in range(self.num_classes): 112 | new_data += per_class[c] 113 | new_targets += [c] * len(per_class[c]) 114 | self.task_dict_classes[c] = [c] 115 | else: 116 | classes = np.arange(self.num_classes) 117 | assert (not self.shuffle_classes) # sanity 118 | if self.shuffle_classes: 119 | np.random.seed(np_rand_seed()) 120 | np.random.shuffle(classes) 121 | 122 | num_tasks = int(np.ceil(self.num_classes / 2.)) 123 | print("two_classes_per_block: num tasks %d" % num_tasks) 124 | for t in range(num_tasks): 125 | inds = [t * 2, t * 2 + 1] 126 | if (t * 2 + 1 >= self.num_classes): # odd number of tasks 127 | assert (t == num_tasks - 1 and num_tasks % 2 == 1) 128 | inds = [t * 2] 129 | 130 | self.task_dict_classes[t] = [classes[i] for i in inds] 131 | 132 | t_data = [] 133 | t_targets = [] 134 | for i in inds: 135 | c = classes[i] 136 | t_data += per_class[c] 137 | t_targets += [c] * len(per_class[c]) 138 | 139 | order = np.arange(len(t_data)) 140 | np.random.shuffle(order) 141 | new_data += [t_data[i] for i in order] 142 | new_targets += [t_targets[i] for i in order] 143 | 144 | self.data = new_data 145 | self.targets = new_targets 146 | 147 | class_lengths = [] 148 | targets_np = np.array(self.targets) 149 | for c in range(num_classes): 150 | class_lengths.append((targets_np == c).sum()) 151 | 152 | print( 153 | "... finished initialising cifar train %s val %s non stat %s classes %d " 154 | "two_classes_per_block %s iterations %d shuffle %s" % 155 | (self.train, self.is_val, self.non_stat, self.num_classes, self.two_classes_per_block, 156 | self.num_iterations, self.shuffle_classes)) 157 | 158 | self.orig_len = len(self.data) 159 | self.actual_len = self.orig_len * self.num_iterations 160 | 161 | if self.non_stat: # we need to care about looping over in task order 162 | assert (self.orig_len % self.num_classes == 0) 163 | 164 | self.orig_samples_per_task = int(self.orig_len / self.num_classes) 165 | 166 | if self.two_classes_per_block: 167 | self.orig_samples_per_task *= 2 168 | 169 | self.actual_samples_per_task = self.orig_samples_per_task * self.num_iterations 170 | print("orig samples per task: %d, actual samples per task: %d, orig len %d actual len %d" % ( 171 | self.orig_samples_per_task, self.actual_samples_per_task, self.orig_len, self.actual_len)) 172 | 173 | def __getitem__(self, index): 174 | assert (index < self.actual_len) 175 | 176 | if not self.non_stat: 177 | index = index % self.orig_len # looping over stationary data is arbitrary 178 | else: 179 | task_i, actual_offset = divmod(index, self.actual_samples_per_task) 180 | _, orig_offset = divmod(actual_offset, self.orig_samples_per_task) 181 | index = task_i * self.orig_samples_per_task + orig_offset 182 | 183 | img, target = self.data[index], self.targets[index] 184 | 185 | # doing this so that it is consistent with all other datasets 186 | # to return a PIL Image 187 | img = Image.fromarray(img) 188 | 189 | if self.transform is not None: 190 | img = self.transform(img) 191 | 192 | if self.target_transform is not None: 193 | target = self.target_transform(target) 194 | 195 | return img, target 196 | 197 | def __len__(self): 198 | return self.actual_len 199 | 200 | 201 | class cifar10(cifar): 202 | base_folder = 'cifar-10-batches-py' 203 | 204 | train_list = [ 205 | ['data_batch_1', 'c99cafc152244af753f735de768cd75f'], 206 | ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'], 207 | ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'], 208 | ['data_batch_4', '634d18415352ddfa80567beed471001a'], 209 | ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'], 210 | ] 211 | 212 | test_list = [ 213 | ['test_batch', '40351d587109b95175f43aff81a1287e'], 214 | ] 215 | 216 | def __init__(self, root, train=True, transform=None, target_transform=None, non_stat=None, 217 | two_classes_per_block=False, is_val=False, num_iterations=None): 218 | assert (non_stat is not None) 219 | if not train: 220 | assert (num_iterations == 1) 221 | super(cifar10, self).__init__(root, num_classes=10, train=train, transform=transform, 222 | target_transform=target_transform, 223 | non_stat=non_stat, two_classes_per_block=two_classes_per_block, 224 | is_val=is_val, shuffle_classes=False, 225 | num_iterations=num_iterations) 226 | 227 | 228 | class cifar10val(cifar10): 229 | # 5% validation set 230 | def __init__(self, root, transform=None, non_stat=False, two_classes_per_block=False): 231 | super(cifar10val, self).__init__(root, train=True, transform=transform, 232 | non_stat=non_stat, two_classes_per_block=two_classes_per_block, 233 | is_val=True, num_iterations=1) 234 | -------------------------------------------------------------------------------- /code/data/miniimagenet.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import os.path 3 | from collections import defaultdict 4 | 5 | import numpy as np 6 | from PIL import Image 7 | from torchvision.datasets.vision import VisionDataset 8 | 9 | from code.util.general import make_valid_from_train 10 | 11 | 12 | # Reference: https://github.com/optimass/Maximally_Interfered_Retrieval/blob/master/data.py 13 | # We use 1 dataloader rather than one per task 14 | 15 | # Can download from https://www.dropbox.com/s/ed1s1dgei9kxd2p/mini-imagenet.zip?dl=0 16 | 17 | def get_data(setname, root_csv, root_images): 18 | csv_path = os.path.join(root_csv, setname + '.csv') 19 | lines = [x.strip() for x in open(csv_path, 'r').readlines()][1:] 20 | 21 | data = [] 22 | label = [] 23 | lb = -1 24 | 25 | wnids = [] 26 | 27 | for l in lines: 28 | name, wnid = l.split(',') 29 | path = os.path.join(root_images, name) 30 | if wnid not in wnids: 31 | wnids.append(wnid) 32 | lb += 1 33 | data.append(path) 34 | label.append(lb) 35 | 36 | return data, label 37 | 38 | 39 | class MiniImagenetDatasetFolder(VisionDataset): 40 | train_val_pc = 0.95 41 | train_test_pc = 0.8 42 | 43 | def __init__(self, root, data_type=None, transform=None, target_transform=None, 44 | non_stat=False, classes_per_task=None, num_iterations=None): 45 | super(MiniImagenetDatasetFolder, self).__init__(root, transform=transform, 46 | target_transform=target_transform) 47 | 48 | self.data_type = data_type 49 | self.non_stat = non_stat 50 | self.classes_per_task = classes_per_task 51 | 52 | self.num_classes = 100 53 | 54 | self.num_iterations = num_iterations 55 | assert (num_iterations is not None) 56 | 57 | # Load data ------------------------------------------------------------------------------------ 58 | # splits are deterministic 59 | 60 | images_path = os.path.join(root, "images") 61 | train_data, train_label = get_data("train", root, 62 | images_path) # zero indexed labels for all calls 63 | valid_data, valid_label = get_data("val", root, images_path) 64 | test_data, test_label = get_data("test", root, images_path) 65 | 66 | train_amt = np.unique(train_label).shape[0] 67 | valid_amt = np.unique(valid_label).shape[0] 68 | test_amt = np.unique(test_label).shape[0] 69 | 70 | assert (train_amt + valid_amt + test_amt == self.num_classes) 71 | valid_label = [x + train_amt for x in valid_label] 72 | test_label = [x + train_amt + valid_amt for x in test_label] 73 | 74 | all_data = np.array(train_data + valid_data + test_data) # np array of strings 75 | all_label = np.array(train_label + valid_label + test_label) 76 | 77 | train_ds, test_ds = [], [] 78 | current_train, current_test = None, None 79 | 80 | cat = lambda x, y: np.concatenate((x, y), axis=0) 81 | 82 | self.task_dict_classes = defaultdict(list) 83 | task_i = 0 84 | for i in range(self.num_classes): 85 | self.task_dict_classes[task_i].append(i) 86 | 87 | class_indices = np.argwhere(all_label == i).reshape(-1) 88 | class_data = all_data[class_indices] 89 | class_label = all_label[class_indices] 90 | split = int(MiniImagenetDatasetFolder.train_test_pc * class_data.shape[0]) # train/test 91 | 92 | data_train, data_test = class_data[:split], class_data[split:] 93 | label_train, label_test = class_label[:split], class_label[split:] 94 | 95 | if current_train is None: 96 | current_train, current_test = (data_train, label_train), ( 97 | data_test, label_test) # multiple samples here 98 | else: 99 | current_train = cat(current_train[0], data_train), cat(current_train[1], label_train) 100 | current_test = cat(current_test[0], data_test), cat(current_test[1], label_test) 101 | 102 | if i % self.classes_per_task == (self.classes_per_task - 1): 103 | train_ds += [current_train] 104 | test_ds += [current_test] 105 | current_train, current_test = None, None 106 | task_i += 1 107 | 108 | train_ds, val_ds = make_valid_from_train(train_ds, 109 | cut=MiniImagenetDatasetFolder.train_val_pc) # uses 110 | # random split, but seed set in main script 111 | 112 | # now we have list of list of (path, label), one list per task 113 | # pick the right source, flatten into one list and load images 114 | 115 | data_summary = {"train": train_ds, "val": val_ds, "test": test_ds}[self.data_type] 116 | 117 | self.data = [] 118 | self.targets = [] 119 | task_lengths = [] 120 | for task_ds in data_summary: 121 | num_samples_task = len(task_ds[0]) 122 | assert (len(task_ds[1]) == num_samples_task) 123 | task_lengths.append(num_samples_task) 124 | for i in range(num_samples_task): 125 | img_path = task_ds[0][i] 126 | label = task_ds[1][i] 127 | self.data.append(img_path) 128 | self.targets.append(label) 129 | 130 | print(self.task_dict_classes) 131 | 132 | # if stationary, shuffle 133 | if not self.non_stat: 134 | perm = np.random.permutation(len(self.data)) 135 | self.data, self.targets = [self.data[perm_i] for perm_i in perm], [self.targets[perm_i] for 136 | perm_i in perm] 137 | 138 | self.orig_len = len(self.data) 139 | self.actual_len = self.orig_len * self.num_iterations 140 | 141 | if self.non_stat: # we need to care about looping over in task order 142 | assert (self.orig_len % self.num_classes == 0) 143 | self.orig_samples_per_task = int( 144 | self.orig_len / self.num_classes) * self.classes_per_task # equally split among tasks 145 | 146 | self.actual_samples_per_task = self.orig_samples_per_task * self.num_iterations 147 | 148 | # sanity 149 | if self.data_type == "train": 150 | assert (self.orig_samples_per_task == (int( 151 | 600 * MiniImagenetDatasetFolder.train_test_pc * MiniImagenetDatasetFolder.train_val_pc) 152 | * self.classes_per_task)) 153 | 154 | if self.data_type == "val": 155 | assert (self.orig_samples_per_task == (int(600 * MiniImagenetDatasetFolder.train_test_pc * ( 156 | 1. - MiniImagenetDatasetFolder.train_val_pc)) * self.classes_per_task)) 157 | 158 | if self.data_type == "test": 159 | assert (self.orig_samples_per_task == ( 160 | int(600 * (1. - MiniImagenetDatasetFolder.train_test_pc)) * self.classes_per_task)) 161 | 162 | print("orig samples per task: %d, actual samples per task: %d" % ( 163 | self.orig_samples_per_task, self.actual_samples_per_task)) 164 | 165 | def __len__(self): 166 | return self.actual_len 167 | 168 | def __getitem__(self, index): 169 | assert (index < self.actual_len) 170 | 171 | if not self.non_stat: 172 | index = index % self.orig_len # looping over stationary data is arbitrary 173 | else: 174 | task_i, actual_offset = divmod(index, self.actual_samples_per_task) 175 | _, orig_offset = divmod(actual_offset, self.orig_samples_per_task) 176 | index = task_i * self.orig_samples_per_task + orig_offset 177 | 178 | sample_path, target = self.data[index], self.targets[index] 179 | 180 | with open(sample_path, "rb") as f: 181 | sample = Image.open(f).convert('RGB') 182 | 183 | if self.transform is not None: 184 | sample = self.transform(sample) 185 | if self.target_transform is not None: 186 | target = self.target_transform(target) 187 | 188 | return sample, target 189 | 190 | def __len__(self): 191 | return self.actual_len 192 | 193 | 194 | class miniimagenet(MiniImagenetDatasetFolder): 195 | def __init__(self, root, data_type, transform=None, target_transform=None, non_stat=None, 196 | classes_per_task=None, num_iterations=None): 197 | assert (non_stat is not None) 198 | if data_type == "val" or data_type == "test": 199 | assert (num_iterations == 1) 200 | 201 | super(miniimagenet, self).__init__(root, 202 | data_type=data_type, 203 | transform=transform, target_transform=target_transform, 204 | non_stat=non_stat, classes_per_task=classes_per_task, 205 | num_iterations=num_iterations) 206 | 207 | 208 | class miniimagenetval(miniimagenet): 209 | def __init__(self, root, transform=None, non_stat=False, classes_per_task=None): 210 | super(miniimagenetval, self).__init__(root, data_type="val", 211 | non_stat=non_stat, classes_per_task=classes_per_task, 212 | transform=transform, num_iterations=1) 213 | -------------------------------------------------------------------------------- /code/data/mnist.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import numpy as np 4 | import torch 5 | from torchvision import datasets 6 | from torchvision.datasets.vision import VisionDataset 7 | 8 | from code.util.general import make_valid_from_train 9 | 10 | 11 | # Reference: https://github.com/optimass/Maximally_Interfered_Retrieval/blob/master/data.py 12 | # We use 1 dataloader rather than one per task 13 | 14 | class mnist5k(VisionDataset): 15 | train_val_pc = 0.95 16 | 17 | def __init__(self, root, data_type=None, non_stat=False, num_iterations=None, classes_per_task=None): 18 | super(mnist5k, self).__init__(root, transform=None, target_transform=None) 19 | 20 | self.data_type = data_type 21 | self.non_stat = non_stat 22 | self.classes_per_task = classes_per_task 23 | assert(self.classes_per_task == 2) 24 | self.num_classes = 10 25 | self.orig_train_samples_per_class = 500 26 | 27 | self.num_iterations = num_iterations 28 | assert (num_iterations is not None) 29 | 30 | # Load data ------------------------------------------------------------------------------------ 31 | # splits are deterministic 32 | 33 | # follows https://github.com/optimass/Maximally_Interfered_Retrieval/ 34 | 35 | train = datasets.MNIST(root, train=True, download=False) 36 | test = datasets.MNIST(root, train=False, download=False) 37 | 38 | train_x, train_y = train.data, train.targets # 60000, 28, 28; 60000 39 | test_x, test_y = test.data, test.targets 40 | 41 | # sort by label 42 | train_ds, test_ds = [], [] # doesn't really matter for test_ds because of batchnorm tracking 43 | # stats 44 | task_i = 0 45 | current_train, current_test = None, None 46 | self.task_dict_classes = defaultdict(list) 47 | for i in range(self.num_classes): 48 | self.task_dict_classes[task_i].append(i) 49 | train_i = train_y == i 50 | test_i = test_y == i 51 | 52 | if current_train is None: 53 | current_train, current_test = (train_x[train_i], train_y[train_i]), ( 54 | test_x[test_i], test_y[test_i]) 55 | else: 56 | current_train = (torch.cat((current_train[0], train_x[train_i]), dim=0), 57 | torch.cat((current_train[1], train_y[train_i]), dim=0)) 58 | current_test = (torch.cat((current_test[0], test_x[test_i]), dim=0), 59 | torch.cat((current_test[1], test_y[test_i]), dim=0)) 60 | 61 | if i % self.classes_per_task == (self.classes_per_task - 1): 62 | train_ds += [current_train] 63 | test_ds += [current_test] 64 | current_train, current_test = None, None 65 | task_i += 1 66 | 67 | # separate validation set (randomised) 68 | train_ds, val_ds = make_valid_from_train(train_ds, cut=mnist5k.train_val_pc) 69 | 70 | # flatten into single list, and truncate training data into 500 per class 71 | data_summary = {"train": train_ds, "val": val_ds, "test": test_ds}[self.data_type] 72 | self.data = [] # list of tensors 73 | self.targets = [] 74 | counts_per_class = torch.zeros(self.num_classes, dtype=torch.int) 75 | task_lengths = [] 76 | for task_ds in data_summary: 77 | assert (len(task_ds[1]) == len(task_ds[0])) 78 | 79 | num_samples_task = 0 80 | for i in range(len(task_ds[1])): 81 | target = task_ds[1][i] 82 | if self.data_type == "train" and counts_per_class[ 83 | target] == self.orig_train_samples_per_class: 84 | continue 85 | else: 86 | self.data.append(task_ds[0][i]) 87 | self.targets.append(target) 88 | counts_per_class[target] += 1 89 | num_samples_task += 1 90 | 91 | task_lengths.append(num_samples_task) 92 | 93 | print(self.task_dict_classes) 94 | 95 | # if stationary, shuffle 96 | if not self.non_stat: 97 | perm = np.random.permutation(len(self.data)) 98 | self.data, self.targets = [self.data[perm_i] for perm_i in perm], [self.targets[perm_i] for 99 | perm_i in perm] 100 | 101 | self.orig_len = len(self.data) 102 | self.actual_len = self.orig_len * self.num_iterations 103 | 104 | if self.non_stat: # we need to care about looping over in task order 105 | assert (self.orig_len % self.num_classes == 0) 106 | self.orig_samples_per_task = int( 107 | self.orig_len / self.num_classes) * self.classes_per_task # equally split among tasks 108 | self.actual_samples_per_task = self.orig_samples_per_task * self.num_iterations 109 | 110 | # sanity 111 | if self.data_type == "train": assert (self.orig_samples_per_task == 1000) 112 | 113 | print("orig samples per task: %d, actual samples per task: %d" % ( 114 | self.orig_samples_per_task, self.actual_samples_per_task)) 115 | 116 | def __len__(self): 117 | return self.actual_len 118 | 119 | def __getitem__(self, index): 120 | assert (index < self.actual_len) 121 | 122 | if not self.non_stat: 123 | index = index % self.orig_len # looping over stationary data is arbitrary 124 | else: 125 | task_i, actual_offset = divmod(index, self.actual_samples_per_task) 126 | _, orig_offset = divmod(actual_offset, self.orig_samples_per_task) 127 | index = task_i * self.orig_samples_per_task + orig_offset 128 | 129 | sample, target = self.data[index], self.targets[index] 130 | sample = sample.view(-1).float() / 255. # flatten and turn from uint8 (255) -> [0., 1.] 131 | 132 | assert (self.transform is None) 133 | assert (self.target_transform is None) 134 | 135 | return sample, target 136 | 137 | def __len__(self): 138 | return self.actual_len 139 | -------------------------------------------------------------------------------- /code/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .mlp import mlp 2 | from .resnet18 import * -------------------------------------------------------------------------------- /code/models/mlp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn as nn 3 | 4 | 5 | # Consistent with https://github.com/optimass/Maximally_Interfered_Retrieval 6 | 7 | class mlp(nn.Module): 8 | def __init__(self, config): 9 | super(mlp, self).__init__() 10 | 11 | self.nf = 400 12 | 13 | self.input_size = np.prod(config.task_in_dims) 14 | self.hidden = nn.Sequential(nn.Linear(self.input_size, self.nf), 15 | nn.ReLU(True), 16 | nn.Linear(self.nf, self.nf), 17 | nn.ReLU(True)) 18 | 19 | self.linear = nn.Linear(self.nf, np.prod(config.task_out_dims)) 20 | 21 | def forward(self, x): 22 | x = x.view(-1, self.input_size) 23 | x = self.hidden(x) 24 | return self.linear(x) 25 | -------------------------------------------------------------------------------- /code/models/resnet18.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | # Consistent with https://github.com/optimass/Maximally_Interfered_Retrieval 5 | 6 | 7 | class BasicBlock(nn.Module): 8 | """Basic Block for resnet 18 and resnet 34 9 | """ 10 | 11 | # BasicBlock and BottleNeck block 12 | # have different output size 13 | # we use class attribute expansion 14 | # to distinct 15 | expansion = 1 16 | 17 | def __init__(self, in_channels, out_channels, stride=1, use_batchnorm=True, batchnorm_mom=None, 18 | batchnorm_dont_track=False): 19 | super().__init__() 20 | 21 | # residual function 22 | seq = [ 23 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)] 24 | if use_batchnorm: 25 | if batchnorm_mom is None: 26 | seq += [nn.BatchNorm2d(out_channels, track_running_stats=(not batchnorm_dont_track))] 27 | else: 28 | seq += [nn.BatchNorm2d(out_channels, momentum=batchnorm_mom, 29 | track_running_stats=(not batchnorm_dont_track))] 30 | 31 | seq += [ 32 | nn.ReLU(inplace=True), 33 | nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, 34 | bias=False), 35 | ] 36 | if use_batchnorm: 37 | if batchnorm_mom is None: 38 | seq += [nn.BatchNorm2d(out_channels * BasicBlock.expansion, 39 | track_running_stats=(not batchnorm_dont_track))] 40 | else: 41 | seq += [nn.BatchNorm2d(out_channels * BasicBlock.expansion, momentum=batchnorm_mom, 42 | track_running_stats=(not batchnorm_dont_track))] 43 | 44 | self.residual_function = nn.Sequential(*seq) 45 | 46 | # shortcut 47 | self.shortcut = nn.Sequential() 48 | 49 | # the shortcut output dimension is not the same with residual function 50 | # use 1*1 convolution to match the dimension 51 | if stride != 1 or in_channels != BasicBlock.expansion * out_channels: 52 | shortcut_seq = [ 53 | nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, 54 | bias=False)] 55 | if use_batchnorm: 56 | if batchnorm_mom is None: 57 | shortcut_seq += [nn.BatchNorm2d(out_channels * BasicBlock.expansion, 58 | track_running_stats=(not batchnorm_dont_track))] 59 | else: 60 | shortcut_seq += [ 61 | nn.BatchNorm2d(out_channels * BasicBlock.expansion, momentum=batchnorm_mom, 62 | track_running_stats=(not batchnorm_dont_track))] 63 | 64 | self.shortcut = nn.Sequential(*shortcut_seq) 65 | 66 | def forward(self, x): 67 | return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x)) 68 | 69 | 70 | class ResNet(nn.Module): 71 | def __init__(self, block, num_block, num_classes=100, use_batchnorm=True, batchnorm_mom=None, 72 | batchnorm_dont_track=False, in_channels=64, init=False, batchnorm_init=False, 73 | linear_sz=None): 74 | super().__init__() 75 | 76 | self.num_classes = num_classes 77 | self.use_batchnorm = use_batchnorm 78 | self.batchnorm_mom = batchnorm_mom 79 | 80 | self.in_channels = in_channels 81 | 82 | self.conv1 = nn.Sequential( 83 | nn.Conv2d(3, in_channels, kernel_size=3, padding=1, bias=False), 84 | nn.BatchNorm2d(in_channels, momentum=self.batchnorm_mom, 85 | track_running_stats=(not batchnorm_dont_track)), 86 | nn.ReLU(inplace=True)) 87 | # we use a different inputsize than the original paper 88 | # so conv2_x's stride is 1 89 | 90 | # make layer resets in_channels to be out channels 91 | self.conv2_x = self._make_layer(block, in_channels, num_block[0], 1, 92 | batchnorm_mom=self.batchnorm_mom, 93 | batchnorm_dont_track=batchnorm_dont_track) 94 | self.conv3_x = self._make_layer(block, in_channels * 2, num_block[1], 2, 95 | batchnorm_mom=self.batchnorm_mom, 96 | batchnorm_dont_track=batchnorm_dont_track) 97 | self.conv4_x = self._make_layer(block, in_channels * 4, num_block[2], 2, 98 | batchnorm_mom=self.batchnorm_mom, 99 | batchnorm_dont_track=batchnorm_dont_track) 100 | self.conv5_x = self._make_layer(block, in_channels * 8, num_block[3], 2, 101 | batchnorm_mom=self.batchnorm_mom, 102 | batchnorm_dont_track=batchnorm_dont_track) 103 | 104 | if self.num_classes == 10: 105 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 106 | 107 | self.fc1 = nn.Linear(linear_sz, self.num_classes) 108 | 109 | """ 110 | if init: # else default, which is uniform 111 | print("calling _initialise") 112 | self._initialise() 113 | 114 | if batchnorm_init: # else default, which is all 1s since 1.2 115 | print("calling _batchnorm_initialise") 116 | self._batchnorm_initialise() 117 | """ 118 | 119 | def _initialise(self): 120 | for m in self.modules(): 121 | if isinstance(m, nn.Conv2d): 122 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 123 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 124 | nn.init.constant_(m.weight, 1) 125 | nn.init.constant_(m.bias, 0) 126 | 127 | def _batchnorm_initialise(self): 128 | for m in self.modules(): 129 | if isinstance(m, nn.BatchNorm2d): # pre 1.2 default init 130 | nn.init.uniform_(m.weight, a=0.0, b=1.0) 131 | nn.init.constant_(m.bias, 0) 132 | 133 | def _make_layer(self, block, out_channels, num_blocks, stride, batchnorm_mom=None, 134 | batchnorm_dont_track=False): 135 | """make resnet layers(by layer i didnt mean this 'layer' was the 136 | same as a neuron netowork layer, ex. conv layer), one layer may 137 | contain more than one residual block 138 | 139 | Args: 140 | block: block type, basic block or bottle neck block 141 | out_channels: output depth channel number of this layer 142 | num_blocks: how many blocks per layer 143 | stride: the stride of the first block of this layer 144 | 145 | Return: 146 | return a resnet layer 147 | """ 148 | 149 | # we have num_block blocks per layer, the first block 150 | # could be 1 or 2, other blocks would always be 1 151 | strides = [stride] + [1] * (num_blocks - 1) 152 | layers = [] 153 | for stride in strides: 154 | layers.append(block(self.in_channels, out_channels, stride, 155 | use_batchnorm=self.use_batchnorm, batchnorm_mom=batchnorm_mom, 156 | batchnorm_dont_track=batchnorm_dont_track)) 157 | self.in_channels = out_channels * block.expansion 158 | 159 | return nn.Sequential(*layers) 160 | 161 | def forward(self, x): 162 | x = self.conv1(x) 163 | x = self.conv2_x(x) 164 | x = self.conv3_x(x) 165 | x = self.conv4_x(x) 166 | x = self.conv5_x(x) 167 | 168 | if self.num_classes == 10: 169 | x = self.avg_pool(x) # 1x1 (as Aljundi) 170 | else: 171 | x = nn.functional.avg_pool2d(x, 4) # 2x2 (as Aljundi) 172 | 173 | x = x.view(x.size(0), -1) 174 | 175 | return self.fc1(x) 176 | 177 | 178 | def _batch_stats_hook(b, input): 179 | if isinstance(input, tuple): 180 | assert (len(input) == 1) 181 | input = input[0] 182 | 183 | assert (len(input.shape) == 4) 184 | 185 | stored_mean = b.running_mean 186 | stored_var = b.running_var 187 | 188 | curr_mean = input.mean(dim=(0, 2, 3)) 189 | curr_var = input.var(dim=(0, 2, 3)) 190 | 191 | assert (stored_mean.shape == (input.shape[1],)) 192 | assert (stored_var.shape == (input.shape[1],)) 193 | assert (curr_mean.shape == (input.shape[1],)) 194 | assert (curr_var.shape == (input.shape[1],)) 195 | 196 | b.batch_stats_loss = torch.norm(curr_mean - stored_mean, p=2) + torch.norm(curr_var - stored_var, 197 | p=2) 198 | assert (b.batch_stats_loss.shape == torch.Size([])) # scalar 199 | 200 | 201 | class resnet18(ResNet): 202 | def __init__(self, config): 203 | # Newer resnet code uses avgpool but Aljundi code uses 4 pool. 204 | if config.data == "miniimagenet": 205 | num_classes = 100 206 | linear_sz = 160 * 2 * 2 207 | elif config.data == "cifar10": 208 | num_classes = 10 209 | linear_sz = 160 * 1 * 1 210 | 211 | super(resnet18, self).__init__(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, 212 | in_channels=20, 213 | use_batchnorm=True, 214 | batchnorm_dont_track=False, 215 | linear_sz=linear_sz) 216 | 217 | 218 | class resnet18_batch_stats(resnet18): 219 | def __init__(self, config): 220 | super(resnet18_batch_stats, self).__init__(config) 221 | 222 | self._compute_batch_stats_loss = False 223 | 224 | def compute_batch_stats_loss(self): 225 | self._compute_batch_stats_loss = True 226 | 227 | self.num_batchnorms = 0 228 | for m in self.modules(): 229 | if isinstance(m, nn.BatchNorm2d): 230 | m.register_forward_pre_hook(_batch_stats_hook) 231 | m.batch_stats_loss = None 232 | self.num_batchnorms += 1 233 | 234 | def forward(self, x): 235 | # return loss for batch stats 236 | 237 | x = self.conv1(x) 238 | x = self.conv2_x(x) 239 | x = self.conv3_x(x) 240 | x = self.conv4_x(x) 241 | x = self.conv5_x(x) 242 | 243 | if self.num_classes == 10: 244 | x = self.avg_pool(x) 245 | else: 246 | x = nn.functional.avg_pool2d(x, 4) 247 | x = x.view(x.size(0), -1) 248 | 249 | x = self.fc1(x) 250 | 251 | if self._compute_batch_stats_loss: 252 | batch_stats_losses = [] 253 | for m in self.modules(): 254 | if isinstance(m, nn.BatchNorm2d): 255 | assert (m.batch_stats_loss is not None) 256 | batch_stats_losses.append(m.batch_stats_loss) 257 | m.batch_stats_loss = None 258 | assert (len(batch_stats_losses) == self.num_batchnorms) 259 | return x, torch.mean(torch.stack(batch_stats_losses)) 260 | else: 261 | return x -------------------------------------------------------------------------------- /code/scripts/ADI.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from copy import deepcopy 4 | from datetime import datetime 5 | 6 | import torch.optim as optim 7 | 8 | from code.util.eval import evaluate_basic 9 | from code.util.general import * 10 | from code.util.load import * 11 | from code.util.losses import * 12 | from code.util.render import render_aux_x 13 | 14 | # -------------------------------------------------------------------------------------------------- 15 | # Settings 16 | # -------------------------------------------------------------------------------------------------- 17 | 18 | orig_config = argparse.ArgumentParser(allow_abbrev=False) 19 | 20 | orig_config.add_argument("--model_ind_start", type=int, required=True) 21 | 22 | orig_config.add_argument("--num_runs", type=int, required=True) 23 | 24 | orig_config.add_argument("--out_root", type=str, required=True) 25 | 26 | # Data and model 27 | 28 | orig_config.add_argument("--data", type=str, required=True) 29 | 30 | orig_config.add_argument("--data_path", type=str, required=True) 31 | 32 | orig_config.add_argument("--stationary", default=False, action="store_true") 33 | 34 | orig_config.add_argument("--classes_per_task", type=int, required=True) 35 | 36 | orig_config.add_argument("--max_t", type=int, required=True) 37 | 38 | orig_config.add_argument("--tasks_train_batch_sz", type=int, default=10) 39 | 40 | orig_config.add_argument("--tasks_eval_batch_sz", type=int, default=10) 41 | 42 | orig_config.add_argument("--num_iterations", type=int, default=1) 43 | 44 | orig_config.add_argument("--task_model_type", type=str, required=True) 45 | 46 | # Adaptive DeepInversion 47 | 48 | orig_config.add_argument("--recall_from_t", type=int, required=True) 49 | 50 | orig_config.add_argument("--M", type=int, default=100) 51 | 52 | orig_config.add_argument("--lr", type=float, default=0.1) 53 | 54 | orig_config.add_argument("--refine_sample_lr", type=float, default=0.1) 55 | 56 | orig_config.add_argument("--refine_sample_steps", type=int, default=10) 57 | 58 | orig_config.add_argument("--aux_weight", type=float, default=1.0) 59 | 60 | orig_config.add_argument("--no_aux_distill", default=False, action="store_true") 61 | 62 | orig_config.add_argument("--aux_distill_weight", type=float, default=1.0) 63 | 64 | orig_config.add_argument("--adaptive", default=False, action="store_true") 65 | 66 | orig_config.add_argument("--adaptive_weight", type=float, default=1.0) 67 | 68 | orig_config.add_argument("--classes_loss_weight", type=float, default=1.0) 69 | 70 | orig_config.add_argument("--choose_past_classes", default=False, action="store_true") 71 | 72 | orig_config.add_argument("--opt_batch_stats", default=False, action="store_true") 73 | 74 | orig_config.add_argument("--opt_batch_stats_weight", type=float, default=1.0) 75 | 76 | orig_config.add_argument("--sharpen_class", default=False, action="store_true") 77 | 78 | orig_config.add_argument("--sharpen_class_weight", type=float, default=1.0) 79 | 80 | orig_config.add_argument("--TV", default=False, action="store_true") 81 | 82 | orig_config.add_argument("--TV_weight", type=float, default=1.0) 83 | 84 | orig_config.add_argument("--L2", default=False, action="store_true") 85 | 86 | orig_config.add_argument("--L2_weight", type=float, default=1.0) 87 | 88 | orig_config.add_argument("--long_window", default=False, action="store_true") 89 | 90 | orig_config.add_argument("--long_window_range", type=int, nargs="+", default=[1, 1]) # inclusive 91 | 92 | orig_config.add_argument("--use_fixed_window", default=False, action="store_true") 93 | 94 | orig_config.add_argument("--fixed_window", type=int, default=1) 95 | 96 | # Included for testing: 97 | 98 | orig_config.add_argument("--refine_theta_steps", type=int, default=1) 99 | 100 | orig_config.add_argument("--refine_sample_from_scratch", default=False, action="store_true") 101 | 102 | orig_config.add_argument("--hard_targets", default=False, action="store_true") 103 | 104 | orig_config.add_argument("--aux_x_random", default=False, action="store_true") 105 | 106 | orig_config.add_argument("--no_classes_loss", default=False, action="store_true") 107 | 108 | # Admin 109 | 110 | orig_config.add_argument("--cuda", default=False, action="store_true") 111 | 112 | orig_config.add_argument("--eval_freq", type=int, required=True) 113 | 114 | orig_config.add_argument("--store_results_freq", type=int, required=True) 115 | 116 | orig_config.add_argument("--store_model_freq", type=int, required=True) 117 | 118 | orig_config.add_argument("--specific_torch_seed", default=False, action="store_true") 119 | 120 | orig_config.add_argument("--torch_seed", type=int) 121 | 122 | orig_config.add_argument("--render_aux_x", default=False, action="store_true") 123 | 124 | orig_config.add_argument("--render_aux_x_freq", type=int, default=50) 125 | 126 | orig_config.add_argument("--render_aux_x_num", type=int, default=3) 127 | 128 | orig_config = orig_config.parse_args() 129 | 130 | 131 | def main(config): 132 | # ------------------------------------------------------------------------------------------------ 133 | # Setup 134 | # ------------------------------------------------------------------------------------------------ 135 | 136 | reproc_settings(config) 137 | 138 | config.out_dir = osp.join(config.out_root, str(config.model_ind)) 139 | 140 | tasks_model, trainloader, testloader, valloader = get_model_and_data(config) 141 | 142 | if not osp.exists(config.out_dir): 143 | os.makedirs(config.out_dir) 144 | 145 | next_t = 0 146 | last_classes = None 147 | seen_classes = None 148 | 149 | if config.long_window: 150 | old_tasks_model = None 151 | config.next_update_old_model_t_history = [] 152 | config.next_update_old_model_t = 0 # old model needs to be set in first timestep 153 | assert (len(config.long_window_range) == 2) 154 | 155 | optimizer = optim.SGD(tasks_model.parameters(), lr=config.lr, momentum=0, dampening=0, 156 | weight_decay=0, nesterov=False) 157 | 158 | refine_sample_metrics = [] 159 | refine_theta_metrics = [] 160 | 161 | if not config.no_aux_distill: refine_theta_metrics.append("loss_aux_distill") 162 | if not config.no_classes_loss: refine_sample_metrics.append("classes_loss") 163 | if config.adaptive: refine_sample_metrics.append("adaptive_loss") 164 | if config.sharpen_class: refine_sample_metrics.append("sharpen_class_loss") 165 | if config.TV: refine_sample_metrics.append("TV_loss") 166 | if config.L2: refine_sample_metrics.append("L2_loss") 167 | if config.opt_batch_stats: refine_sample_metrics.append("opt_batch_stats_loss") 168 | 169 | refine_sample_metrics += ["loss_refine"] 170 | refine_theta_metrics += ["not_present_class", "loss_aux", "final_loss_aux"] 171 | 172 | t = next_t 173 | 174 | while t <= config.max_t: 175 | for xs, ys in trainloader: 176 | present_classes = ys.unique().to(get_device(config.cuda)) 177 | 178 | # -------------------------------------------------------------------------------------------- 179 | # Eval 180 | # -------------------------------------------------------------------------------------------- 181 | 182 | if (t - next_t) < 1000 or t % 100 == 0: 183 | print("m %d t: %d %s, targets %s" % ( 184 | config.model_ind, t, datetime.now(), str(list(present_classes.cpu().numpy())))) 185 | sys.stdout.flush() 186 | 187 | save_dict = {"tasks_model": tasks_model, "t": t, "last_classes": last_classes, 188 | "seen_classes": seen_classes} 189 | if config.long_window: # else no need to save because it's always the one just before 190 | # current update 191 | save_dict["old_tasks_model"] = old_tasks_model 192 | 193 | last_step = t == (config.max_t) 194 | if (t % config.eval_freq == 0) or (t % config.batches_per_epoch == 0) or last_step or ( 195 | t == 0): 196 | evaluate_basic(config, tasks_model, valloader, t, is_val=True, 197 | last_classes=last_classes, seen_classes=seen_classes) 198 | evaluate_basic(config, tasks_model, testloader, t, is_val=False, 199 | last_classes=last_classes, seen_classes=seen_classes) 200 | 201 | if (t % config.store_model_freq == 0) or last_step: 202 | torch.save(save_dict, osp.join(config.out_dir, "latest_models.pytorch")) 203 | 204 | if (t % config.store_results_freq == 0) or last_step: 205 | render_graphs(config) 206 | store(config) 207 | 208 | if last_step: 209 | return 210 | 211 | # -------------------------------------------------------------------------------------------- 212 | # Train 213 | # -------------------------------------------------------------------------------------------- 214 | 215 | tasks_model.train() 216 | 217 | xs = xs.to(get_device(config.cuda)) 218 | ys = ys.to(get_device(config.cuda)) 219 | 220 | curr_classes = ys.unique() 221 | assert (curr_classes.max() < config.task_out_dims[0]) 222 | 223 | optimizer.zero_grad() 224 | 225 | # set old_tasks_model if needed 226 | if not config.long_window or (config.long_window and t == config.next_update_old_model_t): 227 | old_tasks_model = deepcopy(tasks_model) 228 | if config.opt_batch_stats: old_tasks_model.compute_batch_stats_loss() 229 | 230 | if config.long_window: 231 | config.next_update_old_model_t_history.append(config.next_update_old_model_t) 232 | if config.use_fixed_window: 233 | window_offset = config.fixed_window 234 | else: 235 | window_offset = np.random.randint(config.long_window_range[0], high=( 236 | config.long_window_range[1] + 1)) # randint is excl 237 | 238 | config.next_update_old_model_t = t + window_offset 239 | 240 | # -------------------------------------------------------------------------------------------- 241 | # Train on real data 242 | # -------------------------------------------------------------------------------------------- 243 | 244 | preds = tasks_model(xs) 245 | loss_orig = F.cross_entropy(preds, ys, reduction="mean") 246 | record_and_check(config, "loss_orig", loss_orig.item(), t) 247 | loss_orig.backward() 248 | optimizer.step() 249 | 250 | # -------------------------------------------------------------------------------------------- 251 | # Generate data 252 | # -------------------------------------------------------------------------------------------- 253 | 254 | if t >= config.recall_from_t: 255 | optimizer.zero_grad() 256 | 257 | metrics = dict([(metric, 0.) for metric in refine_sample_metrics + refine_theta_metrics]) 258 | 259 | # pick classes for classes_loss 260 | num_classes = int(np.prod(config.task_out_dims)) 261 | if not config.choose_past_classes: 262 | classes_to_refine = torch.tensor( 263 | np.random.choice(num_classes, config.M, replace=(config.M > num_classes)), 264 | dtype=torch.long, device=get_device(config.cuda)) 265 | else: 266 | # Explicitly pick seen classes excluding present classes. There will be at least 2 bc 267 | # recall from 2nd task 268 | seen_classes_excl_pres = seen_classes.clone() 269 | num_seen = seen_classes.shape[0] 270 | for c in present_classes: 271 | seen_classes_excl_pres = seen_classes_excl_pres[seen_classes_excl_pres != c] 272 | num_seen_excl_pres = seen_classes_excl_pres.shape[0] 273 | 274 | assert (num_seen_excl_pres <= num_seen) 275 | chosen_inds = np.random.choice(num_seen_excl_pres, config.M, 276 | replace=(config.M > num_seen_excl_pres)) 277 | classes_to_refine = seen_classes_excl_pres[chosen_inds] 278 | 279 | for r in range(config.refine_theta_steps): 280 | if config.refine_sample_from_scratch or r == 0: 281 | if not config.aux_x_random: 282 | aux_x = xs[np.random.choice(xs.shape[0], config.M, replace=(config.M > xs.shape[0]))] 283 | else: 284 | aux_x = torch.rand((config.M,) + xs.shape[1:]).to(get_device(config.cuda)) 285 | 286 | aux_x_orig = aux_x.clone() 287 | aux_x.requires_grad_(True) 288 | 289 | for s in range(config.refine_sample_steps): 290 | aux_preds_old = old_tasks_model(aux_x) 291 | aux_preds_new = tasks_model(aux_x) 292 | 293 | if config.opt_batch_stats: 294 | aux_preds_old, opt_batch_stats_loss = aux_preds_old 295 | 296 | loss_refine = torch.tensor(0.).to(get_device(config.cuda)) 297 | if not config.no_classes_loss: 298 | classes_loss = deep_inversion_classes_loss(classes_to_refine, aux_preds_old) 299 | loss_refine += config.classes_loss_weight * classes_loss 300 | 301 | if config.adaptive: 302 | adaptive_loss = neg_symmetric_KL(aux_preds_old, aux_preds_new) 303 | loss_refine += config.adaptive_weight * adaptive_loss 304 | 305 | if config.sharpen_class: 306 | sharpen_class_loss = sharpen_class(aux_preds_old) 307 | loss_refine += config.sharpen_class_weight * sharpen_class_loss 308 | 309 | if config.TV: 310 | TV_loss = TV(aux_x) 311 | loss_refine += config.TV_weight * TV_loss 312 | 313 | if config.L2: 314 | L2_loss = L2(aux_x) 315 | loss_refine += config.L2_weight * L2_loss 316 | 317 | if config.opt_batch_stats: 318 | loss_refine += config.opt_batch_stats_weight * opt_batch_stats_loss 319 | 320 | for metric in refine_sample_metrics: 321 | metrics[metric] += locals()[metric].item() 322 | 323 | aux_x_grads = \ 324 | torch.autograd.grad(loss_refine, aux_x, only_inputs=True, retain_graph=False)[0] 325 | aux_x = (aux_x - config.refine_sample_lr * aux_x_grads).detach().requires_grad_(True) 326 | 327 | aux_x.requires_grad_(False) 328 | with torch.no_grad(): 329 | aux_y = old_tasks_model(aux_x) 330 | distill_xs_targets = old_tasks_model(xs) 331 | 332 | if config.opt_batch_stats: 333 | aux_y, _ = aux_y 334 | distill_xs_targets, _ = distill_xs_targets 335 | 336 | if config.render_aux_x and t % config.render_aux_x_freq == 0: 337 | render_aux_x(config, t, r, aux_x_orig, aux_x, aux_y, present_classes) 338 | 339 | aux_y_hard = aux_y.argmax(dim=1) 340 | 341 | if not hasattr(config, "aux_y_hard"): 342 | config.aux_y_hard = OrderedDict() 343 | if t not in config.aux_y_hard: 344 | config.aux_y_hard[t] = aux_y_hard.cpu() 345 | else: 346 | config.aux_y_hard[t] = torch.cat((config.aux_y_hard[t], aux_y_hard.cpu())).unique() 347 | 348 | if not hasattr(config, "aux_y_probs"): 349 | config.aux_y_probs = OrderedDict() 350 | aux_y_probs = F.softmax(aux_y, dim=1).mean(dim=0) 351 | if t not in config.aux_y_probs: 352 | config.aux_y_probs[t] = aux_y_probs.cpu() * (1. / config.refine_theta_steps) # n c 353 | else: 354 | config.aux_y_probs[t] += (aux_y_probs.cpu() * (1. / config.refine_theta_steps)) 355 | 356 | is_present_class = 0 357 | for c in present_classes: 358 | is_present_class += (aux_y_hard == c).sum().item() 359 | metrics["not_present_class"] += aux_y_hard.shape[0] - is_present_class 360 | 361 | # ---------------------------------------------------------------------------------------- 362 | # Train on generated data 363 | # ---------------------------------------------------------------------------------------- 364 | 365 | preds = tasks_model(aux_x) 366 | if not config.hard_targets: 367 | loss_aux = F.kl_div(F.log_softmax(preds, dim=1), 368 | F.softmax(aux_y, dim=1), reduction="batchmean") 369 | else: 370 | loss_aux = F.cross_entropy(preds, aux_y_hard, reduction="mean") 371 | 372 | final_loss_aux = loss_aux * config.aux_weight 373 | 374 | if not config.no_aux_distill: 375 | loss_aux_distill = F.kl_div(F.log_softmax(tasks_model(xs), dim=1), 376 | F.softmax(distill_xs_targets, dim=1), reduction="batchmean") 377 | final_loss_aux += config.aux_distill_weight * loss_aux_distill 378 | metrics["loss_aux_distill"] += loss_aux_distill.item() 379 | 380 | metrics["loss_aux"] += loss_aux.item() 381 | metrics["final_loss_aux"] += final_loss_aux.item() 382 | 383 | final_loss_aux.backward() 384 | optimizer.step() 385 | 386 | for metric in refine_sample_metrics: 387 | metrics[metric] /= float(config.refine_sample_steps * config.refine_theta_steps) 388 | record_and_check(config, metric, metrics[metric], t) 389 | 390 | for metric in refine_theta_metrics: 391 | metrics[metric] /= float(config.refine_theta_steps) 392 | record_and_check(config, metric, metrics[metric], t) 393 | 394 | t += 1 395 | if seen_classes is None: 396 | seen_classes = present_classes 397 | else: 398 | seen_classes = torch.cat((seen_classes, present_classes)).unique() 399 | last_classes = present_classes 400 | 401 | 402 | if __name__ == "__main__": 403 | ms = range(orig_config.model_ind_start, orig_config.model_ind_start + orig_config.num_runs) 404 | for m in ms: 405 | c = deepcopy(orig_config) 406 | c.model_ind = m 407 | main(c) 408 | print("Done m %d" % m) 409 | -------------------------------------------------------------------------------- /code/scripts/ARM.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from copy import deepcopy 3 | from datetime import datetime 4 | import os 5 | 6 | import torch.optim as optim 7 | 8 | from code.models import * 9 | from code.util.data import * 10 | from code.util.eval import evaluate_basic 11 | from code.util.general import * 12 | from code.util.load import * 13 | from code.util.losses import * 14 | from code.util.render import render_aux_x 15 | 16 | # -------------------------------------------------------------------------------------------------- 17 | # Settings 18 | # -------------------------------------------------------------------------------------------------- 19 | 20 | orig_config = argparse.ArgumentParser(allow_abbrev=False) 21 | 22 | orig_config.add_argument("--model_ind_start", type=int, required=True) 23 | 24 | orig_config.add_argument("--num_runs", type=int, required=True) 25 | 26 | orig_config.add_argument("--out_root", type=str, required=True) 27 | 28 | # Data and model 29 | 30 | orig_config.add_argument("--data", type=str, required=True) 31 | 32 | orig_config.add_argument("--data_path", type=str, required=True) 33 | 34 | orig_config.add_argument("--stationary", default=False, action="store_true") 35 | 36 | orig_config.add_argument("--classes_per_task", type=int, required=True) 37 | 38 | orig_config.add_argument("--max_t", type=int, required=True) 39 | 40 | orig_config.add_argument("--tasks_train_batch_sz", type=int, default=10) 41 | 42 | orig_config.add_argument("--tasks_eval_batch_sz", type=int, default=10) 43 | 44 | orig_config.add_argument("--num_iterations", type=int, default=1) 45 | 46 | orig_config.add_argument("--task_model_type", type=str, required=True) 47 | 48 | # ARM 49 | 50 | orig_config.add_argument("--recall_from_t", type=int, required=True) 51 | 52 | orig_config.add_argument("--M", type=int, default=100) 53 | 54 | orig_config.add_argument("--lr", type=float, default=0.1) # eta_0 55 | 56 | orig_config.add_argument("--refine_sample_lr", type=float, default=0.1) # eta_1 57 | 58 | orig_config.add_argument("--refine_sample_steps", type=int, default=10) # S 59 | 60 | orig_config.add_argument("--aux_weight", type=float, default=1.0) # lambda_0 (loss on recall) 61 | 62 | orig_config.add_argument("--aux_distill", default=False, action="store_true") 63 | 64 | orig_config.add_argument("--aux_distill_weight", type=float, default=1.0) # lambda_0 (loss on real) 65 | 66 | orig_config.add_argument("--divergence_loss_weight", type=float, default=1.0) # optional weight D 67 | 68 | orig_config.add_argument("--notlocal_weight", type=float, default=1.0) # lambda_1 69 | 70 | orig_config.add_argument("--notlocal_new_weight", type=float, default=1.0) # lambda_2 71 | 72 | orig_config.add_argument("--diversity_weight", type=float, default=1.0) # lambda_3 73 | 74 | orig_config.add_argument("--sharpen_class", default=False, action="store_true") 75 | 76 | orig_config.add_argument("--sharpen_class_weight", type=float, default=1.0) # lambda_4 77 | 78 | orig_config.add_argument("--L2", default=False, action="store_true") 79 | 80 | orig_config.add_argument("--L2_weight", type=float, default=1.0) # lambda_5 81 | 82 | orig_config.add_argument("--TV", default=False, action="store_true") 83 | 84 | orig_config.add_argument("--TV_weight", type=float, default=1.0) # lambda_6 85 | 86 | orig_config.add_argument("--long_window", default=False, action="store_true") 87 | 88 | orig_config.add_argument("--long_window_range", type=int, nargs="+", default=[1, 1]) 89 | 90 | orig_config.add_argument("--use_fixed_window", default=False, action="store_true") 91 | 92 | orig_config.add_argument("--fixed_window", type=int, default=1) # when to update old theta 93 | 94 | # Usually not used, included for testing: 95 | 96 | orig_config.add_argument("--refine_theta_steps", type=int, default=1) 97 | 98 | orig_config.add_argument("--refine_sample_from_scratch", default=False, action="store_true") 99 | 100 | orig_config.add_argument("--hard_targets", default=False, action="store_true") 101 | 102 | orig_config.add_argument("--aux_x_random", default=False, action="store_true") 103 | 104 | orig_config.add_argument("--use_crossent_as_D", default=False, action="store_true") 105 | 106 | orig_config.add_argument("--opt_batch_stats", default=False, action="store_true") 107 | 108 | orig_config.add_argument("--opt_batch_stats_weight", type=float, default=1.0) 109 | 110 | # Admin 111 | 112 | orig_config.add_argument("--cuda", default=False, action="store_true") 113 | 114 | orig_config.add_argument("--eval_freq", type=int, required=True) 115 | 116 | orig_config.add_argument("--store_results_freq", type=int, required=True) 117 | 118 | orig_config.add_argument("--store_model_freq", type=int, required=True) 119 | 120 | orig_config.add_argument("--specific_torch_seed", default=False, action="store_true") 121 | 122 | orig_config.add_argument("--torch_seed", type=int) 123 | 124 | orig_config.add_argument("--render_aux_x", default=False, action="store_true") 125 | 126 | orig_config.add_argument("--render_aux_x_freq", type=int, default=10) 127 | 128 | orig_config.add_argument("--render_aux_x_num", type=int, default=3) 129 | 130 | orig_config.add_argument("--render_separate", default=False, action="store_true") 131 | 132 | orig_config.add_argument("--count_params_only", default=False, action="store_true") 133 | 134 | orig_config = orig_config.parse_args() 135 | 136 | 137 | def main(config): 138 | # ------------------------------------------------------------------------------------------------ 139 | # Setup 140 | # ------------------------------------------------------------------------------------------------ 141 | 142 | reproc_settings(config) 143 | 144 | config.out_dir = osp.join(config.out_root, str(config.model_ind)) 145 | 146 | tasks_model, trainloader, testloader, valloader = get_model_and_data(config) 147 | 148 | if config.count_params_only: 149 | sz = 0 150 | for n, p in tasks_model.named_parameters(): 151 | curr_sz = np.prod(tuple(p.data.shape)) 152 | print((n, p.data.shape, curr_sz)) 153 | sz += curr_sz 154 | print("Num params: %d" % sz) 155 | exit(0) 156 | 157 | if not osp.exists(config.out_dir): 158 | os.makedirs(config.out_dir) 159 | 160 | next_t = 0 161 | last_classes = None # for info only 162 | seen_classes = None 163 | 164 | if config.long_window: 165 | old_tasks_model = None 166 | config.next_update_old_model_t_history = [] 167 | config.next_update_old_model_t = 0 # Old model set in first timestep 168 | assert (len(config.long_window_range) == 2) 169 | 170 | optimizer = optim.SGD(tasks_model.parameters(), lr=config.lr, momentum=0, dampening=0, 171 | weight_decay=0, nesterov=False) 172 | 173 | refine_sample_metrics = ["divergence_loss", "notlocal_loss", "notlocal_new_loss", 174 | "diversity_loss", "loss_refine"] 175 | refine_theta_metrics = ["not_present_class", "loss_aux", "final_loss_aux"] 176 | 177 | if config.sharpen_class: refine_sample_metrics.append("sharpen_class_loss") 178 | if config.TV: refine_sample_metrics.append("TV_loss") 179 | if config.L2: refine_sample_metrics.append("L2_loss") 180 | if config.opt_batch_stats: refine_sample_metrics.append("opt_batch_stats_loss") 181 | if config.aux_distill: refine_theta_metrics.append("loss_aux_distill") 182 | 183 | t = next_t 184 | 185 | while t <= config.max_t: 186 | for xs, ys in trainloader: 187 | present_classes = ys.unique().to(get_device(config.cuda)) 188 | 189 | # -------------------------------------------------------------------------------------------- 190 | # Eval 191 | # -------------------------------------------------------------------------------------------- 192 | 193 | if (t - next_t) < 1000 or t % 100 == 0: 194 | print("m %d t: %d %s, targets %s" % ( 195 | config.model_ind, t, datetime.now(), str(list(present_classes.cpu().numpy())))) 196 | sys.stdout.flush() 197 | 198 | save_dict = {"tasks_model": tasks_model, "t": t, "last_classes": last_classes, 199 | "seen_classes": seen_classes} 200 | if config.long_window: # else no need to save because it's always the one just before 201 | # current update 202 | save_dict["old_tasks_model"] = old_tasks_model 203 | 204 | last_step = t == (config.max_t) 205 | if (t % config.eval_freq == 0) or (t % config.batches_per_epoch == 0) or last_step or ( 206 | t == 0): 207 | evaluate_basic(config, tasks_model, valloader, t, is_val=True, 208 | last_classes=last_classes, seen_classes=seen_classes) 209 | evaluate_basic(config, tasks_model, testloader, t, is_val=False, 210 | last_classes=last_classes, seen_classes=seen_classes) 211 | 212 | if (t % config.store_model_freq == 0) or last_step: 213 | torch.save(save_dict, osp.join(config.out_dir, "latest_models.pytorch")) 214 | 215 | if (t % config.store_results_freq == 0) or last_step: 216 | render_graphs(config) 217 | store(config) 218 | 219 | if last_step: 220 | return 221 | 222 | # -------------------------------------------------------------------------------------------- 223 | # Train 224 | # -------------------------------------------------------------------------------------------- 225 | 226 | tasks_model.train() 227 | 228 | xs = xs.to(get_device(config.cuda)) 229 | ys = ys.to(get_device(config.cuda)) 230 | 231 | curr_classes = ys.unique() 232 | 233 | optimizer.zero_grad() 234 | 235 | # Update old_tasks_model if needed 236 | if not config.long_window or (config.long_window and t == config.next_update_old_model_t): 237 | old_tasks_model = deepcopy(tasks_model) 238 | if config.opt_batch_stats: old_tasks_model.compute_batch_stats_loss() 239 | 240 | if config.long_window: 241 | config.next_update_old_model_t_history.append(config.next_update_old_model_t) 242 | if config.use_fixed_window: 243 | window_offset = config.fixed_window 244 | else: 245 | window_offset = np.random.randint(config.long_window_range[0], high=( 246 | config.long_window_range[1] + 1)) # randint is excl 247 | 248 | config.next_update_old_model_t = t + window_offset 249 | 250 | # -------------------------------------------------------------------------------------------- 251 | # Train on real data 252 | # -------------------------------------------------------------------------------------------- 253 | 254 | preds = tasks_model(xs) 255 | loss_orig = F.cross_entropy(preds, ys, reduction="mean") 256 | record_and_check(config, "loss_orig", loss_orig.item(), t) 257 | loss_orig.backward() 258 | optimizer.step() 259 | 260 | # -------------------------------------------------------------------------------------------- 261 | # Generate recall 262 | # -------------------------------------------------------------------------------------------- 263 | 264 | if t >= config.recall_from_t: 265 | optimizer.zero_grad() 266 | 267 | metrics = dict([(metric, 0.) for metric in refine_sample_metrics + refine_theta_metrics]) 268 | 269 | for r in range(config.refine_theta_steps): # each step updates tasks_model 270 | if config.refine_sample_from_scratch or r == 0: # fresh aux_x 271 | if not config.aux_x_random: 272 | chosen_aux_inds = np.random.choice(xs.shape[0], config.M, 273 | replace=(config.M > xs.shape[0])) 274 | aux_x = xs[chosen_aux_inds] 275 | else: 276 | aux_x = torch.rand((config.M,) + xs.shape[1:]).to(get_device(config.cuda)) 277 | 278 | aux_x_orig = aux_x.clone() 279 | aux_x.requires_grad_(True) 280 | 281 | for s in range(config.refine_sample_steps): 282 | aux_preds_old = old_tasks_model(aux_x) 283 | aux_preds_new = tasks_model(aux_x) 284 | 285 | if config.opt_batch_stats: 286 | aux_preds_old, opt_batch_stats_loss = aux_preds_old 287 | 288 | if config.use_crossent_as_D: 289 | divergence_loss = - crossent_logits(preds=aux_preds_new, targets=aux_preds_old) 290 | else: 291 | divergence_loss = neg_symmetric_KL(aux_preds_old, aux_preds_new) 292 | 293 | notlocal_loss = notlocal(aux_preds_old, present_classes) 294 | notlocal_new_loss = notlocal(aux_preds_new, present_classes) 295 | 296 | diversity_loss = diversity(aux_preds_old) 297 | 298 | loss_refine = config.divergence_loss_weight * divergence_loss + \ 299 | config.diversity_weight * diversity_loss + \ 300 | config.notlocal_weight * notlocal_loss + config.notlocal_new_weight * \ 301 | notlocal_new_loss 302 | 303 | if config.sharpen_class: 304 | sharpen_class_loss = sharpen_class(aux_preds_old) 305 | loss_refine += config.sharpen_class_weight * sharpen_class_loss 306 | 307 | if config.TV: 308 | TV_loss = TV(aux_x) 309 | loss_refine += config.TV_weight * TV_loss 310 | 311 | if config.L2: 312 | L2_loss = L2(aux_x) 313 | loss_refine += config.L2_weight * L2_loss 314 | 315 | if config.opt_batch_stats: 316 | loss_refine += config.opt_batch_stats_weight * opt_batch_stats_loss 317 | 318 | for metric in refine_sample_metrics: 319 | metrics[metric] += locals()[metric].item() 320 | 321 | aux_x_grads = \ 322 | torch.autograd.grad(loss_refine, aux_x, only_inputs=True, retain_graph=False)[0] 323 | aux_x = (aux_x - config.refine_sample_lr * aux_x_grads).detach().requires_grad_(True) 324 | 325 | # Get final predictions on recalled data from old model 326 | aux_x.requires_grad_(False) 327 | with torch.no_grad(): 328 | aux_y = old_tasks_model(aux_x) 329 | distill_xs_targets = old_tasks_model(xs) 330 | 331 | if config.opt_batch_stats: 332 | aux_y, _ = aux_y 333 | distill_xs_targets, _ = distill_xs_targets 334 | 335 | if config.render_aux_x and t % config.render_aux_x_freq == 0: 336 | render_aux_x(config, t, r, aux_x_orig, aux_x, aux_y, present_classes) 337 | 338 | # count number whose top class isn't in current classes 339 | aux_y_hard = aux_y.argmax(dim=1) 340 | 341 | if not hasattr(config, "aux_y_hard"): 342 | config.aux_y_hard = OrderedDict() 343 | if t not in config.aux_y_hard: 344 | config.aux_y_hard[t] = aux_y_hard.cpu() 345 | else: 346 | config.aux_y_hard[t] = torch.cat((config.aux_y_hard[t], aux_y_hard.cpu())).unique() 347 | 348 | if not hasattr(config, "aux_y_probs"): 349 | config.aux_y_probs = OrderedDict() 350 | aux_y_probs = F.softmax(aux_y, dim=1).mean(dim=0) 351 | if t not in config.aux_y_probs: 352 | config.aux_y_probs[t] = aux_y_probs.cpu() * (1. / config.refine_theta_steps) # n c 353 | else: 354 | config.aux_y_probs[t] += (aux_y_probs.cpu() * (1. / config.refine_theta_steps)) 355 | 356 | is_present_class = 0 357 | for c in present_classes: 358 | is_present_class += (aux_y_hard == c).sum().item() 359 | metrics["not_present_class"] += aux_y_hard.shape[0] - is_present_class 360 | 361 | # ---------------------------------------------------------------------------------------- 362 | # Train on recalled data 363 | # ---------------------------------------------------------------------------------------- 364 | 365 | preds = tasks_model(aux_x) 366 | if not config.hard_targets: 367 | loss_aux = crossent_logits(preds=preds, targets=aux_y) 368 | else: 369 | loss_aux = F.cross_entropy(preds, aux_y_hard, reduction="mean") 370 | 371 | final_loss_aux = loss_aux * config.aux_weight 372 | 373 | if config.aux_distill: 374 | loss_aux_distill = crossent_logits(tasks_model(xs), distill_xs_targets) 375 | #loss_aux_distill = F.kl_div(F.log_softmax(tasks_model(xs), dim=1), 376 | # F.softmax(distill_xs_targets, dim=1), reduction="batchmean") 377 | final_loss_aux += config.aux_distill_weight * loss_aux_distill 378 | metrics["loss_aux_distill"] += loss_aux_distill.item() 379 | 380 | metrics["loss_aux"] += loss_aux.item() 381 | metrics["final_loss_aux"] += final_loss_aux.item() 382 | 383 | final_loss_aux.backward() 384 | optimizer.step() 385 | 386 | for metric in refine_sample_metrics: 387 | metrics[metric] /= float(config.refine_sample_steps * config.refine_theta_steps) 388 | record_and_check(config, metric, metrics[metric], t) 389 | 390 | for metric in refine_theta_metrics: 391 | metrics[metric] /= float(config.refine_theta_steps) 392 | record_and_check(config, metric, metrics[metric], t) 393 | 394 | if t < config.recall_from_t and config.render_aux_x and t % config.render_aux_x_freq == 0: 395 | render_aux_x(config, t, 0, xs, None, None, present_classes) 396 | 397 | t += 1 398 | if seen_classes is None: 399 | seen_classes = present_classes 400 | else: 401 | seen_classes = torch.cat((seen_classes, present_classes)).unique() 402 | last_classes = present_classes 403 | 404 | 405 | if __name__ == "__main__": 406 | ms = range(orig_config.model_ind_start, orig_config.model_ind_start + orig_config.num_runs) 407 | for m in ms: 408 | print("Doing m %d" % m) 409 | c = deepcopy(orig_config) 410 | c.model_ind = m 411 | main(c) 412 | print("Done m %d" % m) 413 | -------------------------------------------------------------------------------- /code/scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xu-ji/ARM/704f69d73765f13ca5e2d8aee11f399b53c635e6/code/scripts/__init__.py -------------------------------------------------------------------------------- /code/scripts/distill.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from copy import deepcopy 4 | from datetime import datetime 5 | 6 | import torch.optim as optim 7 | 8 | from code.util.eval import evaluate_basic 9 | from code.util.general import * 10 | from code.util.load import * 11 | from code.util.losses import * 12 | 13 | # -------------------------------------------------------------------------------------------------- 14 | # Settings 15 | # -------------------------------------------------------------------------------------------------- 16 | 17 | orig_config = argparse.ArgumentParser(allow_abbrev=False) 18 | 19 | orig_config.add_argument("--model_ind_start", type=int, required=True) 20 | 21 | orig_config.add_argument("--num_runs", type=int, required=True) 22 | 23 | orig_config.add_argument("--out_root", type=str, required=True) 24 | 25 | # Data and model 26 | 27 | orig_config.add_argument("--data", type=str, default="cifar100") 28 | 29 | orig_config.add_argument("--data_path", type=str, required=True) 30 | 31 | orig_config.add_argument("--stationary", default=False, action="store_true") 32 | 33 | orig_config.add_argument("--classes_per_task", type=int, required=True) 34 | 35 | orig_config.add_argument("--max_t", type=int, required=True) 36 | 37 | orig_config.add_argument("--tasks_train_batch_sz", type=int, default=10) 38 | 39 | orig_config.add_argument("--tasks_eval_batch_sz", type=int, default=10) 40 | 41 | orig_config.add_argument("--num_iterations", type=int, default=1) 42 | 43 | orig_config.add_argument("--task_model_type", type=str, required=True) 44 | 45 | # Distill options 46 | 47 | orig_config.add_argument("--lr", type=float, default=0.1) 48 | 49 | orig_config.add_argument("--recall_from_t", type=int, required=True) 50 | 51 | orig_config.add_argument("--long_window", default=False, action="store_true") 52 | 53 | orig_config.add_argument("--long_window_range", type=int, nargs="+", default=[1, 1]) # inclusive 54 | 55 | orig_config.add_argument("--use_fixed_window", default=False, action="store_true") 56 | 57 | orig_config.add_argument("--fixed_window", type=int, default=1) 58 | 59 | orig_config.add_argument("--aux_distill_weight", type=float, default=1.0) 60 | 61 | # Admin 62 | 63 | orig_config.add_argument("--cuda", default=False, action="store_true") 64 | 65 | orig_config.add_argument("--eval_freq", type=int, required=True) 66 | 67 | orig_config.add_argument("--store_results_freq", type=int, required=True) 68 | 69 | orig_config.add_argument("--store_model_freq", type=int, required=True) 70 | 71 | orig_config.add_argument("--specific_torch_seed", default=False, action="store_true") 72 | 73 | orig_config.add_argument("--torch_seed", type=int) 74 | 75 | orig_config = orig_config.parse_args() 76 | 77 | 78 | def main(config): 79 | # ------------------------------------------------------------------------------------------------ 80 | # Setup 81 | # ------------------------------------------------------------------------------------------------ 82 | 83 | reproc_settings(config) 84 | 85 | config.out_dir = osp.join(config.out_root, str(config.model_ind)) 86 | 87 | tasks_model, trainloader, testloader, valloader = get_model_and_data(config) 88 | 89 | if not osp.exists(config.out_dir): 90 | os.makedirs(config.out_dir) 91 | 92 | next_t = 0 93 | last_classes = None 94 | seen_classes = None 95 | 96 | if config.long_window: 97 | old_tasks_model = None 98 | config.next_update_old_model_t_history = [] 99 | config.next_update_old_model_t = 0 100 | assert (len(config.long_window_range) == 2) 101 | 102 | optimizer = optim.SGD(tasks_model.parameters(), lr=config.lr, momentum=0, dampening=0, 103 | weight_decay=0, nesterov=False) 104 | 105 | refine_theta_metrics = ["final_loss_aux", "loss_aux_distill"] 106 | 107 | t = next_t 108 | 109 | while t <= config.max_t: 110 | for xs, ys in trainloader: 111 | present_classes = ys.unique().to(get_device(config.cuda)) 112 | 113 | # -------------------------------------------------------------------------------------------- 114 | # Eval 115 | # -------------------------------------------------------------------------------------------- 116 | 117 | if (t - next_t) < 1000 or t % 100 == 0: 118 | print("m %d t: %d %s, targets %s" % ( 119 | config.model_ind, t, datetime.now(), str(list(present_classes.cpu().numpy())))) 120 | sys.stdout.flush() 121 | 122 | save_dict = {"tasks_model": tasks_model, "t": t, "last_classes": last_classes, 123 | "seen_classes": seen_classes} 124 | if config.long_window: # else no need to save because it's always the one just before 125 | # current update 126 | save_dict["old_tasks_model"] = old_tasks_model 127 | 128 | last_step = t == (config.max_t) 129 | if (t % config.eval_freq == 0) or (t % config.batches_per_epoch == 0) or last_step or ( 130 | t == 0): 131 | evaluate_basic(config, tasks_model, valloader, t, is_val=True, 132 | last_classes=last_classes, seen_classes=seen_classes) 133 | evaluate_basic(config, tasks_model, testloader, t, is_val=False, 134 | last_classes=last_classes, seen_classes=seen_classes) 135 | 136 | if (t % config.store_model_freq == 0) or last_step: 137 | torch.save(save_dict, osp.join(config.out_dir, "latest_models.pytorch")) 138 | 139 | if (t % config.store_results_freq == 0) or last_step: 140 | render_graphs(config) 141 | store(config) 142 | 143 | if last_step: 144 | return 145 | 146 | # -------------------------------------------------------------------------------------------- 147 | # Train 148 | # -------------------------------------------------------------------------------------------- 149 | 150 | tasks_model.train() 151 | 152 | xs = xs.to(get_device(config.cuda)) 153 | ys = ys.to(get_device(config.cuda)) 154 | 155 | curr_classes = ys.unique() 156 | assert (curr_classes.max() < config.task_out_dims[0]) 157 | 158 | optimizer.zero_grad() 159 | 160 | # set old_tasks_model if needed 161 | if not config.long_window or (config.long_window and t == config.next_update_old_model_t): 162 | old_tasks_model = deepcopy(tasks_model) 163 | 164 | if config.long_window: 165 | config.next_update_old_model_t_history.append(config.next_update_old_model_t) 166 | if config.use_fixed_window: 167 | window_offset = config.fixed_window 168 | else: 169 | window_offset = np.random.randint(config.long_window_range[0], 170 | high=(config.long_window_range[1] + 1)) 171 | 172 | config.next_update_old_model_t = t + window_offset 173 | 174 | # -------------------------------------------------------------------------------------------- 175 | # Train on real data 176 | # -------------------------------------------------------------------------------------------- 177 | 178 | preds = tasks_model(xs) 179 | loss_orig = F.cross_entropy(preds, ys, reduction="mean") 180 | record_and_check(config, "loss_orig", loss_orig.item(), t) # added! 181 | loss_orig.backward() 182 | optimizer.step() # updates tasks_model, which is now \theta' 183 | 184 | # -------------------------------------------------------------------------------------------- 185 | # Distill 186 | # -------------------------------------------------------------------------------------------- 187 | if t >= config.recall_from_t: 188 | optimizer.zero_grad() 189 | metrics = dict([(metric, 0.) for metric in refine_theta_metrics]) 190 | 191 | with torch.no_grad(): 192 | distill_xs_targets = old_tasks_model(xs) 193 | 194 | loss_aux_distill = F.kl_div(F.log_softmax(tasks_model(xs), dim=1), 195 | F.softmax(distill_xs_targets, dim=1), reduction="batchmean") 196 | final_loss_aux = config.aux_distill_weight * loss_aux_distill 197 | 198 | metrics["loss_aux_distill"] += loss_aux_distill.item() 199 | metrics["final_loss_aux"] += final_loss_aux.item() 200 | 201 | final_loss_aux.backward() 202 | optimizer.step() # updates tasks_model 203 | 204 | for metric in refine_theta_metrics: 205 | record_and_check(config, metric, metrics[metric], t) 206 | 207 | t += 1 208 | if seen_classes is None: 209 | seen_classes = present_classes 210 | else: 211 | seen_classes = torch.cat((seen_classes, present_classes)).unique() 212 | last_classes = present_classes 213 | 214 | 215 | if __name__ == "__main__": 216 | ms = range(orig_config.model_ind_start, orig_config.model_ind_start + orig_config.num_runs) 217 | for m in ms: 218 | print("Doing m %d" % m) 219 | c = deepcopy(orig_config) 220 | c.model_ind = m 221 | main(c) 222 | print("Done m %d" % m) 223 | -------------------------------------------------------------------------------- /code/scripts/naive.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | from datetime import datetime 5 | 6 | import torch.optim as optim 7 | 8 | from code.models import * 9 | from code.util.eval import evaluate_basic 10 | from code.util.general import * 11 | from code.util.load import * 12 | 13 | # -------------------------------------------------------------------------------------------------- 14 | # Settings 15 | # -------------------------------------------------------------------------------------------------- 16 | 17 | orig_config = argparse.ArgumentParser(allow_abbrev=False) 18 | 19 | orig_config.add_argument("--model_ind_start", type=int, required=True) 20 | 21 | orig_config.add_argument("--num_runs", type=int, required=True) 22 | 23 | orig_config.add_argument("--out_root", type=str, required=True) 24 | 25 | # Data and model 26 | 27 | orig_config.add_argument("--data", type=str, default="cifar100") 28 | 29 | orig_config.add_argument("--data_path", type=str, required=True) 30 | 31 | orig_config.add_argument("--stationary", default=False, action="store_true") 32 | 33 | orig_config.add_argument("--classes_per_task", type=int, required=True) 34 | 35 | orig_config.add_argument("--max_t", type=int, required=True) 36 | 37 | orig_config.add_argument("--tasks_train_batch_sz", type=int, default=10) 38 | 39 | orig_config.add_argument("--tasks_eval_batch_sz", type=int, default=10) 40 | 41 | orig_config.add_argument("--num_iterations", type=int, default=1) 42 | 43 | orig_config.add_argument("--task_model_type", type=str, required=True) 44 | 45 | orig_config.add_argument("--lr", type=float, default=0.1) 46 | 47 | # Admin 48 | 49 | orig_config.add_argument("--cuda", default=False, action="store_true") 50 | 51 | orig_config.add_argument("--eval_freq", type=int, required=True) 52 | 53 | orig_config.add_argument("--store_results_freq", type=int, required=True) 54 | 55 | orig_config.add_argument("--store_model_freq", type=int, required=True) 56 | 57 | orig_config.add_argument("--specific_torch_seed", default=False, action="store_true") 58 | 59 | orig_config.add_argument("--torch_seed", type=int) 60 | 61 | orig_config = orig_config.parse_args() 62 | 63 | 64 | def main(config): 65 | # ------------------------------------------------------------------------------------------------ 66 | # Setup 67 | # ------------------------------------------------------------------------------------------------ 68 | 69 | reproc_settings(config) 70 | 71 | config.out_dir = osp.join(config.out_root, str(config.model_ind)) 72 | 73 | tasks_model, trainloader, testloader, valloader = get_model_and_data(config) 74 | 75 | if not osp.exists(config.out_dir): 76 | os.makedirs(config.out_dir) 77 | 78 | next_t = 0 79 | last_classes = None 80 | seen_classes = None 81 | 82 | optimizer = optim.SGD(tasks_model.parameters(), lr=config.lr, momentum=0, dampening=0, 83 | weight_decay=0, nesterov=False) # basic 84 | 85 | t = next_t 86 | 87 | while t <= config.max_t: 88 | for xs, ys in trainloader: 89 | present_classes = ys.unique().to(get_device(config.cuda)) 90 | 91 | # -------------------------------------------------------------------------------------------- 92 | # Eval 93 | # -------------------------------------------------------------------------------------------- 94 | 95 | if (t - next_t) < 1000 or t % 100 == 0: 96 | print("m %d t %d %s, fst targets %s" % ( 97 | config.model_ind, t, datetime.now(), str(list(present_classes.cpu().numpy())))) 98 | sys.stdout.flush() 99 | 100 | save_dict = {"tasks_model": tasks_model, "t": t, "last_classes": last_classes, 101 | "seen_classes": seen_classes} 102 | 103 | last_step = t == (config.max_t) 104 | if (t % config.eval_freq == 0) or (t % config.batches_per_epoch == 0) or last_step or ( 105 | t == 0): 106 | evaluate_basic(config, tasks_model, valloader, t, is_val=True, last_classes=last_classes, 107 | seen_classes=seen_classes) 108 | evaluate_basic(config, tasks_model, testloader, t, is_val=False, last_classes=last_classes, 109 | seen_classes=seen_classes) 110 | 111 | if (t % config.store_model_freq == 0) or last_step: 112 | torch.save(save_dict, osp.join(config.out_dir, "latest_models.pytorch")) 113 | 114 | if (t % config.store_results_freq == 0) or last_step: 115 | render_graphs(config) 116 | store(config) 117 | 118 | if last_step: 119 | return 120 | 121 | # -------------------------------------------------------------------------------------------- 122 | # Train 123 | # -------------------------------------------------------------------------------------------- 124 | 125 | tasks_model.train() 126 | 127 | optimizer.zero_grad() 128 | 129 | xs = xs.to(get_device(config.cuda)) 130 | ys = ys.to(get_device(config.cuda)) 131 | 132 | preds = tasks_model(xs) 133 | loss_orig = torch.nn.functional.cross_entropy(preds, ys, reduction="mean") 134 | 135 | loss_orig.backward() 136 | 137 | optimizer.step() 138 | 139 | record_and_check(config, "loss_orig", loss_orig.item(), t) 140 | 141 | t += 1 142 | if seen_classes is None: 143 | seen_classes = present_classes 144 | else: 145 | seen_classes = torch.cat((seen_classes, present_classes)).unique() 146 | last_classes = present_classes 147 | 148 | 149 | if __name__ == "__main__": 150 | ms = range(orig_config.model_ind_start, orig_config.model_ind_start + orig_config.num_runs) 151 | for m in ms: 152 | print("Doing m %d" % m) 153 | c = deepcopy(orig_config) 154 | c.model_ind = m 155 | main(c) 156 | print("Done m %d" % m) 157 | -------------------------------------------------------------------------------- /code/scripts/print_results.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pickle 4 | import torch 5 | import numpy as np 6 | 7 | args = argparse.ArgumentParser(allow_abbrev=False) 8 | args.add_argument("--root", type=str, required=True) 9 | args.add_argument("--start", type=int, required=True) 10 | args.add_argument("--load_model", default=False, action="store_true") 11 | args.add_argument("--num_runs", type=int, default=5) 12 | args = args.parse_args() 13 | 14 | def treat_underscores(x): 15 | res = [] 16 | for c in x: 17 | if c == "_": 18 | res.append("\\_") 19 | else: 20 | res.append(c) 21 | 22 | return "".join(res) 23 | 24 | 25 | def print_results(args): 26 | ms_avg = {"val": {"acc": [], "forgetting": []}, 27 | "test": {"acc": [], "forgetting": []}} 28 | 29 | for m in range(args.start, args.start + args.num_runs): 30 | out_dir = os.path.join(args.root, str(m)) 31 | config_p = os.path.join(out_dir, "config.pickle") 32 | 33 | config = None 34 | tries = 0 35 | while tries < 1000: 36 | try: 37 | with open(config_p, "rb") as config_f: 38 | config = pickle.load(config_f) 39 | break 40 | except: 41 | tries += 1 42 | 43 | if config is None: 44 | continue 45 | 46 | if args.load_model: 47 | torch.load(os.path.join(config.out_dir, "latest_models.pytorch")) 48 | 49 | actual_t = config.max_t 50 | for prefix in ["val", "test"]: 51 | if not config.stationary: 52 | accs_dict = getattr(config, "%s_accs" % prefix) 53 | 54 | ms_avg[prefix]["acc"].append(accs_dict[actual_t]) 55 | 56 | forgetting_dict = getattr(config, "%s_forgetting" % prefix) 57 | if actual_t in forgetting_dict: 58 | ms_avg[prefix]["forgetting"].append(forgetting_dict[actual_t]) 59 | 60 | print("model %d, %s: acc %.4f, forgetting %.4f" % ( 61 | config.model_ind, prefix, accs_dict[actual_t], forgetting_dict[actual_t])) 62 | else: 63 | accs_dict = getattr(config, "%s_accs_data" % prefix) 64 | ms_avg[prefix]["acc"].append(accs_dict[actual_t]) 65 | print("model %d, %s: acc %.4f" % (config.model_ind, prefix, accs_dict[actual_t])) 66 | 67 | print("---") 68 | 69 | for prefix in ["val", "test"]: 70 | for metric in ["acc", "forgetting"]: 71 | if len(ms_avg[prefix][metric]) == 0: 72 | ms_avg[prefix][metric] = (-1, -1) 73 | else: 74 | avg = np.array(ms_avg[prefix][metric]).mean() 75 | std = np.array(ms_avg[prefix][metric]).std() 76 | ms_avg[prefix][metric] = (avg, std) 77 | 78 | print("average %s: acc %.4f +- %.4f, forgetting %.4f +- %.4f" % ( 79 | prefix, ms_avg[prefix]["acc"][0], ms_avg[prefix]["acc"][1], 80 | ms_avg[prefix]["forgetting"][0], ms_avg[prefix]["forgetting"][1])) 81 | 82 | 83 | if __name__ == "__main__": 84 | print_results(args) 85 | -------------------------------------------------------------------------------- /code/scripts/print_table.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import argparse 4 | import pickle 5 | import numpy as np 6 | 7 | args = argparse.ArgumentParser(allow_abbrev=False) 8 | args.add_argument("--root", type=str, required=True) 9 | args.add_argument("--num_runs", type=int, default=5) 10 | args = args.parse_args() 11 | 12 | experiments = [ 13 | ("ARM MNIST", 6579), 14 | ("ARM Cifar10", 3717), 15 | ("ARM MiniImageNet", 4821), 16 | 17 | ("ADI MNIST", 6232), 18 | ("ADI Cifar10", 2262), 19 | ("ADI MiniImageNet", 6042), 20 | 21 | ("Distill MNIST", 6102), 22 | ("Distill Cifar10", 4082), 23 | ("Distill MiniImageNet", 6092), 24 | 25 | ("Naive MNIST", 4967), 26 | ("Naive Cifar10", 2522), 27 | ("Naive MiniImageNet", 4557), 28 | 29 | ("Naive stationary MNIST", 4947), 30 | ("Naive stationary Cifar10", 6462), 31 | ("Naive stationary MiniImageNet", 4527), 32 | 33 | # table 4 34 | ("Distill Cifar10, unit lag", 6452), 35 | ("ADI Cifar10, unit lag", 6467), 36 | ("ADI Cifar10, no distill", 2327), 37 | ("ARM Cifar10, unit lag", 6602), 38 | ("ARM Cifar10, no distill", 2717), 39 | 40 | # table 7 41 | ("$\lambda_1 = 0, \lambda_2 = 0$", 3982), 42 | ("$\lambda_3 = 0$", 6562), 43 | ("$\lambda_4 = 0$", 3977), 44 | ("$\lambda_5 = 0$", 3967), 45 | ("$\lambda_6 = 0$", 3972), 46 | ("$M = 150$ (+50)", 6624), 47 | ("$M = 50$ (-50)", 5922), 48 | ("$S = 20$ (doubled)", 3957), 49 | ("$S = 5$ (halved)", 3962), 50 | ("Cross-entropy as D", 3617), 51 | ("Random noise init", 4077), 52 | ("Recall 2x per t", 6502), 53 | ("Recall 4x per t", 6507), 54 | ] 55 | 56 | num_runs = 5 57 | 58 | print("LaTeX table:") 59 | print("\\begin{table}[h]") 60 | print("\\centering") 61 | print("\\fontsize{7}{7}\\selectfont") 62 | print("\\begin{tabular}{l c c c c}") 63 | print("\\toprule") 64 | print("& \\multicolumn{2}{c}{Val} & \\multicolumn{2}{c}{Test} \\\\") 65 | print("& Accuracy & Forgetting & Accuracy & Forgetting \\\\") 66 | print("\\midrule") 67 | for name, m_start in experiments: 68 | ms_avg = {"val": {"acc": [], "forgetting": []}, 69 | "test": {"acc": [], "forgetting": []}} 70 | 71 | counts = 0 72 | for m in range(m_start, m_start + args.num_runs): 73 | out_dir = os.path.join(args.root, str(m)) 74 | config_p = os.path.join(out_dir, "config.pickle") 75 | 76 | config = None 77 | tries = 0 78 | while tries < 1000: 79 | try: 80 | with open(config_p, "rb") as config_f: 81 | config = pickle.load(config_f) 82 | break 83 | except: 84 | tries += 1 85 | 86 | if config is None: 87 | continue 88 | 89 | actual_t = config.max_t 90 | 91 | if not actual_t in config.test_accs: 92 | continue 93 | 94 | for prefix in ["val", "test"]: 95 | if not config.stationary: 96 | accs_dict = getattr(config, "%s_accs" % prefix) 97 | 98 | ms_avg[prefix]["acc"].append(accs_dict[actual_t]) 99 | 100 | forgetting_dict = getattr(config, "%s_forgetting" % prefix) 101 | if actual_t in forgetting_dict: 102 | ms_avg[prefix]["forgetting"].append(forgetting_dict[actual_t]) 103 | else: 104 | accs_dict = getattr(config, "%s_accs_data" % prefix) 105 | ms_avg[prefix]["acc"].append(accs_dict[actual_t]) 106 | 107 | counts += 1 108 | 109 | for prefix in ["val", "test"]: 110 | for metric in ["acc", "forgetting"]: 111 | if len(ms_avg[prefix][metric]) == 0: 112 | ms_avg[prefix][metric] = (-1, -1) 113 | else: 114 | avg = np.array(ms_avg[prefix][metric]).mean() 115 | std = np.array(ms_avg[prefix][metric]).std() 116 | ms_avg[prefix][metric] = (avg, std) 117 | 118 | print("%s (%d) & %.4f $\pm$ %.4f & %.4f $\pm$ %.4f & %.4f $\pm$ %.4f & %.4f $\pm$ %.4f \\\\" % 119 | (name, counts, 120 | 121 | ms_avg["val"]["acc"][0], ms_avg["val"]["acc"][1], 122 | ms_avg["val"]["forgetting"][0], ms_avg["val"]["forgetting"][1], 123 | 124 | ms_avg["test"]["acc"][0], ms_avg["test"]["acc"][1], 125 | ms_avg["test"]["forgetting"][0], ms_avg["test"]["forgetting"][1], 126 | )) 127 | 128 | print("\\bottomrule") 129 | print("\\end{tabular}") 130 | print("\\end{table}") 131 | -------------------------------------------------------------------------------- /code/util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xu-ji/ARM/704f69d73765f13ca5e2d8aee11f399b53c635e6/code/util/__init__.py -------------------------------------------------------------------------------- /code/util/data.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader, SequentialSampler 2 | from torchvision import transforms 3 | 4 | from code.data import * 5 | 6 | 7 | def get_data(config): 8 | dataloaders, datasets = globals()["get_%s_loaders" % config.data](config) 9 | num_train_samples = len(datasets[0]) 10 | assert (num_train_samples % config.tasks_train_batch_sz) == 0 11 | config.batches_per_epoch = int(num_train_samples / config.tasks_train_batch_sz) 12 | 13 | print("length of training d, test d, val d: %d %d %d, batches_per_epoch %d" % 14 | (num_train_samples, len(datasets[1]), len(datasets[2]), config.batches_per_epoch)) 15 | assert (config.store_model_freq == config.batches_per_epoch) # once per epoch 16 | 17 | return dataloaders 18 | 19 | 20 | def get_cifar10_loaders(config): 21 | assert (config.data == "cifar10") 22 | #config.task_in_dims = (3, 32, 32) 23 | #config.task_out_dims = (10,) 24 | two_classes_per_block = (config.classes_per_task == 2) 25 | 26 | train_fns = [transforms.ToTensor()] 27 | transform_train = transforms.Compose(train_fns) 28 | 29 | cifar10_training = cifar10(root=config.data_path, train=True, transform=transform_train, 30 | non_stat=(not config.stationary), 31 | two_classes_per_block=two_classes_per_block, 32 | num_iterations=config.num_iterations) 33 | 34 | if not config.stationary: 35 | cifar10_training_loader = DataLoader(cifar10_training, 36 | sampler=SequentialSampler(cifar10_training), shuffle=False, 37 | batch_size=config.tasks_train_batch_sz) 38 | else: 39 | cifar10_training_loader = DataLoader(cifar10_training, shuffle=True, 40 | batch_size=config.tasks_train_batch_sz) 41 | 42 | test_fns = [transforms.ToTensor()] 43 | transform_test = transforms.Compose(test_fns) 44 | 45 | evals_non_stat = False # does not make a difference, test time behavior independent of rest of 46 | # batch 47 | 48 | cifar10_val = cifar10val(root=config.data_path, transform=transform_test, 49 | non_stat=evals_non_stat, two_classes_per_block=two_classes_per_block) 50 | cifar10_val_loader = DataLoader(cifar10_val, shuffle=False, batch_size=config.tasks_eval_batch_sz) 51 | 52 | cifar10_test = cifar10(root=config.data_path, train=False, transform=transform_test, 53 | num_iterations=1, 54 | non_stat=evals_non_stat, two_classes_per_block=two_classes_per_block) 55 | cifar10_test_loader = DataLoader(cifar10_test, shuffle=False, 56 | batch_size=config.tasks_eval_batch_sz) 57 | 58 | return (cifar10_training_loader, cifar10_test_loader, cifar10_val_loader), \ 59 | (cifar10_training, cifar10_test, cifar10_val) 60 | 61 | 62 | def get_miniimagenet_loaders(config): 63 | assert (config.data == "miniimagenet") 64 | #config.task_in_dims = (3, 84, 84) 65 | #config.task_out_dims = (100,) 66 | 67 | train_fns = [ 68 | transforms.Resize(84), 69 | transforms.CenterCrop(84), 70 | transforms.ToTensor(), 71 | ] 72 | 73 | transform_train = transforms.Compose(train_fns) 74 | 75 | miniimagenet_training = miniimagenet(root=config.data_path, data_type="train", 76 | transform=transform_train, 77 | non_stat=(not config.stationary), 78 | classes_per_task=config.classes_per_task, 79 | num_iterations=config.num_iterations) 80 | 81 | if not config.stationary: 82 | miniimagenet_training_loader = DataLoader(miniimagenet_training, 83 | sampler=SequentialSampler(miniimagenet_training), 84 | shuffle=False, batch_size=config.tasks_train_batch_sz) 85 | else: 86 | miniimagenet_training_loader = DataLoader(miniimagenet_training, shuffle=True, 87 | batch_size=config.tasks_train_batch_sz) 88 | 89 | test_fns = [ 90 | transforms.Resize(84), 91 | transforms.CenterCrop(84), 92 | transforms.ToTensor(), 93 | ] 94 | 95 | transform_test = transforms.Compose(test_fns) 96 | 97 | evals_non_stat = False 98 | 99 | miniimagenet_val = miniimagenetval(root=config.data_path, transform=transform_test, 100 | non_stat=evals_non_stat, 101 | classes_per_task=config.classes_per_task) 102 | miniimagenet_val_loader = DataLoader(miniimagenet_val, shuffle=False, 103 | batch_size=config.tasks_eval_batch_sz) 104 | 105 | miniimagenet_test = miniimagenet(root=config.data_path, data_type="test", 106 | transform=transform_test, num_iterations=1, 107 | non_stat=evals_non_stat, 108 | classes_per_task=config.classes_per_task) 109 | miniimagenet_test_loader = DataLoader(miniimagenet_test, shuffle=False, 110 | batch_size=config.tasks_eval_batch_sz) 111 | 112 | return (miniimagenet_training_loader, miniimagenet_test_loader, miniimagenet_val_loader), \ 113 | (miniimagenet_training, miniimagenet_test, miniimagenet_val) 114 | 115 | 116 | def get_mnist5k_loaders(config): 117 | assert (config.data == "mnist5k") 118 | #config.task_in_dims = (28 * 28,) 119 | #config.task_out_dims = (10,) 120 | 121 | mnist5k_training = mnist5k(root=config.data_path, data_type="train", 122 | non_stat=(not config.stationary), num_iterations=config.num_iterations, 123 | classes_per_task=config.classes_per_task) 124 | 125 | if not config.stationary: 126 | mnist5k_training_loader = DataLoader(mnist5k_training, 127 | sampler=SequentialSampler(mnist5k_training), 128 | shuffle=False, batch_size=config.tasks_train_batch_sz) 129 | else: 130 | mnist5k_training_loader = DataLoader(mnist5k_training, shuffle=True, 131 | batch_size=config.tasks_train_batch_sz) 132 | 133 | evals_non_stat = False 134 | 135 | mnist5k_val = mnist5k(root=config.data_path, data_type="val", non_stat=evals_non_stat, 136 | num_iterations=1, classes_per_task=config.classes_per_task) 137 | mnist5k_val_loader = DataLoader(mnist5k_val, shuffle=False, batch_size=config.tasks_eval_batch_sz) 138 | 139 | mnist5k_test = mnist5k(root=config.data_path, data_type="test", non_stat=evals_non_stat, 140 | num_iterations=1, classes_per_task=config.classes_per_task) 141 | mnist5k_test_loader = DataLoader(mnist5k_test, shuffle=False, 142 | batch_size=config.tasks_eval_batch_sz) 143 | 144 | return (mnist5k_training_loader, mnist5k_test_loader, mnist5k_val_loader), \ 145 | (mnist5k_training, mnist5k_test, mnist5k_val) 146 | -------------------------------------------------------------------------------- /code/util/eval.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict, defaultdict 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from code.util.general import get_device 7 | 8 | 9 | # accs_data: average over all data (old metric) 10 | # per_label_accs: acc per class 11 | # per_task_accs: acc per task 12 | # accs: average over all seen tasks (after last task, is same as Chaudry def.) 13 | # forgetting: average over all seen tasks 14 | 15 | def evaluate_basic(config, tasks_model, data_loader, t, is_val, last_classes=None, 16 | seen_classes=None, tag=""): 17 | if is_val: 18 | prefix = "val" 19 | else: 20 | prefix = "test" 21 | 22 | prefix = "%s%s" % (tag, prefix) 23 | 24 | tasks_model.eval() 25 | 26 | acc_data = 0. 27 | counts = 0 28 | 29 | num_out = int(np.prod(config.task_out_dims)) 30 | per_label_acc = np.zeros(num_out) 31 | per_label_counts = np.zeros(num_out) 32 | 33 | for x, y in data_loader: 34 | x = x.to(get_device(config.cuda)) 35 | y = y.to(get_device(config.cuda)) 36 | 37 | with torch.no_grad(): 38 | preds = tasks_model(x) 39 | 40 | preds_flat = torch.argmax(preds, dim=1) 41 | 42 | acc_data += (preds_flat == y).sum().item() 43 | counts += y.shape[0] 44 | 45 | for c in range(num_out): 46 | pos = (y == c) 47 | per_label_acc[c] += (pos * (preds_flat == c)).sum().item() 48 | per_label_counts[c] += pos.sum().item() 49 | 50 | # acc over all data 51 | acc_data /= counts 52 | 53 | # acc per class 54 | per_label_counts = np.maximum(per_label_counts, 1) # avoid div 0 55 | per_label_acc /= per_label_counts 56 | 57 | # acc per seen task and avg 58 | acc = None 59 | if hasattr(config, "%s_accs_data" % prefix) and (not config.stationary): # start from after training starts 60 | last_classes = last_classes.cpu().numpy() 61 | seen_classes = seen_classes.cpu().numpy() 62 | 63 | per_task_acc = defaultdict(list) 64 | for c in seen_classes: # seen tasks only 65 | per_task_acc[config.class_dict_tasks[c]].append(per_label_acc[c]) 66 | 67 | acc = 0. 68 | for task_i in per_task_acc: 69 | assert (len(per_task_acc[task_i]) == config.classes_per_task) 70 | per_task_acc[task_i] = np.array(per_task_acc[task_i]).mean() 71 | acc += per_task_acc[task_i] 72 | acc /= len(per_task_acc) 73 | 74 | if not hasattr(config, "%s_accs" % prefix): 75 | setattr(config, "%s_accs_data" % prefix, OrderedDict()) 76 | setattr(config, "%s_per_label_accs" % prefix, OrderedDict()) 77 | 78 | setattr(config, "%s_per_task_accs" % prefix, OrderedDict()) 79 | setattr(config, "%s_accs" % prefix, OrderedDict()) 80 | 81 | setattr(config, "%s_forgetting" % prefix, OrderedDict()) 82 | 83 | getattr(config, "%s_accs_data" % prefix)[t] = acc_data 84 | getattr(config, "%s_per_label_accs" % prefix)[t] = per_label_acc 85 | if acc is not None: 86 | getattr(config, "%s_per_task_accs" % prefix)[t] = per_task_acc 87 | getattr(config, "%s_accs" % prefix)[t] = acc 88 | 89 | # for all previous (excl latest) tasks, find the maximum drop to curr acc 90 | if not config.stationary: 91 | if len(getattr(config, "%s_accs_data" % prefix)) >= 3: # at least 1 previous (non pre training) eval 92 | assert (last_classes is not None) 93 | getattr(config, "%s_forgetting" % prefix)[t] = compute_forgetting(config, t, 94 | getattr(config, 95 | "%s_per_task_accs" % prefix), 96 | last_classes) 97 | 98 | 99 | def compute_forgetting(config, t, per_task_accs, last_classes): 100 | # per_task_acc is not equal length per timestep so can't array 101 | 102 | assert (t % config.eval_freq == 0) 103 | 104 | # find task that just finished 105 | last_task_i = None 106 | for c in last_classes: 107 | task_i = config.class_dict_tasks[c] 108 | if last_task_i is None: 109 | last_task_i = task_i 110 | else: 111 | assert (last_task_i == task_i) 112 | 113 | forgetting_per_task = {} 114 | for task_i in range(last_task_i): # excl last (tasks are numbered chronologically) 115 | best_acc = None 116 | for past_t in per_task_accs: 117 | if past_t == 0: continue # not used 118 | if past_t == t: continue 119 | 120 | if task_i in per_task_accs[past_t]: 121 | if best_acc is None or per_task_accs[past_t][task_i] > best_acc: 122 | best_acc = per_task_accs[past_t][task_i] 123 | assert (best_acc is not None) 124 | 125 | forgetting_per_task[task_i] = best_acc - per_task_accs[t][task_i] 126 | 127 | assert (len(forgetting_per_task) == last_task_i) 128 | return np.array(list(forgetting_per_task.values())).mean() 129 | -------------------------------------------------------------------------------- /code/util/general.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import pickle 3 | 4 | import matplotlib 5 | import numpy as np 6 | import torch 7 | 8 | matplotlib.use('Agg') 9 | import matplotlib.pyplot as plt 10 | from collections import OrderedDict 11 | from copy import deepcopy 12 | 13 | 14 | def get_device(cuda): 15 | if cuda: 16 | return torch.device("cuda:0") 17 | else: 18 | return torch.device("cpu") 19 | 20 | 21 | def store(config): 22 | # anonymization 23 | data_path = config.data_path 24 | out_root = config.out_root 25 | out_dir = config.out_dir 26 | 27 | config.data_path = "" 28 | config.out_root = "" 29 | config.out_dir = "" 30 | 31 | with open(osp.join(out_dir, "config.pickle"), 32 | 'wb') as outfile: 33 | pickle.dump(config, outfile) 34 | 35 | with open(osp.join(out_dir, "config.txt"), 36 | "w") as text_file: 37 | text_file.write("%s" % config) 38 | 39 | config.data_path = data_path 40 | config.out_root = out_root 41 | config.out_dir = out_dir 42 | 43 | 44 | def get_avg_grads(model): 45 | total = None 46 | count = 0 47 | for p in model.parameters(): 48 | sz = np.prod(p.grad.shape) 49 | grad_sum = p.grad.abs().sum() 50 | if total is None: 51 | total = grad_sum 52 | count = sz 53 | else: 54 | total += grad_sum 55 | count += sz 56 | 57 | return total / float(count) 58 | 59 | 60 | def record_and_check(config, name, val, t): 61 | if hasattr(val, "item"): 62 | val = val.item() 63 | 64 | record(config, name, val, t) 65 | if not np.isfinite(val): 66 | print("value (probably loss) not finite, aborting:") 67 | print(name) 68 | print("t %d" % t) 69 | print(val) 70 | store(config) # to store nan values 71 | exit(1) 72 | 73 | 74 | def check(config, val, t): 75 | if not np.isfinite(val): 76 | print("value (probably loss) not finite, aborting:") 77 | print("t %d" % t) 78 | print(val) 79 | store(config) # to store nan values 80 | exit(1) 81 | 82 | 83 | def record(config, val_name, val, t, abs=False): 84 | if not hasattr(config, val_name): 85 | setattr(config, val_name, OrderedDict()) 86 | 87 | storage = getattr(config, val_name) 88 | 89 | if "torch" in str(val.__class__): 90 | if abs: 91 | val = torch.abs(val) 92 | 93 | if val.dtype == torch.int64: 94 | assert (val.shape == torch.Size([])) 95 | else: 96 | val = torch.mean(val) 97 | 98 | storage[t] = val.item() # either scalar loss or vector of grads 99 | else: 100 | if abs: 101 | val = abs(val) # default abs 102 | 103 | storage[t] = val 104 | 105 | if not hasattr(config, "record_names"): 106 | config.record_names = [] 107 | 108 | if not val_name in config.record_names: 109 | config.record_names.append(val_name) 110 | 111 | 112 | def get_gpu_mem(nvsmi, gpu_ind): 113 | mem_stats = nvsmi.DeviceQuery('memory.free, memory.total')["gpu"][gpu_ind]["fb_memory_usage"] 114 | return mem_stats["total"] - mem_stats["free"] 115 | 116 | 117 | def get_task_in_dims(config): 118 | return config.task_in_dims 119 | 120 | 121 | def get_task_out_dims(config): 122 | return config.task_out_dims 123 | 124 | 125 | def render_graphs(config): 126 | if not hasattr(config, "record_names"): 127 | return 128 | 129 | training_val_names = config.record_names 130 | fig0, axarr0 = plt.subplots(max(len(training_val_names), 2), sharex=False, 131 | figsize=(8, len(training_val_names) * 4)) 132 | 133 | for i, val_name in enumerate(training_val_names): 134 | if hasattr(config, val_name): 135 | storage = getattr(config, val_name) 136 | axarr0[i].clear() 137 | axarr0[i].plot(list(storage.keys()), list(storage.values())) # ordereddict 138 | 139 | axarr0[i].set_title(val_name) 140 | 141 | fig0.suptitle("Model %d" % (config.model_ind), fontsize=8) 142 | fig0.savefig(osp.join(config.out_dir, "plots_0.png")) 143 | 144 | if hasattr(config, "val_accs"): 145 | fig1, axarr1 = plt.subplots(4, sharex=False, figsize=(8, 4 * 4)) 146 | 147 | for pi, prefix in enumerate(["val", "test"]): 148 | accs_name = "%s_accs" % prefix 149 | axarr1[pi * 2].clear() 150 | axarr1[pi * 2].plot(list(getattr(config, accs_name).keys()), 151 | list(getattr(config, accs_name).values())) # ordereddict 152 | axarr1[pi * 2].set_title(accs_name) 153 | 154 | per_label_accs_name = "%s_per_label_accs" % prefix 155 | axarr1[pi * 2 + 1].clear() 156 | per_label_accs_t = getattr(config, per_label_accs_name).keys() 157 | per_label_accs_np = np.array(list(getattr(config, per_label_accs_name).values())) 158 | for c in range(int(np.prod(get_task_out_dims(config)))): 159 | axarr1[pi * 2 + 1].plot(list(per_label_accs_t), list(per_label_accs_np[:, c]), label=str(c)) 160 | axarr1[pi * 2 + 1].legend() 161 | axarr1[pi * 2 + 1].set_title(per_label_accs_name) 162 | 163 | fig1.suptitle("Model %d" % (config.model_ind), fontsize=8) 164 | fig1.savefig(osp.join(config.out_dir, "plots_1.png")) 165 | 166 | # render predictions, if exist 167 | if hasattr(config, "aux_y_probs"): 168 | # time along x axis, classes along y axis 169 | fig2, ax2 = plt.subplots(1, figsize=(16, 8)) # width, height 170 | 171 | num_t = len(config.aux_y_probs) 172 | num_classes = int(np.prod(get_task_out_dims(config))) 173 | 174 | aux_y_probs = list(config.aux_y_probs.values()) 175 | aux_y_probs = [aux_y_prob.numpy() for aux_y_prob in aux_y_probs] 176 | aux_y_probs = np.array(aux_y_probs) 177 | 178 | # print(aux_y_probs.shape) 179 | assert (aux_y_probs.shape == (len(config.aux_y_probs), int(np.prod(get_task_out_dims(config))))) 180 | 181 | aux_y_probs = aux_y_probs.transpose() # now num classes, time 182 | min_val = aux_y_probs.min() 183 | max_val = aux_y_probs.max() 184 | 185 | # tile along y axis to make each class fatter. Should be same number of pixels altogether as 186 | # current t / 2 187 | scale = int(0.5 * float(num_t) / num_classes) 188 | if scale > 1: 189 | aux_y_probs = np.repeat(aux_y_probs, scale, axis=0) 190 | ax2.set_yticks(np.arange(num_classes) * scale) 191 | ax2.set_yticklabels(np.arange(num_classes)) 192 | 193 | num_thousands = int(num_t / 1000) 194 | ax2.set_xticks(np.arange(num_thousands) * 1000) 195 | ax2.set_xticklabels(np.arange(num_thousands) * 1000 + list(config.aux_y_probs.keys())[0]) 196 | 197 | im = ax2.imshow(aux_y_probs) 198 | fig2.colorbar(im, ax=ax2) 199 | # ax2.colorbar() 200 | 201 | fig2.suptitle("Model %d, max %f min %f" % (config.model_ind, max_val, min_val), fontsize=8) 202 | fig2.savefig(osp.join(config.out_dir, "plots_2.png")) 203 | 204 | plt.close("all") 205 | 206 | 207 | def trim_config(config, next_t): 208 | # trim everything down to next_t numbers 209 | # we are starting at top of loop *before* eval step 210 | 211 | for val_name in config.record_names: 212 | storage = getattr(config, val_name) 213 | if isinstance(storage, list): 214 | assert (len(storage) >= (next_t)) 215 | setattr(config, val_name, storage[:next_t]) 216 | else: 217 | assert (isinstance(storage, OrderedDict)) 218 | storage_copy = deepcopy(storage) 219 | for k, v in storage.items(): 220 | if k >= next_t: 221 | del storage_copy[k] 222 | setattr(config, val_name, storage_copy) 223 | 224 | for prefix in ["val", "test"]: 225 | accs_storage = getattr(config, "%s_accs" % prefix) 226 | per_label_accs_storage = getattr(config, "%s_per_label_accs" % prefix) 227 | 228 | if isinstance(accs_storage, list): 229 | assert (isinstance(per_label_accs_storage, list)) 230 | assert (len(accs_storage) >= (next_t) and len(per_label_accs_storage) >= ( 231 | next_t)) # at least next_t stored 232 | 233 | setattr(config, "%s_accs" % prefix, accs_storage[:next_t]) 234 | setattr(config, "%s_per_label_accs" % prefix, per_label_accs_storage[:next_t]) 235 | else: 236 | assert ( 237 | isinstance(accs_storage, OrderedDict) and isinstance(per_label_accs_storage, OrderedDict)) 238 | for dn, d in [("accs", accs_storage), ("per_label_accs", per_label_accs_storage)]: 239 | d_copy = deepcopy(d) 240 | for k, v in d.items(): 241 | if k >= next_t: 242 | del d_copy[k] 243 | setattr(config, dn, d_copy) 244 | 245 | # deal with window 246 | if config.long_window: 247 | # find index of first historical t for update >= next_t 248 | # set config.next_update_old_model_t = that t 249 | # trim history behind it, backing onto nest_t 250 | 251 | next_t_i = None 252 | for i, update_t in enumerate(config.next_update_old_model_t_history): 253 | if update_t > next_t: 254 | next_t_i = i 255 | break 256 | 257 | # there must be a t in update history that is greater than next_t 258 | # unless config.next_update_old_model_t >= next_t and we stopped before it was added to history 259 | # in which case we don't need to trim any update history 260 | 261 | if next_t_i is None: 262 | print("no trimming:") 263 | print(("config.next_update_old_model_t", config.next_update_old_model_t)) 264 | print(("next_t", next_t)) 265 | assert (config.next_update_old_model_t >= next_t) 266 | else: 267 | config.next_update_old_model_t = config.next_update_old_model_t_history[next_t_i] 268 | config.next_update_old_model_t_history = config.next_update_old_model_t_history[:next_t_i] 269 | 270 | 271 | def sum_seq(seq): 272 | res = None 273 | for elem in seq: 274 | if res is None: 275 | res = elem 276 | else: 277 | res += elem 278 | return res 279 | 280 | 281 | def np_rand_seed(): # fixed classes shuffling 282 | return 111 283 | 284 | 285 | def reproc_settings(config): 286 | np.random.seed(0) # set separately when shuffling data too 287 | if config.specific_torch_seed: 288 | torch.manual_seed(config.torch_seed) 289 | else: 290 | torch.manual_seed(config.model_ind) # allow initialisations different per model 291 | 292 | torch.backends.cudnn.deterministic = True 293 | torch.backends.cudnn.benchmark = False 294 | 295 | 296 | def copy_parameter_values(from_model, to_model): 297 | to_params = list(to_model.named_parameters()) 298 | assert (isinstance(to_params[0], tuple) and len(to_params[0]) == 2) 299 | 300 | to_params = dict(to_params) 301 | 302 | for n, p in from_model.named_parameters(): 303 | to_params[n].data.copy_(p.data) # not clone 304 | 305 | 306 | def make_valid_from_train(dataset, cut): 307 | tr_ds, val_ds = [], [] 308 | for task_ds in dataset: 309 | x_t, y_t = task_ds 310 | 311 | # shuffle before splitting 312 | perm = torch.randperm(len(x_t)) 313 | x_t, y_t = x_t[perm], y_t[perm] 314 | 315 | split = int(len(x_t) * cut) 316 | x_tr, y_tr = x_t[:split], y_t[:split] 317 | x_val, y_val = x_t[split:], y_t[split:] 318 | 319 | tr_ds += [(x_tr, y_tr)] 320 | val_ds += [(x_val, y_val)] 321 | 322 | return tr_ds, val_ds 323 | 324 | 325 | def invert_dict(dict_to_invert): 326 | new_dict = {} 327 | for k, vs in dict_to_invert.items(): 328 | for v in vs: 329 | new_dict[v] = k 330 | return new_dict -------------------------------------------------------------------------------- /code/util/load.py: -------------------------------------------------------------------------------- 1 | from .data import get_data 2 | from .general import get_device, invert_dict 3 | from code.models import * 4 | 5 | def get_model_and_data(config): 6 | config.task_in_dims = {"mnist5k": (28 * 28,), "miniimagenet": (3, 84, 84), "cifar10": (3, 32, 32)}[config.data] 7 | config.task_out_dims = {"mnist5k": (10,), "miniimagenet": (100,), "cifar10": (10,)}[config.data] 8 | 9 | tasks_model = globals()[config.task_model_type](config).to(get_device(config.cuda)) 10 | 11 | trainloader, testloader, valloader = get_data(config) 12 | 13 | config.class_dict_tasks = invert_dict(trainloader.dataset.task_dict_classes) 14 | 15 | return tasks_model, trainloader, testloader, valloader 16 | -------------------------------------------------------------------------------- /code/util/losses.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | def TV(x): 8 | if len(x.shape) == 2: # mnist 9 | side = int(x.shape[1] ** 0.5) 10 | assert (side == 28) # sanity 11 | x = x.view(-1, 1, side, side) 12 | 13 | batch_sz = x.shape[0] 14 | h_x = x.shape[2] 15 | w_x = x.shape[3] 16 | count_h = batch_sample_size( 17 | x[:, :, 1:, :]) # num points in 1 image inc channels except 1 row missing 18 | count_w = batch_sample_size(x[:, :, :, 1:]) 19 | h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum() 20 | w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum() 21 | return 2 * (h_tv / count_h + w_tv / count_w) / batch_sz 22 | 23 | 24 | def batch_sample_size(x): 25 | return x.shape[1] * x.shape[2] * x.shape[3] 26 | 27 | 28 | def L2(x): 29 | if len(x.shape) == 2: # mnist 30 | side = int(x.shape[1] ** 0.5) 31 | assert (side == 28) # sanity 32 | x = x.view(-1, 1, side, side) 33 | 34 | sz = batch_sample_size(x) 35 | x = x.view(x.shape[0], -1) 36 | res = torch.norm(x, p=2, dim=1).sum() # this is not size normalised! 37 | return res / (x.shape[0] * sz) 38 | 39 | 40 | def sharpen_class(aux_preds): 41 | # minimise cross entropy with its own argmax 42 | hard_targets = aux_preds.argmax(dim=1).detach() 43 | return F.cross_entropy(aux_preds, hard_targets, reduction="mean") 44 | 45 | 46 | def img_distance(aux_x, xs): 47 | aux_x_exp = aux_x.repeat(xs.shape[0], 1, 1, 1) 48 | xs_exp = xs.repeat_interleave(aux_x.shape[0], dim=0) 49 | assert (aux_x_exp.shape[1:] == xs_exp.shape[1:]) 50 | assert (aux_x.shape[1:] == aux_x_exp.shape[1:]) 51 | 52 | assert (aux_x_exp.shape[0] == aux_x.shape[0] * xs.shape[0]) 53 | assert (xs_exp.shape[0] == aux_x.shape[0] * xs.shape[0]) 54 | 55 | # maximise the distance 56 | return -F.mse_loss(aux_x_exp, xs_exp, reduction="mean") 57 | 58 | 59 | def diversity_raw(preds): 60 | # maximise the entropy, use avg of logs 61 | assert (len(preds.shape) == 2) 62 | cross_batch_preds = preds.mean(dim=0, keepdim=True) 63 | entrop = crossent_logits(cross_batch_preds, cross_batch_preds) 64 | return -entrop 65 | 66 | 67 | def diversity(preds, EPS=sys.float_info.epsilon): 68 | # maximise the entropy, use avg after softmax!! 69 | assert (len(preds.shape) == 2) 70 | 71 | preds = F.softmax(preds, dim=1) 72 | cross_batch_preds = preds.mean(dim=0) 73 | cross_batch_preds[(cross_batch_preds < EPS).detach()] = EPS 74 | loss = (cross_batch_preds * torch.log(cross_batch_preds)).sum() # minimise negative entropy 75 | return loss 76 | 77 | 78 | def notlocal(preds, present_classes): 79 | # by minimizing this, maximise distance with present classes 80 | 81 | num_samples, _ = preds.shape 82 | assert (len(present_classes.shape) == 1) 83 | 84 | cum_losses = preds.new_zeros(num_samples) 85 | for c in present_classes: 86 | c_targets = preds.new_full((num_samples,), c, dtype=torch.long) 87 | y_loss = - F.cross_entropy(preds, c_targets, reduction="none") 88 | cum_losses += y_loss 89 | 90 | loss = cum_losses.mean() / present_classes.shape[0] 91 | assert (loss.requires_grad) 92 | 93 | return loss 94 | 95 | 96 | def crossent_logits(preds, targets): 97 | # targets are not yet softmaxed probabilities 98 | assert (len(targets.shape) == 2) 99 | targets = F.softmax(targets, dim=1) 100 | 101 | num_samples, num_outputs = preds.shape 102 | cum_losses = preds.new_zeros(num_samples) 103 | 104 | for c in range(num_outputs): 105 | # grads should still bp back through pred and target if nec 106 | target_temp = preds.new_full((num_samples,), c, dtype=torch.long) # filled with c 107 | y_loss = F.cross_entropy(preds, target_temp, reduction="none") 108 | cum_losses += targets[:, c] * y_loss # weight each one by prob c given by actual target 109 | 110 | loss = cum_losses.mean() 111 | assert (loss.requires_grad) 112 | return loss 113 | 114 | 115 | def deep_inversion_classes_loss(classes_to_refine, aux_preds_old): 116 | # try to get characteristic images for these classes 117 | assert (classes_to_refine.shape == (aux_preds_old.shape[0],)) 118 | # aux preds old not softmaxed 119 | return F.cross_entropy(aux_preds_old, classes_to_refine, reduction="mean") 120 | 121 | 122 | def neg_symmetric_KL(aux_preds_old, aux_preds_new): 123 | assert (len(aux_preds_old.shape) == 2) 124 | assert (len(aux_preds_new.shape) == 2) 125 | assert (aux_preds_new.shape == aux_preds_old.shape) 126 | 127 | # inputs logsoftmax, targets softmax 128 | avg_preds = 0.5 * (F.softmax(aux_preds_old, dim=1) + F.softmax(aux_preds_new, dim=1)) 129 | 130 | return 1 - 0.5 * ( 131 | F.kl_div(F.log_softmax(aux_preds_old, dim=1), avg_preds, reduction="batchmean") + 132 | F.kl_div(F.log_softmax(aux_preds_new, dim=1), avg_preds, reduction="batchmean")) 133 | 134 | 135 | def avoid_unseen_classes(aux_preds_old, seen_classes): 136 | # max KL div with each of seen classes 137 | num_samples, num_outputs = aux_preds_old.shape 138 | 139 | total = None 140 | for c in seen_classes: 141 | target_temp = aux_preds_old.new_full((num_samples,), c, dtype=torch.long) # filled with c 142 | y_loss = F.cross_entropy(aux_preds_old, target_temp, reduction="mean") 143 | 144 | if total is None: 145 | total = y_loss 146 | else: 147 | total += y_loss 148 | 149 | return - total / seen_classes.shape[0] # avg and negate, because encouraging separation 150 | -------------------------------------------------------------------------------- /code/util/render.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import matplotlib 4 | import numpy as np 5 | import torch 6 | 7 | matplotlib.use('Agg') 8 | import matplotlib.pyplot as plt 9 | from colorsys import hsv_to_rgb 10 | 11 | 12 | def render_aux_x(config, t, r, aux_x_orig, aux_x, aux_y, present_classes): 13 | if len(aux_x_orig.shape) == 2: 14 | assert (config.data == "mnist5k") 15 | side = int(aux_x_orig.shape[1] ** 0.5) 16 | aux_x_orig = aux_x_orig.view(-1, side, side) 17 | if not aux_x is None: 18 | aux_x = aux_x.view(-1, side, side) 19 | 20 | render_out_dir = os.path.join(config.out_dir, "render") 21 | if not os.path.exists(render_out_dir): 22 | os.makedirs(render_out_dir) 23 | 24 | if not aux_x is None: 25 | # render a random selection of render_aux_x_num samples 26 | 27 | if not config.render_separate: 28 | fig0, axarr0 = plt.subplots(2, max(config.render_aux_x_num, 2), sharex=False, 29 | figsize=(config.render_aux_x_num * 4, 2 * 4)) 30 | 31 | num_aux_x = aux_x.shape[0] 32 | assert (num_aux_x == config.M) 33 | assert (config.render_aux_x_num <= num_aux_x) 34 | selected = np.random.choice(num_aux_x, config.render_aux_x_num, replace=False) 35 | 36 | s_aux_x = aux_x[selected] 37 | s_aux_y = aux_y[selected] 38 | s_aux_x_orig = aux_x_orig[selected] 39 | if config.render_aux_x_num == 1: 40 | s_aux_x = s_aux_x.unsqueeze(dim=0) 41 | s_aux_x_orig = s_aux_x_orig.unsqueeze(dim=0) 42 | s_aux_y = s_aux_y.unsqueeze(dim=0) 43 | 44 | s_aux_y = torch.nn.functional.softmax(s_aux_y, dim=1) # pre softmax 45 | 46 | diff_sum = (s_aux_x_orig - s_aux_x).abs().sum().item() / config.render_aux_x_num 47 | 48 | for i in range(config.render_aux_x_num): 49 | img_orig = s_aux_x_orig[i] 50 | img = s_aux_x[i] 51 | 52 | if len(s_aux_x_orig.shape) == 4: 53 | img_orig = img_orig.permute(1, 2, 0) 54 | img = img.permute(1, 2, 0) # h, w, 3 55 | 56 | orig_min, orig_max = img.min(), img.max() 57 | img = img - orig_min 58 | img = img / orig_max 59 | 60 | top_class = s_aux_y[i].argmax() 61 | 62 | if not config.render_separate: 63 | axarr0[0, i].imshow(img.cpu().numpy()) 64 | 65 | vals = ( 66 | top_class.item(), s_aux_y[i, top_class].item(), str(list(present_classes.cpu().numpy())), 67 | orig_min.item(), orig_max.item()) 68 | 69 | axarr0[0, i].set_title("top c %d: val %f, pres %s, img min %f max %f" % vals, fontsize=6) 70 | axarr0[1, i].imshow(img_orig.cpu().numpy()) 71 | else: 72 | fig_curr_orig, ax_curr_orig = plt.subplots(1, figsize=(4, 4)) 73 | ax_curr_orig.imshow(img_orig.cpu().numpy()) 74 | # ax_curr_orig.set_axis_off() # barebones 75 | 76 | ax_curr_orig.axis('off') 77 | fig_curr_orig.patch.set_visible(False) 78 | ax_curr_orig.patch.set_visible(False) 79 | 80 | fig_curr_orig.savefig( 81 | os.path.join(render_out_dir, "m_%d_t_%d_%r_i_%d_orig.png" % (config.model_ind, t, r, i)), 82 | bbox_inches=0) 83 | 84 | fig_curr_aux, ax_curr_aux = plt.subplots(1, figsize=(4, 4)) 85 | ax_curr_aux.imshow(img.cpu().numpy()) 86 | # ax_curr_aux.set_axis_off() # barebones 87 | 88 | ax_curr_aux.axis('off') 89 | fig_curr_aux.patch.set_visible(False) 90 | ax_curr_aux.patch.set_visible(False) 91 | 92 | fig_curr_aux.savefig(os.path.join(render_out_dir, "m_%d_t_%d_%r_i_%d_aux_%d_%.3f.png" % 93 | (config.model_ind, t, r, i, top_class.item(), 94 | s_aux_y[i, top_class].item())), bbox_inches=0) 95 | 96 | plt.close("all") 97 | 98 | if not config.render_separate: 99 | fig0.suptitle("Model %d t %d r %d, diffs %f" % (config.model_ind, t, r, diff_sum), fontsize=8) 100 | fig0.savefig( 101 | os.path.join(render_out_dir, "render_m_%d_t_%d_r_%d.png" % (config.model_ind, t, r))) 102 | plt.close("all") 103 | else: 104 | # just render original batch 105 | for i in range(config.render_aux_x_num): 106 | img_orig = aux_x_orig[i] 107 | if len(img_orig.shape) == 3: 108 | img_orig = img_orig.permute(1, 2, 0) # channels last 109 | 110 | fig_curr_orig, ax_curr_orig = plt.subplots(1, figsize=(4, 4)) 111 | ax_curr_orig.imshow(img_orig.cpu().numpy()) 112 | # ax_curr_orig.set_axis_off() # barebones 113 | 114 | ax_curr_orig.axis('off') 115 | fig_curr_orig.patch.set_visible(False) 116 | ax_curr_orig.patch.set_visible(False) 117 | 118 | fig_curr_orig.savefig( 119 | os.path.join(render_out_dir, "m_%d_t_%d_%r_i_%d_orig.png" % (config.model_ind, t, r, i)), 120 | bbox_inches=0) 121 | 122 | 123 | def get_colours(n): 124 | hues = np.linspace(0.0, 1.0, n + 1)[0:-1] # ignore last one 125 | all_colours = [np.array(hsv_to_rgb(hue, 0.75, 0.75)) for hue in hues] 126 | return all_colours 127 | -------------------------------------------------------------------------------- /commands.txt: -------------------------------------------------------------------------------- 1 | # ------------------------------------------ 2 | To print results summary after running commands: 3 | # ------------------------------------------ 4 | Experiments are named by arbitrary integers for easy reference. 5 | 6 | E.g. ARM CIFAR10 7 | python -m code.scripts.print_results --root /scratch/shared/nfs1/xuji/ARM --start 3717 8 | 9 | average val: acc 0.2586 +- 0.0145, forgetting 0.1046 +- 0.0330 10 | average test: acc 0.2687 +- 0.0107, forgetting 0.0959 +- 0.0371 11 | 12 | Print all results: 13 | python -m code.scripts.print_table --root /scratch/shared/nfs1/xuji/ARM 14 | 15 | 16 | # ------------------------------------------ 17 | Tables 1 - 3 18 | # ------------------------------------------ 19 | 20 | ARM 21 | MNIST 22 | nohup python -m code.scripts.ARM --model_ind_start 6579 --num_runs 5 --data mnist5k --lr 0.05 --task_model_type mlp --classes_per_task 2 --recall_from_t 100 --num_iterations 1 --M 10 --refine_sample_steps 10 --refine_sample_lr 25.0 --divergence_loss_weight 1.0 --L2 --L2_weight 1.0 --TV --TV_weight 1.0 --long_window --use_fixed_window --fixed_window 100 --sharpen_class --sharpen_class_weight 0.1 --notlocal_weight 1.0 --notlocal_new_weight 0.1 --diversity_weight 16.0 --aux_distill --aux_distill_weight 1.0 --max_t 500 --store_model_freq 500 --store_results_freq 100 --eval_freq 100 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/ > out/m6579s.out & 23 | 24 | CIFAR10 25 | nohup python -m code.scripts.ARM --model_ind_start 3717 --num_runs 5 --data cifar10 --lr 0.01 --task_model_type resnet18 --classes_per_task 2 --recall_from_t 950 --num_iterations 1 --M 100 --refine_sample_steps 10 --refine_sample_lr 10.0 --divergence_loss_weight 1.0 --L2 --L2_weight 1.0 --TV --TV_weight 1.0 --long_window --use_fixed_window --fixed_window 950 --sharpen_class --sharpen_class_weight 0.1 --notlocal_weight 1.0 --notlocal_new_weight 0.1 --diversity_weight 16.0 --aux_distill --aux_distill_weight 1.0 --max_t 4750 --store_model_freq 4750 --store_results_freq 950 --eval_freq 950 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/CIFAR > out/m3717s.out & 26 | 27 | MiniImageNet 28 | nohup python -m code.scripts.ARM --model_ind_start 4821 --num_runs 5 --data miniimagenet --lr 0.01 --task_model_type resnet18 --classes_per_task 5 --recall_from_t 684 --num_iterations 3 --M 100 --refine_sample_steps 10 --refine_sample_lr 10.0 --divergence_loss_weight 1.0 --L2 --L2_weight 1.0 --TV --TV_weight 1.0 --long_window --use_fixed_window --fixed_window 684 --sharpen_class --sharpen_class_weight 1.0 --notlocal_weight 1.0 --notlocal_new_weight 0.1 --diversity_weight 1.0 --aux_distill --aux_distill_weight 2.0 --max_t 13680 --store_model_freq 13680 --store_results_freq 684 --eval_freq 684 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/MINIIMAGENET/ > out/m4821.out & 29 | 30 | ADI 31 | MNIST 32 | nohup python -m code.scripts.ADI --model_ind_start 6232 --num_runs 5 --data mnist5k --lr 0.05 --task_model_type mlp --classes_per_task 2 --recall_from_t 100 --num_iterations 1 --M 10 --refine_theta_steps 1 --refine_sample_steps 10 --refine_sample_lr 25.0 --L2 --L2_weight 1.0 --TV --TV_weight 1.0 --long_window --use_fixed_window --fixed_window 100 --classes_loss_weight 1.0 --choose_past_classes --adaptive --adaptive_weight 1.0 --aux_distill_weight 0.5 --max_t 500 --store_model_freq 500 --store_results_freq 100 --eval_freq 100 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/ > out/m6232s.out & 33 | 34 | CIFAR10 35 | nohup python -m code.scripts.ADI --model_ind_start 2262 --num_runs 5 --data cifar10 --lr 0.01 --task_model_type resnet18_batch_stats --classes_per_task 2 --recall_from_t 950 --num_iterations 1 --M 100 --refine_theta_steps 1 --refine_sample_steps 10 --refine_sample_lr 10.0 --L2 --L2_weight 1.0 --TV --TV_weight 1.0 --long_window --use_fixed_window --fixed_window 950 --classes_loss_weight 1.0 --opt_batch_stats --opt_batch_stats_weight 0.1 --choose_past_classes --adaptive --adaptive_weight 1.0 --max_t 4750 --store_model_freq 4750 --store_results_freq 950 --eval_freq 950 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/CIFAR > out/m2262s.out & 36 | 37 | MiniImageNet 38 | nohup python -m code.scripts.ADI --model_ind_start 6042 --num_runs 5 --data miniimagenet --lr 0.01 --task_model_type resnet18_batch_stats --classes_per_task 5 --recall_from_t 684 --num_iterations 3 --M 100 --refine_theta_steps 1 --refine_sample_steps 10 --refine_sample_lr 10.0 --L2 --L2_weight 1.0 --TV --TV_weight 1.0 --long_window --use_fixed_window --fixed_window 684 --classes_loss_weight 1.0 --opt_batch_stats --opt_batch_stats_weight 0.1 --choose_past_classes --adaptive --adaptive_weight 1.0 --aux_distill_weight 2.0 --max_t 13680 --store_model_freq 13680 --store_results_freq 684 --eval_freq 684 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/MINIIMAGENET/ > out/m6042s.out & 39 | 40 | Distill 41 | MNIST 42 | nohup python -m code.scripts.distill --model_ind_start 6102 --num_runs 5 --data mnist5k --lr 0.05 --task_model_type mlp --classes_per_task 2 --recall_from_t 100 --num_iterations 1 --long_window --use_fixed_window --fixed_window 100 --aux_distill_weight 1.0 --max_t 500 --store_model_freq 500 --store_results_freq 100 --eval_freq 100 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/ > out/m6102s.out & 43 | 44 | CIFAR10 45 | nohup python -m code.scripts.distill --model_ind_start 4082 --num_runs 5 --data cifar10 --lr 0.01 --task_model_type resnet18 --classes_per_task 2 --recall_from_t 950 --num_iterations 1 --long_window --use_fixed_window --fixed_window 950 --aux_distill_weight 1.0 --max_t 4750 --store_model_freq 4750 --store_results_freq 950 --eval_freq 950 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/CIFAR > out/m4082s.out & 46 | 47 | MiniImageNet 48 | nohup python -m code.scripts.distill --model_ind_start 6092 --num_runs 5 --data miniimagenet --lr 0.01 --task_model_type resnet18 --classes_per_task 5 --recall_from_t 684 --num_iterations 3 --long_window --use_fixed_window --fixed_window 684 --aux_distill_weight 2.0 --max_t 13680 --store_model_freq 13680 --store_results_freq 684 --eval_freq 684 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/MINIIMAGENET/ > out/m6092s.out & 49 | 50 | Naive 51 | MNIST 52 | nohup python -m code.scripts.naive --model_ind_start 4967 --num_runs 5 --data mnist5k --lr 0.05 --task_model_type mlp --classes_per_task 2 --num_iterations 1 --max_t 500 --store_model_freq 500 --store_results_freq 100 --eval_freq 100 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/ > out/m4967s.out & 53 | 54 | CIFAR10 55 | nohup python -m code.scripts.naive --model_ind_start 2522 --num_runs 5 --data cifar10 --lr 0.1 --task_model_type resnet18 --classes_per_task 2 --num_iterations 1 --max_t 4750 --store_model_freq 4750 --store_results_freq 950 --eval_freq 950 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/CIFAR > out/m2522s.out & 56 | 57 | MiniImageNet 58 | nohup python -m code.scripts.naive --model_ind_start 4557 --num_runs 5 --data miniimagenet --classes_per_task 5 --lr 0.1 --task_model_type resnet18 --num_iterations 3 --max_t 13680 --store_model_freq 13680 --store_results_freq 684 --eval_freq 684 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/MINIIMAGENET/ > out/m4557s.out & 59 | 60 | Naive stationary 61 | MNIST 62 | nohup python -m code.scripts.naive --model_ind_start 4947 --num_runs 5 --data mnist5k --lr 0.05 --task_model_type mlp --classes_per_task 2 --stationary --num_iterations 1 --max_t 500 --store_model_freq 500 --store_results_freq 100 --eval_freq 100 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/ > out/m4947s.out & 63 | 64 | CIFAR10 65 | nohup python -m code.scripts.naive --model_ind_start 6462 --num_runs 5 --data cifar10 --lr 0.1 --task_model_type resnet18 --classes_per_task 2 --stationary --num_iterations 1 --max_t 4750 --store_model_freq 4750 --store_results_freq 950 --eval_freq 950 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/CIFAR > out/m6462s.out & 66 | 67 | MiniImageNet 68 | nohup python -m code.scripts.naive --model_ind_start 4527 --num_runs 5 --data miniimagenet --classes_per_task 5 --stationary --lr 0.1 --task_model_type resnet18 --num_iterations 3 --max_t 13680 --store_model_freq 13680 --store_results_freq 684 --eval_freq 684 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/MINIIMAGENET/ > out/m4527s.out & 69 | 70 | # ------------------------------------------ 71 | Table 4 72 | # ------------------------------------------ 73 | 74 | Distill unit lag 75 | nohup python -m code.scripts.distill --model_ind_start 6452 --num_runs 5 --data cifar10 --lr 0.01 --task_model_type resnet18 --classes_per_task 2 --recall_from_t 950 --num_iterations 1 --aux_distill_weight 1.0 --max_t 4750 --store_model_freq 4750 --store_results_freq 950 --eval_freq 950 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/CIFAR > out/m6452s.out & 76 | 77 | ADI unit lag 78 | nohup python -m code.scripts.ADI --model_ind_start 6467 --num_runs 5 --data cifar10 --lr 0.01 --task_model_type resnet18_batch_stats --classes_per_task 2 --recall_from_t 950 --num_iterations 1 --M 100 --refine_theta_steps 1 --refine_sample_steps 10 --refine_sample_lr 10.0 --L2 --L2_weight 1.0 --TV --TV_weight 1.0 --classes_loss_weight 1.0 --opt_batch_stats --opt_batch_stats_weight 0.1 --choose_past_classes --adaptive --adaptive_weight 8.0 --max_t 4750 --store_model_freq 4750 --store_results_freq 950 --eval_freq 950 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/CIFAR > out/m6467s.out & 79 | 80 | ADI no distill 81 | nohup python -m code.scripts.ADI --model_ind_start 2327 --num_runs 5 --data cifar10 --lr 0.01 --task_model_type resnet18_batch_stats --classes_per_task 2 --recall_from_t 950 --num_iterations 1 --M 100 --refine_theta_steps 1 --refine_sample_steps 10 --refine_sample_lr 10.0 --L2 --L2_weight 1.0 --TV --TV_weight 1.0 --long_window --use_fixed_window --fixed_window 950 --classes_loss_weight 1.0 --opt_batch_stats --opt_batch_stats_weight 0.1 --choose_past_classes --adaptive --adaptive_weight 1.0 --no_aux_distill --max_t 4750 --store_model_freq 4750 --store_results_freq 950 --eval_freq 950 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/CIFAR > out/m2327s.out & 82 | 83 | ARM unit lag 84 | nohup python -m code.scripts.ARM --model_ind_start 6602 --num_runs 5 --data cifar10 --lr 0.01 --task_model_type resnet18 --classes_per_task 2 --recall_from_t 950 --num_iterations 1 --M 100 --refine_sample_steps 10 --refine_sample_lr 10.0 --divergence_loss_weight 8.0 --L2 --L2_weight 1.0 --TV --TV_weight 1.0 --sharpen_class --sharpen_class_weight 0.1 --notlocal_weight 1.0 --notlocal_new_weight 0.1 --diversity_weight 16.0 --aux_distill --aux_distill_weight 1.0 --max_t 4750 --store_model_freq 4750 --store_results_freq 950 --eval_freq 950 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/CIFAR > out/m4127s.out & 85 | 86 | ARM no distill 87 | nohup python -m code.scripts.ARM --model_ind_start 2717 --num_runs 5 --data cifar10 --lr 0.01 --task_model_type resnet18 --classes_per_task 2 --recall_from_t 950 --num_iterations 1 --M 100 --refine_sample_steps 10 --refine_sample_lr 10.0 --divergence_loss_weight 1.0 --L2 --L2_weight 1.0 --TV --TV_weight 1.0 --long_window --use_fixed_window --fixed_window 950 --sharpen_class --sharpen_class_weight 1.0 --notlocal_weight 1.0 --notlocal_new_weight 0.1 --diversity_weight 16.0 --max_t 4750 --store_model_freq 4750 --store_results_freq 950 --eval_freq 950 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/CIFAR > out/m2717s.out & 88 | 89 | # ------------------------------------------ 90 | Table 7 91 | # ------------------------------------------ 92 | 93 | \lambda_1 = 0, \lambda_2 = 0 94 | nohup python -m code.scripts.ARM --model_ind_start 3982 --num_runs 5 --data cifar10 --lr 0.01 --task_model_type resnet18 --classes_per_task 2 --recall_from_t 950 --num_iterations 1 --M 100 --refine_sample_steps 10 --refine_sample_lr 10.0 --divergence_loss_weight 1.0 --L2 --L2_weight 1.0 --TV --TV_weight 1.0 --long_window --use_fixed_window --fixed_window 950 --sharpen_class --sharpen_class_weight 0.1 --notlocal_weight 0.0 --notlocal_new_weight 0.0 --diversity_weight 16.0 --aux_distill --aux_distill_weight 1.0 --max_t 4750 --store_model_freq 4750 --store_results_freq 950 --eval_freq 950 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/CIFAR > out/m3982s.out & 95 | 96 | \lambda_3 = 0 97 | nohup python -m code.scripts.ARM --model_ind_start 6562 --num_runs 5 --data cifar10 --lr 0.01 --task_model_type resnet18 --classes_per_task 2 --recall_from_t 950 --num_iterations 1 --M 100 --refine_sample_steps 10 --refine_sample_lr 10.0 --divergence_loss_weight 1.0 --L2 --L2_weight 1.0 --TV --TV_weight 1.0 --long_window --use_fixed_window --fixed_window 950 --sharpen_class --sharpen_class_weight 0.1 --notlocal_weight 1.0 --notlocal_new_weight 0.1 --diversity_weight 0.0 --aux_distill --aux_distill_weight 1.0 --max_t 4750 --store_model_freq 4750 --store_results_freq 950 --eval_freq 950 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/CIFAR > out/m6562s.out & 98 | 99 | \lambda_4 = 0 100 | nohup python -m code.scripts.ARM --model_ind_start 3977 --num_runs 5 --data cifar10 --lr 0.01 --task_model_type resnet18 --classes_per_task 2 --recall_from_t 950 --num_iterations 1 --M 100 --refine_sample_steps 10 --refine_sample_lr 10.0 --divergence_loss_weight 1.0 --L2 --L2_weight 1.0 --TV --TV_weight 1.0 --long_window --use_fixed_window --fixed_window 950 --notlocal_weight 1.0 --notlocal_new_weight 0.1 --diversity_weight 16.0 --aux_distill --aux_distill_weight 1.0 --max_t 4750 --store_model_freq 4750 --store_results_freq 950 --eval_freq 950 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/CIFAR > out/m3977s.out & 101 | 102 | \lambda_5 = 0 103 | nohup python -m code.scripts.ARM --model_ind_start 3967 --num_runs 5 --data cifar10 --lr 0.01 --task_model_type resnet18 --classes_per_task 2 --recall_from_t 950 --num_iterations 1 --M 100 --refine_sample_steps 10 --refine_sample_lr 10.0 --divergence_loss_weight 1.0 --TV --TV_weight 1.0 --long_window --use_fixed_window --fixed_window 950 --sharpen_class --sharpen_class_weight 0.1 --notlocal_weight 1.0 --notlocal_new_weight 0.1 --diversity_weight 16.0 --aux_distill --aux_distill_weight 1.0 --max_t 4750 --store_model_freq 4750 --store_results_freq 950 --eval_freq 950 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/CIFAR > out/m3967s.out & 104 | 105 | \lambda_6 = 0 106 | nohup python -m code.scripts.ARM --model_ind_start 3972 --num_runs 5 --data cifar10 --lr 0.01 --task_model_type resnet18 --classes_per_task 2 --recall_from_t 950 --num_iterations 1 --M 100 --refine_sample_steps 10 --refine_sample_lr 10.0 --divergence_loss_weight 1.0 --L2 --L2_weight 1.0 --long_window --use_fixed_window --fixed_window 950 --sharpen_class --sharpen_class_weight 0.1 --notlocal_weight 1.0 --notlocal_new_weight 0.1 --diversity_weight 16.0 --aux_distill --aux_distill_weight 1.0 --max_t 4750 --store_model_freq 4750 --store_results_freq 950 --eval_freq 950 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/CIFAR > out/m3972s.out & 107 | 108 | M = 150 (+50) 109 | nohup python -m code.scripts.ARM --model_ind_start 6624 --num_runs 5 --data cifar10 --lr 0.01 --task_model_type resnet18 --classes_per_task 2 --recall_from_t 950 --num_iterations 1 --M 150 --refine_sample_steps 10 --refine_sample_lr 10.0 --divergence_loss_weight 1.0 --L2 --L2_weight 1.0 --TV --TV_weight 1.0 --long_window --use_fixed_window --fixed_window 950 --sharpen_class --sharpen_class_weight 0.1 --notlocal_weight 1.0 --notlocal_new_weight 0.1 --diversity_weight 16.0 --aux_distill --aux_distill_weight 1.0 --max_t 4750 --store_model_freq 4750 --store_results_freq 950 --eval_freq 950 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/CIFAR > out/m6478s.out & 110 | 111 | M = 50 (-50) 112 | nohup python -m code.scripts.ARM --model_ind_start 5922 --num_runs 5 --data cifar10 --lr 0.01 --task_model_type resnet18 --classes_per_task 2 --recall_from_t 950 --num_iterations 1 --M 50 --refine_sample_steps 10 --refine_sample_lr 10.0 --divergence_loss_weight 1.0 --L2 --L2_weight 1.0 --TV --TV_weight 1.0 --long_window --use_fixed_window --fixed_window 950 --sharpen_class --sharpen_class_weight 0.1 --notlocal_weight 1.0 --notlocal_new_weight 0.1 --diversity_weight 16.0 --aux_distill --aux_distill_weight 1.0 --max_t 4750 --store_model_freq 4750 --store_results_freq 950 --eval_freq 950 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/CIFAR > out/m5922s.out & 113 | 114 | S = 20 (doubled) 115 | nohup python -m code.scripts.ARM --model_ind_start 3957 --num_runs 5 --data cifar10 --lr 0.01 --task_model_type resnet18 --classes_per_task 2 --recall_from_t 950 --num_iterations 1 --M 100 --refine_sample_steps 20 --refine_sample_lr 10.0 --divergence_loss_weight 1.0 --L2 --L2_weight 1.0 --TV --TV_weight 1.0 --long_window --use_fixed_window --fixed_window 950 --sharpen_class --sharpen_class_weight 0.1 --notlocal_weight 1.0 --notlocal_new_weight 0.1 --diversity_weight 16.0 --aux_distill --aux_distill_weight 1.0 --max_t 4750 --store_model_freq 4750 --store_results_freq 950 --eval_freq 950 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/CIFAR > out/m3957s.out & 116 | 117 | S = 5 (halved) 118 | nohup python -m code.scripts.ARM --model_ind_start 3962 --num_runs 5 --data cifar10 --lr 0.01 --task_model_type resnet18 --classes_per_task 2 --recall_from_t 950 --num_iterations 1 --M 100 --refine_sample_steps 5 --refine_sample_lr 10.0 --divergence_loss_weight 1.0 --L2 --L2_weight 1.0 --TV --TV_weight 1.0 --long_window --use_fixed_window --fixed_window 950 --sharpen_class --sharpen_class_weight 0.1 --notlocal_weight 1.0 --notlocal_new_weight 0.1 --diversity_weight 16.0 --aux_distill --aux_distill_weight 1.0 --max_t 4750 --store_model_freq 4750 --store_results_freq 950 --eval_freq 950 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/CIFAR > out/m3962s.out & 119 | 120 | Cross-entropy as D 121 | nohup python -m code.scripts.ARM --model_ind_start 3617 --num_runs 5 --data cifar10 --lr 0.01 --task_model_type resnet18 --classes_per_task 2 --recall_from_t 950 --num_iterations 1 --M 100 --refine_sample_steps 10 --refine_sample_lr 10.0 --divergence_loss_weight 1.0 --L2 --L2_weight 1.0 --TV --TV_weight 1.0 --long_window --use_fixed_window --fixed_window 950 --sharpen_class --sharpen_class_weight 0.1 --notlocal_weight 1.0 --notlocal_new_weight 0.1 --diversity_weight 16.0 --aux_distill --aux_distill_weight 1.0 --use_crossent_as_D --max_t 4750 --store_model_freq 4750 --store_results_freq 950 --eval_freq 950 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/CIFAR > out/m3617s.out & 122 | 123 | Random noise init \mathcal{\hat{B}}_X 124 | nohup python -m code.scripts.ARM --model_ind_start 4077 --num_runs 5 --data cifar10 --lr 0.01 --task_model_type resnet18 --classes_per_task 2 --recall_from_t 950 --num_iterations 1 --M 100 --refine_sample_steps 10 --refine_sample_lr 10.0 --divergence_loss_weight 1.0 --L2 --L2_weight 1.0 --TV --TV_weight 1.0 --long_window --use_fixed_window --fixed_window 950 --sharpen_class --sharpen_class_weight 0.1 --notlocal_weight 1.0 --notlocal_new_weight 0.1 --diversity_weight 16.0 --aux_x_random --aux_distill --aux_distill_weight 1.0 --max_t 4750 --store_model_freq 4750 --store_results_freq 950 --eval_freq 950 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/CIFAR > out/m4077s.out & 125 | 126 | Recall 2x per t 127 | nohup python -m code.scripts.ARM --model_ind_start 6502 --num_runs 5 --data cifar10 --lr 0.01 --task_model_type resnet18 --classes_per_task 2 --recall_from_t 950 --num_iterations 1 --M 100 --refine_theta_steps 2 --refine_sample_steps 10 --refine_sample_lr 10.0 --divergence_loss_weight 1.0 --L2 --L2_weight 1.0 --TV --TV_weight 1.0 --long_window --use_fixed_window --fixed_window 950 --sharpen_class --sharpen_class_weight 0.1 --notlocal_weight 1.0 --notlocal_new_weight 0.1 --diversity_weight 16.0 --aux_distill --aux_distill_weight 1.0 --max_t 4750 --store_model_freq 4750 --store_results_freq 950 --eval_freq 950 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/CIFAR > out/m6502s.out & 128 | 129 | Recall 4x per t 130 | nohup python -m code.scripts.ARM --model_ind_start 6507 --num_runs 5 --data cifar10 --lr 0.01 --task_model_type resnet18 --classes_per_task 2 --recall_from_t 950 --num_iterations 1 --M 100 --refine_theta_steps 4 --refine_sample_steps 10 --refine_sample_lr 10.0 --divergence_loss_weight 1.0 --L2 --L2_weight 1.0 --TV --TV_weight 1.0 --long_window --use_fixed_window --fixed_window 950 --sharpen_class --sharpen_class_weight 0.1 --notlocal_weight 1.0 --notlocal_new_weight 0.1 --diversity_weight 16.0 --aux_distill --aux_distill_weight 1.0 --max_t 4750 --store_model_freq 4750 --store_results_freq 950 --eval_freq 950 --cuda --out_root /scratch/shared/nfs1/xuji/ARM --data_path /scratch/shared/nfs1/xuji/datasets/CIFAR > out/m6507s.out & 131 | -------------------------------------------------------------------------------- /summary.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xu-ji/ARM/704f69d73765f13ca5e2d8aee11f399b53c635e6/summary.png --------------------------------------------------------------------------------