├── models ├── __init__.py ├── mlp.py └── lenet.py ├── dataloaders ├── __init__.py ├── datasetGen.py ├── wrapper.py └── base.py ├── img ├── cover_image.png └── hyperparameters_table.jpg ├── utils ├── __pycache__ │ └── utils.cpython-37.pyc └── utils.py ├── algos ├── common.py ├── ewc.py ├── gem.py └── ogd.py ├── requirements.txt ├── README.md ├── LICENSE ├── .gitignore ├── scripts ├── script.py └── commands │ ├── command_split_mnist.sh │ ├── command_cifar.sh │ └── command_rotated_mnist.sh ├── main.py └── trainer.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /dataloaders/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /img/cover_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tldoan/PCA-OGD/HEAD/img/cover_image.png -------------------------------------------------------------------------------- /img/hyperparameters_table.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tldoan/PCA-OGD/HEAD/img/hyperparameters_table.jpg -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tldoan/PCA-OGD/HEAD/utils/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /algos/common.py: -------------------------------------------------------------------------------- 1 | 2 | from dataloaders.wrapper import Storage 3 | 4 | class Memory(Storage): 5 | def reduce(self, m): 6 | self.storage = self.storage[:m] -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | certifi==2020.12.5 2 | chardet==4.0.0 3 | cycler==0.10.0 4 | Cython==0.29.21 5 | filelock==3.0.12 6 | gdown==3.12.2 7 | gnureadline==8.0.0 8 | idna==2.10 9 | kiwisolver==1.3.1 10 | matplotlib==3.3.2 11 | numpy==1.19.2 12 | pandas==1.1.3 13 | Pillow==7.2.0 14 | pyparsing==2.4.7 15 | PySocks==1.7.1 16 | python-dateutil==2.8.1 17 | pytz==2020.5 18 | quadprog==0.1.8 19 | requests==2.25.1 20 | six==1.15.0 21 | torch==1.7.1 22 | torchvision==0.8.2 23 | tqdm==4.56.0 24 | typing-extensions==3.7.4.3 25 | urllib3==1.26.2 26 | virtualenv==16.6.2 27 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PCA-OGD 2 | Official code for [A Theoretical Analysis of Catastrophic Forgetting through the NTK 3 | Overlap Matrix](https://arxiv.org/abs/2010.04003) (AISTATS 2021) 4 | 5 | 6 | 7 |
8 | 9 |
10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | ## Prerequisites 18 | - `requirements.txt` 19 | 20 | ## Instructions 21 | 22 | - Clone the repo 23 | - Run scripts: `scripts/commands` 24 | - Hyperparameters in Table 4 of the Appendix of the [paper](https://arxiv.org/pdf/2010.04003.pdf) 25 | 26 | ## Also includes: 27 | - Implementation of [Orthgohonal Gradient Descent (OGD)](https://arxiv.org/pdf/2010.04003.pdf) (Farajtabar et al., 2019) 28 | 29 | 30 | ## To cite: 31 | 32 | ``` 33 | @article{doan2020theoretical, 34 | title={A Theoretical Analysis of Catastrophic Forgetting through the NTK Overlap Matrix}, 35 | author={Doan, Thang and Bennani, Mehdi and Mazoure, Bogdan and Rabusseau, Guillaume and Alquier, Pierre}, 36 | journal={arXiv preprint arXiv:2010.04003}, 37 | year={2020} 38 | } 39 | ``` 40 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Thang Doan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /models/mlp.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class MLP(nn.Module): 5 | def __init__(self, out_dim=10, in_channel=1, img_sz=32, hidden_dim=256): 6 | super(MLP, self).__init__() 7 | self.in_dim = in_channel*img_sz*img_sz 8 | self.linear = nn.Sequential( 9 | nn.Linear(self.in_dim, hidden_dim), 10 | #nn.BatchNorm1d(hidden_dim), 11 | nn.ReLU(inplace=True), 12 | nn.Linear(hidden_dim, hidden_dim), 13 | #nn.BatchNorm1d(hidden_dim), 14 | nn.ReLU(inplace=True), 15 | ) 16 | self.last = nn.Linear(hidden_dim, out_dim) # Subject to be replaced dependent on task 17 | 18 | def features(self, x): 19 | x = self.linear(x.view(-1,self.in_dim)) 20 | return x 21 | 22 | def logits(self, x): 23 | x = self.last(x) 24 | return x 25 | 26 | def forward(self, x): 27 | x = self.features(x) 28 | x = self.logits(x) 29 | return x 30 | 31 | 32 | def MLP50(): 33 | print("\n Using MLP100 \n") 34 | return MLP(hidden_dim=50) 35 | 36 | 37 | def MLP100(): 38 | print("\n Using MLP100 \n") 39 | return MLP(hidden_dim=100) 40 | 41 | 42 | def MLP400(): 43 | return MLP(hidden_dim=400) 44 | 45 | 46 | def MLP1000(): 47 | print("\n Using MLP1000 \n") 48 | return MLP(hidden_dim=1000) 49 | 50 | 51 | def MLP2000(): 52 | return MLP(hidden_dim=2000) 53 | 54 | 55 | def MLP5000(): 56 | return MLP(hidden_dim=5000) -------------------------------------------------------------------------------- /models/lenet.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import torch.nn as nn 4 | 5 | 6 | class LeNet(nn.Module): 7 | 8 | def __init__(self, out_dim=10, in_channel=1, img_sz=32, hidden_dim=500): 9 | super(LeNet, self).__init__() 10 | feat_map_sz = img_sz//4 11 | self.n_feat = 50 * feat_map_sz * feat_map_sz 12 | self.hidden_dim = hidden_dim 13 | 14 | self.conv = nn.Sequential( 15 | nn.Conv2d(in_channel, 20, 5, padding=2), 16 | nn.BatchNorm2d(20), 17 | nn.ReLU(inplace=True), 18 | nn.MaxPool2d(2, 2), 19 | nn.Conv2d(20, 50, 5, padding=2), 20 | nn.BatchNorm2d(50), 21 | nn.ReLU(inplace=True), 22 | nn.MaxPool2d(2, 2), 23 | nn.Flatten(), 24 | 25 | ) 26 | self.linear = nn.Sequential( 27 | nn.Linear(self.n_feat, hidden_dim), 28 | nn.BatchNorm1d(hidden_dim), 29 | nn.ReLU(inplace=True), 30 | ) 31 | 32 | 33 | 34 | self.last = nn.Linear(hidden_dim, out_dim) # Subject to be replaced dependent on task 35 | 36 | def features(self, x): 37 | x = self.conv(x) 38 | 39 | 40 | x = self.linear(x.view(-1, self.n_feat)) 41 | 42 | # x=self.linear(x) 43 | return x 44 | 45 | def logits(self, x): 46 | x = self.last(x) 47 | return x 48 | 49 | def forward(self, x): 50 | x = self.features(x) 51 | x = self.logits(x) 52 | return x 53 | 54 | 55 | def LeNetC(out_dim=10, hidden_dim=500): # LeNet with color input 56 | return LeNet(out_dim=out_dim, in_channel=3, img_sz=32, hidden_dim=hidden_dim) -------------------------------------------------------------------------------- /algos/ewc.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import torch.nn.functional as F 3 | ## toolbox for EWC 4 | 5 | def consolidate(trainer,precision_matrices): 6 | for index in trainer._precision_matrices: 7 | trainer._precision_matrices[index]+=precision_matrices[index] 8 | 9 | def update_means(trainer): 10 | 11 | for n, p in deepcopy(trainer.params).items(): 12 | trainer._means[n] = p.data 13 | 14 | def penalty(trainer): 15 | loss = 0 16 | 17 | for index in trainer.params: 18 | _loss = trainer._precision_matrices[index] *0.5* (trainer.params[index] - trainer._means[index]) ** 2 19 | loss += _loss.sum() 20 | return loss 21 | 22 | def _diag_fisher(trainer,train_loader): 23 | precision_matrices = {} 24 | 25 | for n, p in deepcopy(trainer.params).items(): 26 | p.data.zero_() 27 | precision_matrices[n] = p.data 28 | 29 | trainer.model.eval() 30 | k=1 31 | for _,element in enumerate(train_loader): 32 | if k>=trainer.config.fisher_sample: 33 | break 34 | trainer.model.zero_grad() 35 | inputs=element[0].to(trainer.config.device) 36 | targets = element[1].long().to(trainer.config.device) 37 | 38 | task=element[2] 39 | out = trainer.forward(inputs,task) 40 | assert out.shape[0] == 1 41 | 42 | pred = out.cpu() 43 | 44 | 45 | loss=F.log_softmax(pred, dim=1)[0][targets.item()] 46 | loss.backward() 47 | 48 | for index in trainer.params: 49 | trainer._precision_matrices[index].data += trainer.params[index].grad.data ** 2 / trainer.config.fisher_sample 50 | 51 | precision_matrices = {n: p for n, p in precision_matrices.items()} 52 | return precision_matrices -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | -------------------------------------------------------------------------------- /scripts/script.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import os 3 | 4 | 5 | prefix_list = ["python main.py "] 6 | 7 | def generate_command(params_dict,prefix_list): 8 | 9 | final_command=[] 10 | for key,value in params_dict.items(): 11 | if not isinstance(value,list): 12 | params_dict[key]=[value] 13 | 14 | 15 | keys, values = zip(*params_dict.items()) 16 | 17 | params_combination=[(keys,v) for v in itertools.product(*values)] 18 | 19 | for i in params_combination: 20 | 21 | command=prefix_list[0]+" " 22 | for element in range(len(i[0])): 23 | command+=" --"+str(i[0][element])+" "+str(i[1][element]) 24 | 25 | final_command.append(command) 26 | 27 | return final_command 28 | 29 | 30 | def replicate_results_split_cifar(): 31 | params_dict={ 32 | "nepoch":10, 33 | "memory_size":100, 34 | "batch_size":32, 35 | "lr":1e-3, 36 | "subset_size":1000, 37 | "memory_size":100, 38 | "n_tasks":20, 39 | "eval_freq":1000, 40 | "agem_mem_batch_size":256, 41 | "pca_sample":3000, 42 | "dataset":"split_cifar", 43 | "fisher_sample":1024, 44 | "hidden_dim": 200, 45 | "ewc_reg":25, 46 | "seed": [0,1,2,3,4], 47 | "method":["pca","ogd","ewc","sgd","agem"] 48 | } 49 | return params_dict 50 | 51 | 52 | ### rotated and permuted mnist 53 | def replicate_results_rotated_mnist(): 54 | params_dict={ 55 | "nepoch":10, 56 | "memory_size":100, 57 | "batch_size":32, 58 | "lr":1e-3, 59 | "subset_size":1000, 60 | "memory_size":100, 61 | "n_tasks":15, 62 | "eval_freq":1000, 63 | "agem_mem_batch_size":256, 64 | "pca_sample":3000, 65 | "dataset":"split_cifar", 66 | "fisher_sample":1024, 67 | "hidden_dim": 100, 68 | "ewc_reg":10, 69 | "seed": [0,1,2,3,4], 70 | "method":["pca","ogd","ewc","sgd","agem"] 71 | } 72 | return params_dict 73 | 74 | 75 | def replicate_results_split_mnist(): 76 | params_dict={ 77 | "nepoch":5, 78 | "memory_size":100, 79 | "batch_size":32, 80 | "lr":1e-3, 81 | "subset_size":2000, 82 | "memory_size":100, 83 | "n_tasks":5, 84 | "eval_freq":1000, 85 | "agem_mem_batch_size":256, 86 | "pca_sample":3000, 87 | "dataset":"split_cifar", 88 | "fisher_sample":1024, 89 | "hidden_dim": 100, 90 | "ewc_reg":10, 91 | "seed": [0,1,2,3,4], 92 | "method":["pca","ogd","ewc","sgd","agem"] 93 | } 94 | return params_dict 95 | 96 | 97 | if __name__ == '__main__': 98 | 99 | 100 | command_split_cifar=generate_command(replicate_results_split_cifar(),prefix_list) 101 | command_split_mnist=generate_command(replicate_results_split_mnist(),prefix_list) 102 | command_rotated_mnist=generate_command(replicate_results_rotated_mnist(),prefix_list) 103 | 104 | carrier=[command_split_cifar,command_split_mnist,command_rotated_mnist] 105 | name=["command_cifar","command_split_mnist","command_rotated_mnist"] 106 | if not os.path.exists("commands"): 107 | os.makedirs("commands", exist_ok=True) 108 | for it in range(len(carrier)): 109 | 110 | with open("commands/{}.sh".format(name[it]), "w") as outfile: 111 | outfile.write("\n".join(carrier[it])+"\n") -------------------------------------------------------------------------------- /dataloaders/datasetGen.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from random import shuffle 3 | from .wrapper import Subclass, AppendName, Permutation 4 | 5 | 6 | def SplitGen(train_dataset, val_dataset, first_split_sz=2, other_split_sz=2, rand_split=False, remap_class=False): 7 | ''' 8 | Generate the dataset splits based on the labels. 9 | :param train_dataset: (torch.utils.data.dataset) 10 | :param val_dataset: (torch.utils.data.dataset) 11 | :param first_split_sz: (int) 12 | :param other_split_sz: (int) 13 | :param rand_split: (bool) Randomize the set of label in each split 14 | :param remap_class: (bool) Ex: remap classes in a split from [2,4,6 ...] to [0,1,2 ...] 15 | :return: train_loaders {task_name:loader}, val_loaders {task_name:loader}, out_dim {task_name:num_classes} 16 | ''' 17 | assert train_dataset.number_classes==val_dataset.number_classes,'Train/Val has different number of classes' 18 | num_classes = train_dataset.number_classes 19 | 20 | # Calculate the boundary index of classes for splits 21 | # Ex: [0,2,4,6,8,10] or [0,50,60,70,80,90,100] 22 | split_boundaries = [0, first_split_sz] 23 | while split_boundaries[-1] 0 : 71 | new_data.append(x) 72 | new_targets.append(label) 73 | cnt[label] -= 1 74 | 75 | 76 | train_dataset.dataset = torch.stack(new_data) 77 | train_dataset.labels = torch.Tensor(new_targets) 78 | train_dataset.data = torch.stack(new_data) 79 | train_dataset.targets = torch.Tensor(new_targets) 80 | 81 | 82 | ds = torch.utils.data.TensorDataset(train_dataset.dataset, train_dataset.labels) 83 | ds.root=dataroot 84 | train_dataset = CacheClassLabel(ds) 85 | else: 86 | train_dataset = CacheClassLabel(train_dataset) 87 | 88 | val_dataset = torchvision.datasets.MNIST( 89 | dataroot, 90 | train=False, 91 | transform=val_transform 92 | ) 93 | val_dataset = CacheClassLabel(val_dataset) 94 | 95 | return train_dataset, val_dataset 96 | 97 | 98 | def CIFAR10(dataroot, train_aug=False, angle=0): 99 | normalize = transforms.Normalize(mean=[0.491, 0.482, 0.447], std=[0.247, 0.243, 0.262]) 100 | rotate = RotationTransform(angle=angle) 101 | 102 | val_transform = transforms.Compose([ 103 | rotate, 104 | transforms.ToTensor(), 105 | normalize, 106 | ]) 107 | train_transform = val_transform 108 | if train_aug: 109 | train_transform = transforms.Compose([ 110 | transforms.RandomCrop(32, padding=4), 111 | transforms.RandomHorizontalFlip(), 112 | rotate, 113 | transforms.ToTensor(), 114 | normalize, 115 | ]) 116 | 117 | train_dataset = torchvision.datasets.CIFAR10( 118 | root=dataroot, 119 | train=True, 120 | download=True, 121 | transform=train_transform 122 | ) 123 | train_dataset = CacheClassLabel(train_dataset) 124 | 125 | val_dataset = torchvision.datasets.CIFAR10( 126 | root=dataroot, 127 | train=False, 128 | download=True, 129 | transform=val_transform 130 | ) 131 | val_dataset = CacheClassLabel(val_dataset) 132 | 133 | return train_dataset, val_dataset 134 | 135 | 136 | def CIFAR100(dataroot, train_aug=False, angle=0): 137 | normalize = transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276]) 138 | 139 | val_transform = transforms.Compose([ 140 | transforms.ToTensor(), 141 | normalize, 142 | ]) 143 | train_transform = val_transform 144 | if train_aug: 145 | train_transform = transforms.Compose([ 146 | transforms.RandomCrop(32, padding=4), 147 | transforms.RandomHorizontalFlip(), 148 | transforms.ToTensor(), 149 | normalize, 150 | ]) 151 | 152 | train_dataset = torchvision.datasets.CIFAR100( 153 | root=dataroot, 154 | train=True, 155 | download=True, 156 | transform=train_transform 157 | ) 158 | train_dataset = CacheClassLabel(train_dataset) 159 | 160 | val_dataset = torchvision.datasets.CIFAR100( 161 | root=dataroot, 162 | train=False, 163 | download=True, 164 | transform=val_transform 165 | ) 166 | val_dataset = CacheClassLabel(val_dataset) 167 | 168 | return train_dataset, val_dataset 169 | -------------------------------------------------------------------------------- /algos/gem.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils.utils import parameters_to_grad_vector 3 | from algos.common import Memory 4 | import numpy as np 5 | ## toolbox for GEM/AGEM methods 6 | 7 | def _get_new_gem_m_basis(self, device,optimizer, model,forward): 8 | new_basis = [] 9 | 10 | 11 | 12 | for t,mem in self.task_memory.items(): 13 | 14 | 15 | for index in range(len(self.task_mem_cache[t]['data'])): 16 | 17 | 18 | inputs=self.task_mem_cache[t]['data'][index].to(device).unsqueeze(0) 19 | 20 | targets=self.task_mem_cache[t]['target'][index].to(device).unsqueeze(0) 21 | 22 | task=self.task_mem_cache[t]['task'][0] 23 | 24 | out = forward(inputs,task) 25 | 26 | mem_loss = self.criterion(out, targets.long().to(self.config.device)) 27 | optimizer.zero_grad() 28 | mem_loss.backward() 29 | new_basis.append(parameters_to_grad_vector(self.get_params_dict(last=False)).cpu()) 30 | 31 | del out,inputs,targets 32 | torch.cuda.empty_cache() 33 | new_basis = torch.stack(new_basis).T 34 | return new_basis 35 | 36 | 37 | def update_agem_memory(trainer, train_loader,task_id): 38 | 39 | trainer.task_count =task_id 40 | num_sample_per_task = trainer.config.memory_size #// (self.config.n_tasks-1) 41 | randind = torch.randperm(len(train_loader.dataset))[:num_sample_per_task] 42 | for ind in randind: # save it to the memory 43 | trainer.agem_mem.append(train_loader.dataset[ind]) 44 | 45 | 46 | 47 | mem_loader_batch_size = min(trainer.config.agem_mem_batch_size, len(trainer.agem_mem)) 48 | trainer.agem_mem_loader = torch.utils.data.DataLoader(trainer.agem_mem, 49 | batch_size=mem_loader_batch_size, 50 | shuffle=True, 51 | num_workers=1) 52 | 53 | 54 | def update_gem_no_transfer_memory(trainer,train_loader,task_id): 55 | 56 | trainer.task_count=task_id 57 | 58 | num_sample_per_task = trainer.config.memory_size 59 | 60 | num_sample_per_task = min(len(train_loader.dataset), num_sample_per_task) 61 | 62 | trainer.task_memory[trainer.task_count] = Memory() 63 | randind = torch.randperm(len(train_loader.dataset))[:num_sample_per_task] # randomly sample some data 64 | for ind in randind: # save it to the memory 65 | trainer.task_memory[trainer.task_count].append(train_loader.dataset[ind]) 66 | 67 | 68 | for t, mem in trainer.task_memory.items(): 69 | 70 | mem_loader = torch.utils.data.DataLoader(mem, 71 | batch_size=len(mem), 72 | shuffle=False, 73 | num_workers=1) 74 | assert len(mem_loader) == 1, 'The length of mem_loader should be 1' 75 | for i, (mem_input, mem_target, mem_task) in enumerate(mem_loader): 76 | pass 77 | 78 | trainer.task_mem_cache[t] = {'data': mem_input, 'target': mem_target, 'task': mem_task} 79 | 80 | 81 | 82 | 83 | 84 | def _project_agem_grad(trainer, batch_grad_vec, mem_grad_vec): 85 | 86 | if torch.dot(batch_grad_vec, mem_grad_vec) >= 0: 87 | return batch_grad_vec 88 | else : 89 | trainer.gradient_violation+=1 90 | 91 | 92 | 93 | frac = torch.dot(batch_grad_vec, mem_grad_vec) / torch.dot(mem_grad_vec, mem_grad_vec) 94 | 95 | new_grad = batch_grad_vec - frac * mem_grad_vec 96 | 97 | check = torch.dot(new_grad, mem_grad_vec) 98 | assert torch.abs(check) < 1e-5 99 | return new_grad 100 | 101 | def project2cone2(trainer, gradient, memories): 102 | """ 103 | Solves the GEM dual QP described in the paper given a proposed 104 | gradient "gradient", and a memory of task gradients "memories". 105 | Overwrites "gradient" with the final projected update. 106 | input: gradient, p-vector 107 | input: memories, (t * p)-vector 108 | output: x, p-vector 109 | Modified from: https://github.com/facebookresearch/GradientEpisodicMemory/blob/master/model/gem.py#L70 110 | """ 111 | 112 | margin = trainer.config.margin_gem 113 | memories_np = memories.cpu().contiguous().double().numpy() 114 | gradient_np = gradient.cpu().contiguous().view(-1).double().numpy() 115 | t = memories_np.shape[0] 116 | 117 | P = np.dot(memories_np, memories_np.transpose()) 118 | P = 0.5 * (P + P.transpose()) 119 | q = np.dot(memories_np, gradient_np) * -1 120 | G = np.eye(t) 121 | P = P + G * 0.001 122 | h = np.zeros(t) + margin 123 | # print(P) 124 | v = trainer.quadprog.solve_qp(P, q, G, h)[0] 125 | x = np.dot(v, memories_np) + gradient_np 126 | new_grad = torch.Tensor(x).view(-1) 127 | 128 | new_grad = new_grad.to(trainer.config.device) 129 | return new_grad -------------------------------------------------------------------------------- /scripts/commands/command_split_mnist.sh: -------------------------------------------------------------------------------- 1 | python main.py --nepoch 5 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 2000 --n_tasks 5 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 0 --method pca 2 | python main.py --nepoch 5 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 2000 --n_tasks 5 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 0 --method ogd 3 | python main.py --nepoch 5 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 2000 --n_tasks 5 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 0 --method ewc 4 | python main.py --nepoch 5 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 2000 --n_tasks 5 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 0 --method sgd 5 | python main.py --nepoch 5 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 2000 --n_tasks 5 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 0 --method agem 6 | python main.py --nepoch 5 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 2000 --n_tasks 5 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 1 --method pca 7 | python main.py --nepoch 5 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 2000 --n_tasks 5 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 1 --method ogd 8 | python main.py --nepoch 5 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 2000 --n_tasks 5 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 1 --method ewc 9 | python main.py --nepoch 5 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 2000 --n_tasks 5 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 1 --method sgd 10 | python main.py --nepoch 5 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 2000 --n_tasks 5 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 1 --method agem 11 | python main.py --nepoch 5 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 2000 --n_tasks 5 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 2 --method pca 12 | python main.py --nepoch 5 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 2000 --n_tasks 5 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 2 --method ogd 13 | python main.py --nepoch 5 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 2000 --n_tasks 5 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 2 --method ewc 14 | python main.py --nepoch 5 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 2000 --n_tasks 5 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 2 --method sgd 15 | python main.py --nepoch 5 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 2000 --n_tasks 5 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 2 --method agem 16 | python main.py --nepoch 5 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 2000 --n_tasks 5 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 3 --method pca 17 | python main.py --nepoch 5 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 2000 --n_tasks 5 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 3 --method ogd 18 | python main.py --nepoch 5 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 2000 --n_tasks 5 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 3 --method ewc 19 | python main.py --nepoch 5 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 2000 --n_tasks 5 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 3 --method sgd 20 | python main.py --nepoch 5 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 2000 --n_tasks 5 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 3 --method agem 21 | python main.py --nepoch 5 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 2000 --n_tasks 5 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 4 --method pca 22 | python main.py --nepoch 5 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 2000 --n_tasks 5 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 4 --method ogd 23 | python main.py --nepoch 5 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 2000 --n_tasks 5 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 4 --method ewc 24 | python main.py --nepoch 5 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 2000 --n_tasks 5 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 4 --method sgd 25 | python main.py --nepoch 5 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 2000 --n_tasks 5 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 4 --method agem 26 | -------------------------------------------------------------------------------- /scripts/commands/command_cifar.sh: -------------------------------------------------------------------------------- 1 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 20 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 200 --ewc_reg 25 --seed 0 --method pca 2 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 20 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 200 --ewc_reg 25 --seed 0 --method ogd 3 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 20 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 200 --ewc_reg 25 --seed 0 --method ewc 4 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 20 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 200 --ewc_reg 25 --seed 0 --method sgd 5 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 20 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 200 --ewc_reg 25 --seed 0 --method agem 6 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 20 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 200 --ewc_reg 25 --seed 1 --method pca 7 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 20 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 200 --ewc_reg 25 --seed 1 --method ogd 8 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 20 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 200 --ewc_reg 25 --seed 1 --method ewc 9 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 20 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 200 --ewc_reg 25 --seed 1 --method sgd 10 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 20 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 200 --ewc_reg 25 --seed 1 --method agem 11 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 20 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 200 --ewc_reg 25 --seed 2 --method pca 12 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 20 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 200 --ewc_reg 25 --seed 2 --method ogd 13 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 20 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 200 --ewc_reg 25 --seed 2 --method ewc 14 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 20 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 200 --ewc_reg 25 --seed 2 --method sgd 15 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 20 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 200 --ewc_reg 25 --seed 2 --method agem 16 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 20 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 200 --ewc_reg 25 --seed 3 --method pca 17 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 20 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 200 --ewc_reg 25 --seed 3 --method ogd 18 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 20 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 200 --ewc_reg 25 --seed 3 --method ewc 19 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 20 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 200 --ewc_reg 25 --seed 3 --method sgd 20 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 20 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 200 --ewc_reg 25 --seed 3 --method agem 21 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 20 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 200 --ewc_reg 25 --seed 4 --method pca 22 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 20 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 200 --ewc_reg 25 --seed 4 --method ogd 23 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 20 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 200 --ewc_reg 25 --seed 4 --method ewc 24 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 20 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 200 --ewc_reg 25 --seed 4 --method sgd 25 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 20 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 200 --ewc_reg 25 --seed 4 --method agem 26 | -------------------------------------------------------------------------------- /scripts/commands/command_rotated_mnist.sh: -------------------------------------------------------------------------------- 1 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 15 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 0 --method pca 2 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 15 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 0 --method ogd 3 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 15 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 0 --method ewc 4 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 15 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 0 --method sgd 5 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 15 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 0 --method agem 6 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 15 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 1 --method pca 7 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 15 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 1 --method ogd 8 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 15 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 1 --method ewc 9 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 15 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 1 --method sgd 10 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 15 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 1 --method agem 11 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 15 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 2 --method pca 12 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 15 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 2 --method ogd 13 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 15 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 2 --method ewc 14 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 15 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 2 --method sgd 15 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 15 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 2 --method agem 16 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 15 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 3 --method pca 17 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 15 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 3 --method ogd 18 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 15 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 3 --method ewc 19 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 15 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 3 --method sgd 20 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 15 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 3 --method agem 21 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 15 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 4 --method pca 22 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 15 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 4 --method ogd 23 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 15 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 4 --method ewc 24 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 15 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 4 --method sgd 25 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 15 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 4 --method agem 26 | -------------------------------------------------------------------------------- /algos/ogd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from algos.common import Memory 3 | from utils.utils import parameters_to_grad_vector, count_parameter 4 | 5 | def _get_new_ogd_basis(trainer,train_loader, device,optimizer, model,forward): 6 | new_basis = [] 7 | 8 | 9 | for _,element in enumerate(train_loader): 10 | 11 | inputs=element[0].to(device) 12 | 13 | targets = element[1].to(device) 14 | 15 | task=element[2] 16 | 17 | out = forward(inputs,task) 18 | 19 | assert out.shape[0] == 1 20 | 21 | pred = out[0,int(targets.item())].cpu() 22 | 23 | optimizer.zero_grad() 24 | pred.backward() 25 | 26 | ### retrieve \nabla f(x) wrt theta 27 | new_basis.append(parameters_to_grad_vector(trainer.get_params_dict(last=False)).cpu()) 28 | 29 | del out,inputs,targets 30 | torch.cuda.empty_cache() 31 | new_basis = torch.stack(new_basis).T 32 | 33 | return new_basis 34 | 35 | def project_vec(vec, proj_basis): 36 | if proj_basis.shape[1] > 0 : # param x basis_size 37 | dots = torch.matmul(vec, proj_basis) # basis_size dots= [ for i in proj_basis ] 38 | out = torch.matmul(proj_basis, dots) # out = [ i for i in proj_basis ] 39 | return out 40 | else: 41 | return torch.zeros_like(vec) 42 | 43 | def update_mem(trainer,train_loader,task_count): 44 | 45 | trainer.task_count =task_count 46 | 47 | 48 | num_sample_per_task = trainer.config.memory_size # // (self.config.n_tasks-1) 49 | num_sample_per_task = min(len(train_loader.dataset), num_sample_per_task) 50 | 51 | memory_length=[] 52 | for i in range(task_count): 53 | memory_length.append(num_sample_per_task) 54 | 55 | 56 | for storage in trainer.task_memory.values(): 57 | ## reduce the size of the stored elements 58 | storage.reduce(num_sample_per_task) 59 | 60 | 61 | trainer.task_memory[0] = Memory() # Initialize the memory slot 62 | 63 | if trainer.config.method=="pca": 64 | randind = torch.randperm(len(train_loader.dataset))[:trainer.config.pca_sample] ## for pca method we samples pca_samples > num_sample_per_task before applying pca and keeping (num_sample_per_task) elements 65 | else: 66 | randind = torch.randperm(len(train_loader.dataset))[:num_sample_per_task] # randomly sample some data 67 | for ind in randind: # save it to the memory 68 | 69 | trainer.task_memory[0].append(train_loader.dataset[ind]) 70 | 71 | 72 | ####################################### Grads MEM ########################### 73 | 74 | for storage in trainer.task_grad_memory.values(): 75 | storage.reduce(num_sample_per_task) 76 | 77 | 78 | 79 | if trainer.config.method in ['ogd','pca']: 80 | ogd_train_loader = torch.utils.data.DataLoader(trainer.task_memory[0], 81 | batch_size=1, 82 | shuffle=False, 83 | num_workers=1) 84 | 85 | 86 | 87 | trainer.task_memory[0] = Memory() 88 | 89 | new_basis_tensor = _get_new_ogd_basis(trainer, 90 | ogd_train_loader, 91 | trainer.config.device, 92 | trainer.optimizer, 93 | trainer.model, 94 | trainer.forward).cpu() 95 | 96 | 97 | 98 | if trainer.config.method=="pca": 99 | 100 | try: 101 | _,_,v1=torch.pca_lowrank(new_basis_tensor.T.cpu(), q=num_sample_per_task, center=True, niter=2) 102 | 103 | except: 104 | _,_,v1=torch.svd_lowrank((new_basis_tensor.T+1e-4*new_basis_tensor.T.mean()*torch.rand(new_basis_tensor.T.size(0), new_basis_tensor.T.size(1))).cpu(), q=num_sample_per_task, niter=2, M=None) 105 | 106 | 107 | 108 | del new_basis_tensor 109 | new_basis_tensor=v1.cpu() 110 | torch.cuda.empty_cache() 111 | 112 | if trainer.config.is_split: 113 | if trainer.config.all_features: 114 | if hasattr(trainer.model,"conv"): 115 | n_params = count_parameter(trainer.model.linear)+count_parameter(trainer.model.conv) 116 | else: 117 | n_params = count_parameter(trainer.model.linear) 118 | else: 119 | 120 | n_params = count_parameter(trainer.model.linear) 121 | 122 | else: 123 | n_params = count_parameter(trainer.model) 124 | 125 | 126 | 127 | trainer.ogd_basis = torch.empty(n_params, 0).cpu() 128 | 129 | 130 | for t, mem in trainer.task_grad_memory.items(): 131 | 132 | task_ogd_basis_tensor=torch.stack(mem.storage,axis=1).cpu() 133 | 134 | trainer.ogd_basis = torch.cat([trainer.ogd_basis, task_ogd_basis_tensor], axis=1).cpu() 135 | 136 | 137 | trainer.ogd_basis=orthonormalize(trainer.ogd_basis,new_basis_tensor,trainer.config.device,normalize=True) 138 | 139 | 140 | # (g) Store in the new basis 141 | ptr = 0 142 | 143 | for t in range(len(memory_length)): 144 | 145 | 146 | task_mem_size=memory_length[t] 147 | idxs_list = [i + ptr for i in range(task_mem_size)] 148 | 149 | trainer.ogd_basis_ids[t] = torch.LongTensor(idxs_list).to(trainer.config.device) 150 | 151 | 152 | trainer.task_grad_memory[t] = Memory() # Initialize the memory slot 153 | 154 | 155 | 156 | 157 | if trainer.config.method=="pca": 158 | length=num_sample_per_task 159 | else: 160 | length=task_mem_size 161 | for ind in range(length): # save it to the memory 162 | trainer.task_grad_memory[t].append(trainer.ogd_basis[:, ptr].cpu()) 163 | ptr += 1 164 | 165 | 166 | def orthonormalize(main_vectors, additional_vectors,device,normalize=True): 167 | ## orthnormalize the basis (graham schmidt) 168 | for element in range(additional_vectors.size()[1]): 169 | 170 | 171 | coeff=torch.mv(main_vectors.t(),additional_vectors[:,element]) ## x - y/ |||| 172 | pv=torch.mv(main_vectors, coeff) 173 | d=(additional_vectors[:,element]-pv)/torch.norm(additional_vectors[:,element]-pv,p=2) 174 | main_vectors=torch.cat((main_vectors,d.view(-1,1)),dim=1) 175 | del pv 176 | del d 177 | return main_vectors.to(device) -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import dataloaders.base 2 | from dataloaders.datasetGen import SplitGen, PermutedGen,RotatedGen 3 | from torch.nn.utils.convert_parameters import _check_param_device, parameters_to_vector, vector_to_parameters 4 | import torch 5 | 6 | 7 | 8 | def get_benchmark_data_loader(config): 9 | ## example: config.dataset_root_path='/home/usr/dataset/CIFAR100' 10 | config.dataset_root_path='' 11 | if config.dataset=="permuted": 12 | config.force_out_dim=10 13 | train_dataset, val_dataset = dataloaders.base.__dict__['MNIST'](config.dataset_root_path, False ,subset_size=config.subset_size) 14 | 15 | train_dataset_splits, val_dataset_splits, task_output_space = PermutedGen(train_dataset, val_dataset,config.n_tasks,remap_class= False) 16 | 17 | elif config.dataset=="rotated": 18 | config.force_out_dim=10 19 | import dataloaders.base 20 | 21 | Dataset = dataloaders.base.__dict__["MNIST"] 22 | n_rotate=config.n_tasks 23 | 24 | rotate_step=5 25 | 26 | 27 | train_dataset_splits, val_dataset_splits, task_output_space = RotatedGen(Dataset=Dataset, 28 | dataroot=config.dataset_root_path, 29 | train_aug=False, 30 | n_rotate=n_rotate, 31 | rotate_step=rotate_step, 32 | remap_class=False 33 | ,subset_size=config.subset_size) 34 | 35 | elif config.dataset=="split_mnist": 36 | config.first_split_size=2 37 | config.other_split_size=2 38 | config.force_out_dim=0 39 | config.is_split=True 40 | import dataloaders.base 41 | Dataset = dataloaders.base.__dict__["MNIST"] 42 | 43 | 44 | 45 | if config.subset_size<50000: 46 | train_dataset, val_dataset = Dataset(config.dataset_root_path,False, angle=0,noise=None,subset_size=config.subset_size) 47 | else: 48 | train_dataset, val_dataset = Dataset(config.dataset_root_path,False, angle=0,noise=None) 49 | train_dataset_splits, val_dataset_splits, task_output_space = SplitGen(train_dataset, val_dataset, 50 | first_split_sz=config.first_split_size, 51 | other_split_sz=config.other_split_size, 52 | rand_split=config.rand_split, 53 | remap_class=True) 54 | 55 | config.n_tasks = len(task_output_space.items()) 56 | 57 | 58 | elif config.dataset=="split_cifar": 59 | config.force_out_dim=0 60 | config.first_split_size=5 61 | config.other_split_size=5 62 | config.is_split=True 63 | import dataloaders.base 64 | Dataset = dataloaders.base.__dict__["CIFAR100"] 65 | # assert config.model_type == "lenet" # CIFAR100 is trained with lenet only 66 | 67 | train_dataset, val_dataset = Dataset(config.dataset_root_path,False, angle=0) 68 | 69 | train_dataset_splits, val_dataset_splits, task_output_space = SplitGen(train_dataset, val_dataset, 70 | first_split_sz=config.first_split_size, 71 | other_split_sz=config.other_split_size, 72 | rand_split=config.rand_split, 73 | remap_class=True) 74 | config.n_tasks=len(train_dataset_splits) 75 | 76 | config.out_dim = {'All': config.force_out_dim} if config.force_out_dim > 0 else task_output_space 77 | 78 | val_loaders = [torch.utils.data.DataLoader(val_dataset_splits[str(task_id)], 79 | batch_size=256,shuffle=False, 80 | num_workers=config.workers) 81 | for task_id in range(1, config.n_tasks + 1)] 82 | 83 | return train_dataset_splits,val_loaders,task_output_space 84 | 85 | 86 | 87 | def test_error(trainer,task_idx): 88 | trainer.model.eval() 89 | acc = 0 90 | acc_cnt = 0 91 | with torch.no_grad(): 92 | for idx, data in enumerate(trainer.val_loaders[task_idx]): 93 | 94 | data, target, task = data 95 | 96 | data = data.to(trainer.config.device) 97 | target = target.to(trainer.config.device) 98 | 99 | outputs = trainer.forward(data,task) 100 | 101 | acc += accuracy(outputs, target) 102 | acc_cnt += float(target.shape[0]) 103 | return acc/acc_cnt 104 | 105 | 106 | def accuracy(outputs,target): 107 | topk=(1,) 108 | with torch.no_grad(): 109 | maxk = max(topk) 110 | 111 | _, pred = outputs.topk(maxk, 1, True, True) 112 | pred = pred.t() 113 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 114 | 115 | res = [] 116 | for k in topk: 117 | correct_k = correct[:k].view(-1).float().sum().item() 118 | res.append(correct_k) 119 | 120 | if len(res)==1: 121 | return res[0] 122 | else: 123 | return res 124 | 125 | 126 | 127 | def parameters_to_grad_vector(parameters): 128 | # Flag for the device where the parameter is located 129 | param_device = None 130 | vec = [] 131 | for param in parameters: 132 | # Ensure the parameters are located in the same device 133 | param_device = _check_param_device(param, param_device) 134 | vec.append(param.grad.view(-1)) 135 | 136 | return torch.cat(vec) 137 | 138 | def count_parameter(model): 139 | return sum(p.numel() for p in model.parameters()) 140 | 141 | 142 | 143 | def grad_vector_to_parameters(vec, parameters): 144 | # Ensure vec of type Tensor 145 | if not isinstance(vec, torch.Tensor): 146 | raise TypeError('expected torch.Tensor, but got: {}' 147 | .format(torch.typename(vec))) 148 | # Flag for the device where the parameter is located 149 | param_device = None 150 | # Pointer for slicing the vector for each parameter 151 | pointer = 0 152 | for param in parameters: 153 | # Ensure the parameters are located in the same device 154 | param_device = _check_param_device(param, param_device) 155 | # The length of the parameter 156 | num_param = param.numel() 157 | # Slice the vector, reshape it, and replace the old data of the parameter 158 | # param.data = vec[pointer:pointer + num_param].view_as(param).data 159 | param.grad = vec[pointer:pointer + num_param].view_as(param).clone() 160 | # Increment the pointer 161 | pointer += num_param 162 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import matplotlib as mpl 2 | mpl.use('Agg') 3 | import argparse 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | import torch 7 | import random 8 | import trainer 9 | import pickle 10 | import os 11 | from utils.utils import get_benchmark_data_loader, test_error 12 | from algos import ogd,ewc,gem 13 | from tqdm.auto import tqdm 14 | 15 | 16 | 17 | parser = argparse.ArgumentParser() 18 | 19 | ### Algo parameters 20 | parser.add_argument("--seed", default=1, type=int) # Sets Gym, PyTorch and Numpy seeds 21 | parser.add_argument("--val_size", default=256, type=int) 22 | parser.add_argument("--nepoch", default=10, type=int) # Number of epoches 23 | parser.add_argument("--batch_size", default=32, type=int) # Batch size 24 | parser.add_argument("--memory_size", default=100, type=int) # size of the memory 25 | parser.add_argument("--hidden_dim", default=100, type=int) # size of the hidden layer 26 | parser.add_argument('--lr',default=1e-3, type=float) 27 | parser.add_argument('--n_tasks',default=15, type=int) 28 | parser.add_argument('--workers',default=2, type=int) 29 | parser.add_argument('--eval_freq',default=1000, type=int) 30 | 31 | ## Methods parameters 32 | parser.add_argument("--all_features",default=0, type=int) # Leave it to 0, this is for the case when using Lenet, projecting orthogonally only against the linear layers seems to work better 33 | 34 | ## Dataset 35 | parser.add_argument('--dataset_root_path',default=" ", type=str,help="path to your dataset ex: /home/usr/datasets/") 36 | parser.add_argument('--subset_size',default=1000, type=int, help="number of samples per class, ex: for MNIST, \ 37 | subset_size=1000 wil results in a dataset of total size 10,000") 38 | parser.add_argument('--dataset',default="split_cifar", type=str) 39 | parser.add_argument("--is_split", action="store_true") 40 | parser.add_argument('--first_split_size',default=5, type=int) 41 | parser.add_argument('--other_split_size',default=5, type=int) 42 | parser.add_argument("--rand_split",default=False, action="store_true") 43 | parser.add_argument('--force_out_dim', type=int, default=10, 44 | help="Set 0 to let the task decide the required output dimension", required=False) 45 | 46 | ## Method 47 | parser.add_argument('--method',default="ogd", type=str,help="sgd,ogd,pca,agem,gem-nt") 48 | 49 | ## PCA-OGD 50 | parser.add_argument('--pca_sample',default=3000, type=int) 51 | ## agem 52 | parser.add_argument("--agem_mem_batch_size", default=256, type=int) # size of the memory 53 | parser.add_argument('--margin_gem',default=0.5, type=float) 54 | ## EWC 55 | parser.add_argument('--ewc_reg',default=10, type=float) 56 | parser.add_argument('--fisher_sample',default=1024, type=int) 57 | 58 | ## Folder / Logging results 59 | parser.add_argument('--save_name',default="result", type=str, help="name of the file") 60 | 61 | config = parser.parse_args() 62 | config.device=torch.device("cuda" if torch.cuda.is_available() else "cpu") 63 | 64 | 65 | 66 | np.set_printoptions(suppress=True) 67 | 68 | config_dict=vars(config) 69 | 70 | ### setting seeds 71 | torch.manual_seed(config.seed) 72 | np.random.seed(config.seed) 73 | random.seed(config.seed) 74 | 75 | torch.backends.cudnn.benchmark=True 76 | torch.backends.cudnn.enabled=True 77 | 78 | config.folder="method_{}_dataset_{}_memory_size_{}_bs_{}_lr_{}_epochs_per_task_{}".format(config.method, \ 79 | config.dataset,config.memory_size,config.batch_size, config.lr,config.nepoch) 80 | 81 | 82 | 83 | ## create folder to log results 84 | if not os.path.exists(config.folder): 85 | os.makedirs(config.folder, exist_ok=True) 86 | 87 | ### name of the file 88 | config.save_name=config.save_name+'_seed_'+str(config.seed) 89 | 90 | 91 | ### dataset path 92 | # config.dataset_root_path="..." 93 | 94 | 95 | ######################################################################################## 96 | ### dataset ############################################################################ 97 | print('loading dataset') 98 | train_dataset_splits,val_loaders,task_output_space=get_benchmark_data_loader(config) 99 | config.out_dim = {'All': config.force_out_dim} if config.force_out_dim > 0 else task_output_space 100 | 101 | 102 | ### loading trainer module 103 | trainer=trainer.Trainer(config,val_loaders) 104 | 105 | 106 | 107 | 108 | 109 | t=0 110 | print('start training') 111 | ######################################################################################## 112 | ### start training ##################################################################### 113 | for task_in in range(config.n_tasks): 114 | rr=0 115 | 116 | train_loader = torch.utils.data.DataLoader(train_dataset_splits[str(task_in+1)], 117 | batch_size=config.batch_size, 118 | shuffle=True, 119 | num_workers=config.workers) 120 | ### train for EPOCH times 121 | print("================== TASK {} / {} =================".format(task_in+1, config.n_tasks)) 122 | for epoch in tqdm(range( config.nepoch), desc="Train task"): 123 | 124 | 125 | trainer.ogd_basis.to(trainer.config.device) 126 | 127 | for i, (input, target, task) in enumerate(train_loader): 128 | 129 | trainer.task_id = int(task[0]) 130 | t+=1 131 | rr+=1 132 | inputs = input.to(trainer.config.device) 133 | target = target.long().to(trainer.config.device) 134 | 135 | out = trainer.forward(inputs,task).to(trainer.config.device) 136 | loss = trainer.criterion(out, target) 137 | 138 | if config.method=="ewc" and (task_in+1)>1: 139 | loss+=config.ewc_reg*ewc.penalty(trainer) 140 | 141 | loss.backward() 142 | trainer.optimizer_step() 143 | ### validation accuracy 144 | 145 | if rr%trainer.config.eval_freq==0: 146 | for element in range(task_in+1): 147 | trainer.acc[element]['test_acc'].append(test_error(trainer,element)) 148 | trainer.acc[element]['training_steps'].append(t) 149 | 150 | 151 | for element in range(task_in+1): 152 | trainer.acc[element]['test_acc'].append(test_error(trainer,element)) 153 | trainer.acc[element]['training_steps'].append(t) 154 | print(" task {} / accuracy: {} ".format(element+1, trainer.acc[element]['test_acc'][-1])) 155 | 156 | 157 | ## update memory at the end of each tasks depending on the method 158 | if config.method in ['ogd','pca']: 159 | trainer.ogd_basis.to(trainer.config.device) 160 | ogd.update_mem(trainer,train_loader,task_in+1) 161 | 162 | if config.method=="agem": 163 | gem.update_agem_memory(trainer,train_loader,task_in+1) 164 | 165 | if config.method=="gem-nt": ## GEM-NT 166 | gem.update_gem_no_transfer_memory(trainer,train_loader,task_in+1) 167 | 168 | if config.method=="ewc": 169 | ewc.update_means(trainer) 170 | ewc.consolidate(trainer,ewc._diag_fisher(trainer,train_loader)) 171 | 172 | 173 | 174 | 175 | ### Plotting accuracies 176 | print('plotting accuracies') 177 | plt.close('all') 178 | for tasks_id in range(len(trainer.acc.items())): 179 | plt.plot(trainer.acc[tasks_id]['training_steps'],trainer.acc[tasks_id]['test_acc']) 180 | plt.grid() 181 | plt.savefig(config.folder+'/'+config.save_name+".png",dpi=72) 182 | 183 | 184 | 185 | 186 | print('Saving results') 187 | output = open(config.folder+'/'+config.save_name+'.p', 'wb') 188 | pickle.dump(trainer.acc, output) 189 | output.close() 190 | 191 | 192 | 193 | 194 | 195 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | from torch.nn.utils.convert_parameters import parameters_to_vector, vector_to_parameters 2 | import torch.nn as nn 3 | import torch 4 | from collections import defaultdict 5 | 6 | 7 | from types import MethodType 8 | from importlib import import_module 9 | 10 | from utils.utils import parameters_to_grad_vector,count_parameter 11 | from copy import deepcopy 12 | from algos import gem,ogd 13 | 14 | 15 | 16 | 17 | class Trainer(object): 18 | def __init__(self, config,val_loaders): 19 | self.config=config 20 | 21 | 22 | self.model = self.create_model() 23 | 24 | self.optimizer = torch.optim.SGD(params=self.model.parameters(),lr=self.config.lr,momentum=0,weight_decay=0) 25 | self.criterion = nn.CrossEntropyLoss() 26 | 27 | 28 | 29 | n_params = count_parameter(self.model) 30 | self.ogd_basis = torch.empty(n_params, 0).to(self.config.device) 31 | 32 | 33 | self.mem_dataset = None 34 | 35 | 36 | self.ogd_basis_ids = defaultdict(lambda: torch.LongTensor([]).to(self.config.device)) 37 | 38 | self.val_loaders = val_loaders 39 | 40 | self.task_count = 0 41 | self.task_memory = {} 42 | 43 | 44 | ### FOR GEM no transfer 45 | self.task_mem_cache = {} 46 | 47 | self.task_grad_memory = {} 48 | self.task_grad_mem_cache = {} 49 | 50 | ### Creating dictionary to save accuracy / can also save loss functions but have not implemented it 51 | self.acc={} 52 | for element in range(self.config.n_tasks): 53 | self.acc[element]={} 54 | self.acc[element]['test_acc']=[] 55 | self.acc[element]['training_acc']=[] 56 | self.acc[element]['training_steps']=[] 57 | 58 | if self.config.method=="gem-nt": 59 | self.quadprog = import_module('quadprog') 60 | self.grad_to_be_saved={} 61 | if self.config.method=="agem": 62 | self.agem_mem = list() 63 | self.agem_mem_loader = None 64 | 65 | 66 | self.gradient_count=0 67 | self.gradient_violation=0 68 | self.eval_freq=config.eval_freq 69 | 70 | if self.config.method=="ewc": 71 | ## split cifar 72 | if self.config.is_split: 73 | if self.config.all_features: 74 | if hasattr(self.model,"conv"): 75 | r=list(self.model.linear.named_parameters())+list(self.model.conv.named_parameters()) 76 | self.params = {n: p for n, p in r if p.requires_grad} 77 | else: 78 | self.params = {n: p for n, p in self.model.linear.named_parameters() if p.requires_grad} 79 | else: 80 | self.params = {n: p for n, p in self.model.linear.named_parameters() if p.requires_grad} 81 | 82 | ### rotated 83 | else: 84 | self.params = {n: p for n, p in self.model.named_parameters() if p.requires_grad} 85 | 86 | self._means = {} 87 | 88 | for n, p in deepcopy(self.params).items(): 89 | self._means[n] = p.data 90 | 91 | self._precision_matrices ={} 92 | for n, p in deepcopy(self.params).items(): 93 | p.data.zero_() 94 | self._precision_matrices[n] = p.data 95 | 96 | 97 | 98 | 99 | def create_model(self): 100 | 101 | cfg = self.config 102 | 103 | 104 | if "cifar" not in cfg.dataset: 105 | import models.mlp 106 | 107 | model = models.mlp.MLP(hidden_dim=cfg.hidden_dim) 108 | 109 | else: 110 | import models.lenet 111 | 112 | model = models.lenet.LeNetC(hidden_dim=cfg.hidden_dim) 113 | 114 | 115 | n_feat = model.last.in_features 116 | 117 | model.last = nn.ModuleDict() 118 | for task,out_dim in cfg.out_dim.items(): 119 | 120 | model.last[task] = nn.Linear(n_feat,out_dim) 121 | 122 | 123 | 124 | # Redefine the task-dependent function 125 | def new_logits(self, x): 126 | outputs = {} 127 | for task, func in self.last.items(): 128 | outputs[task] = func(x) 129 | return outputs 130 | 131 | # Replace the task-dependent function 132 | model.logits = MethodType(new_logits, model) 133 | model.to(self.config.device) 134 | return model 135 | 136 | def forward(self, x, task): 137 | 138 | task_key = task[0] 139 | out = self.model.forward(x) 140 | # print(out) 141 | if self.config.is_split : 142 | try: 143 | return out[task_key] 144 | except: 145 | return out[int(task_key)] 146 | else : 147 | # return out 148 | return out["All"] 149 | 150 | def get_params_dict(self, last, task_key=None): 151 | if self.config.is_split : 152 | if last: 153 | return self.model.last[task_key].parameters() 154 | else: 155 | 156 | if self.config.all_features: 157 | ## take the conv parameters into account 158 | if hasattr(self.model,"conv"): 159 | return list(self.model.linear.parameters())+list(self.model.conv.parameters()) 160 | else: 161 | return self.model.linear.parameters() 162 | 163 | 164 | else: 165 | return self.model.linear.parameters() 166 | # return self.model.linear.parameters() 167 | else: 168 | return self.model.parameters() 169 | 170 | 171 | 172 | 173 | 174 | def optimizer_step(self): 175 | 176 | task_key = str(self.task_id) 177 | 178 | ### take gradients with respect to the parameters 179 | grad_vec = parameters_to_grad_vector(self.get_params_dict(last=False)) 180 | cur_param = parameters_to_vector(self.get_params_dict(last=False)) 181 | 182 | if self.config.method in ['ogd','pca']: 183 | 184 | proj_grad_vec = ogd.project_vec(grad_vec, 185 | proj_basis=self.ogd_basis) 186 | ## take the orthogonal projection 187 | new_grad_vec = grad_vec - proj_grad_vec 188 | 189 | elif self.config.method=="agem" and self.agem_mem_loader is not None : 190 | self.optimizer.zero_grad() 191 | data, target, task = next(iter(self.agem_mem_loader)) 192 | # data = self.to_device(data) 193 | data=data.to(self.config.device) 194 | target = target.long().to(self.config.device) 195 | 196 | output = self.forward(data, task) 197 | mem_loss = self.criterion(output, target) 198 | mem_loss.backward() 199 | mem_grad_vec = parameters_to_grad_vector(self.get_params_dict(last=False)) 200 | 201 | self.gradient_count+=1 202 | 203 | 204 | new_grad_vec = gem._project_agem_grad(self,batch_grad_vec=grad_vec, 205 | mem_grad_vec=mem_grad_vec) 206 | elif self.config.method=="gem-nt": 207 | 208 | if self.task_count >= 1: 209 | 210 | for t,mem in self.task_memory.items(): 211 | self.optimizer.zero_grad() 212 | 213 | mem_out = self.forward(self.task_mem_cache[t]['data'].to(self.config.device),self.task_mem_cache[t]['task']) 214 | 215 | mem_loss = self.criterion(mem_out, self.task_mem_cache[t]['target'].long().to(self.config.device)) 216 | 217 | mem_loss.backward() 218 | 219 | self.task_grad_memory[t]=parameters_to_grad_vector(self.get_params_dict(last=False)) 220 | 221 | mem_grad_vec = torch.stack(list(self.task_grad_memory.values())) 222 | 223 | new_grad_vec = gem.project2cone2(self,grad_vec, mem_grad_vec) 224 | 225 | else: 226 | new_grad_vec = grad_vec 227 | 228 | else: 229 | new_grad_vec = grad_vec 230 | 231 | ### SGD update => new_theta= old_theta - learning_rate x ( derivative of loss function wrt parameters ) 232 | cur_param -= self.config.lr * new_grad_vec#.to(self.config.device) 233 | 234 | vector_to_parameters(cur_param, self.get_params_dict(last=False)) 235 | 236 | if self.config.is_split : 237 | # Update the parameters of the last layer without projection, when there are multiple heads) 238 | cur_param = parameters_to_vector(self.get_params_dict(last=True, task_key=task_key)) 239 | grad_vec = parameters_to_grad_vector(self.get_params_dict(last=True, task_key=task_key)) 240 | cur_param -= self.config.lr * grad_vec 241 | vector_to_parameters(cur_param, self.get_params_dict(last=True, task_key=task_key)) 242 | 243 | ### zero grad 244 | self.optimizer.zero_grad() 245 | 246 | 247 | 248 | --------------------------------------------------------------------------------