├── models
├── __init__.py
├── mlp.py
└── lenet.py
├── dataloaders
├── __init__.py
├── datasetGen.py
├── wrapper.py
└── base.py
├── img
├── cover_image.png
└── hyperparameters_table.jpg
├── utils
├── __pycache__
│ └── utils.cpython-37.pyc
└── utils.py
├── algos
├── common.py
├── ewc.py
├── gem.py
└── ogd.py
├── requirements.txt
├── README.md
├── LICENSE
├── .gitignore
├── scripts
├── script.py
└── commands
│ ├── command_split_mnist.sh
│ ├── command_cifar.sh
│ └── command_rotated_mnist.sh
├── main.py
└── trainer.py
/models/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/dataloaders/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/img/cover_image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tldoan/PCA-OGD/HEAD/img/cover_image.png
--------------------------------------------------------------------------------
/img/hyperparameters_table.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tldoan/PCA-OGD/HEAD/img/hyperparameters_table.jpg
--------------------------------------------------------------------------------
/utils/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tldoan/PCA-OGD/HEAD/utils/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/algos/common.py:
--------------------------------------------------------------------------------
1 |
2 | from dataloaders.wrapper import Storage
3 |
4 | class Memory(Storage):
5 | def reduce(self, m):
6 | self.storage = self.storage[:m]
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | certifi==2020.12.5
2 | chardet==4.0.0
3 | cycler==0.10.0
4 | Cython==0.29.21
5 | filelock==3.0.12
6 | gdown==3.12.2
7 | gnureadline==8.0.0
8 | idna==2.10
9 | kiwisolver==1.3.1
10 | matplotlib==3.3.2
11 | numpy==1.19.2
12 | pandas==1.1.3
13 | Pillow==7.2.0
14 | pyparsing==2.4.7
15 | PySocks==1.7.1
16 | python-dateutil==2.8.1
17 | pytz==2020.5
18 | quadprog==0.1.8
19 | requests==2.25.1
20 | six==1.15.0
21 | torch==1.7.1
22 | torchvision==0.8.2
23 | tqdm==4.56.0
24 | typing-extensions==3.7.4.3
25 | urllib3==1.26.2
26 | virtualenv==16.6.2
27 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PCA-OGD
2 | Official code for [A Theoretical Analysis of Catastrophic Forgetting through the NTK
3 | Overlap Matrix](https://arxiv.org/abs/2010.04003) (AISTATS 2021)
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 | ## Prerequisites
18 | - `requirements.txt`
19 |
20 | ## Instructions
21 |
22 | - Clone the repo
23 | - Run scripts: `scripts/commands`
24 | - Hyperparameters in Table 4 of the Appendix of the [paper](https://arxiv.org/pdf/2010.04003.pdf)
25 |
26 | ## Also includes:
27 | - Implementation of [Orthgohonal Gradient Descent (OGD)](https://arxiv.org/pdf/2010.04003.pdf) (Farajtabar et al., 2019)
28 |
29 |
30 | ## To cite:
31 |
32 | ```
33 | @article{doan2020theoretical,
34 | title={A Theoretical Analysis of Catastrophic Forgetting through the NTK Overlap Matrix},
35 | author={Doan, Thang and Bennani, Mehdi and Mazoure, Bogdan and Rabusseau, Guillaume and Alquier, Pierre},
36 | journal={arXiv preprint arXiv:2010.04003},
37 | year={2020}
38 | }
39 | ```
40 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Thang Doan
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/models/mlp.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class MLP(nn.Module):
5 | def __init__(self, out_dim=10, in_channel=1, img_sz=32, hidden_dim=256):
6 | super(MLP, self).__init__()
7 | self.in_dim = in_channel*img_sz*img_sz
8 | self.linear = nn.Sequential(
9 | nn.Linear(self.in_dim, hidden_dim),
10 | #nn.BatchNorm1d(hidden_dim),
11 | nn.ReLU(inplace=True),
12 | nn.Linear(hidden_dim, hidden_dim),
13 | #nn.BatchNorm1d(hidden_dim),
14 | nn.ReLU(inplace=True),
15 | )
16 | self.last = nn.Linear(hidden_dim, out_dim) # Subject to be replaced dependent on task
17 |
18 | def features(self, x):
19 | x = self.linear(x.view(-1,self.in_dim))
20 | return x
21 |
22 | def logits(self, x):
23 | x = self.last(x)
24 | return x
25 |
26 | def forward(self, x):
27 | x = self.features(x)
28 | x = self.logits(x)
29 | return x
30 |
31 |
32 | def MLP50():
33 | print("\n Using MLP100 \n")
34 | return MLP(hidden_dim=50)
35 |
36 |
37 | def MLP100():
38 | print("\n Using MLP100 \n")
39 | return MLP(hidden_dim=100)
40 |
41 |
42 | def MLP400():
43 | return MLP(hidden_dim=400)
44 |
45 |
46 | def MLP1000():
47 | print("\n Using MLP1000 \n")
48 | return MLP(hidden_dim=1000)
49 |
50 |
51 | def MLP2000():
52 | return MLP(hidden_dim=2000)
53 |
54 |
55 | def MLP5000():
56 | return MLP(hidden_dim=5000)
--------------------------------------------------------------------------------
/models/lenet.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | import torch.nn as nn
4 |
5 |
6 | class LeNet(nn.Module):
7 |
8 | def __init__(self, out_dim=10, in_channel=1, img_sz=32, hidden_dim=500):
9 | super(LeNet, self).__init__()
10 | feat_map_sz = img_sz//4
11 | self.n_feat = 50 * feat_map_sz * feat_map_sz
12 | self.hidden_dim = hidden_dim
13 |
14 | self.conv = nn.Sequential(
15 | nn.Conv2d(in_channel, 20, 5, padding=2),
16 | nn.BatchNorm2d(20),
17 | nn.ReLU(inplace=True),
18 | nn.MaxPool2d(2, 2),
19 | nn.Conv2d(20, 50, 5, padding=2),
20 | nn.BatchNorm2d(50),
21 | nn.ReLU(inplace=True),
22 | nn.MaxPool2d(2, 2),
23 | nn.Flatten(),
24 |
25 | )
26 | self.linear = nn.Sequential(
27 | nn.Linear(self.n_feat, hidden_dim),
28 | nn.BatchNorm1d(hidden_dim),
29 | nn.ReLU(inplace=True),
30 | )
31 |
32 |
33 |
34 | self.last = nn.Linear(hidden_dim, out_dim) # Subject to be replaced dependent on task
35 |
36 | def features(self, x):
37 | x = self.conv(x)
38 |
39 |
40 | x = self.linear(x.view(-1, self.n_feat))
41 |
42 | # x=self.linear(x)
43 | return x
44 |
45 | def logits(self, x):
46 | x = self.last(x)
47 | return x
48 |
49 | def forward(self, x):
50 | x = self.features(x)
51 | x = self.logits(x)
52 | return x
53 |
54 |
55 | def LeNetC(out_dim=10, hidden_dim=500): # LeNet with color input
56 | return LeNet(out_dim=out_dim, in_channel=3, img_sz=32, hidden_dim=hidden_dim)
--------------------------------------------------------------------------------
/algos/ewc.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 | import torch.nn.functional as F
3 | ## toolbox for EWC
4 |
5 | def consolidate(trainer,precision_matrices):
6 | for index in trainer._precision_matrices:
7 | trainer._precision_matrices[index]+=precision_matrices[index]
8 |
9 | def update_means(trainer):
10 |
11 | for n, p in deepcopy(trainer.params).items():
12 | trainer._means[n] = p.data
13 |
14 | def penalty(trainer):
15 | loss = 0
16 |
17 | for index in trainer.params:
18 | _loss = trainer._precision_matrices[index] *0.5* (trainer.params[index] - trainer._means[index]) ** 2
19 | loss += _loss.sum()
20 | return loss
21 |
22 | def _diag_fisher(trainer,train_loader):
23 | precision_matrices = {}
24 |
25 | for n, p in deepcopy(trainer.params).items():
26 | p.data.zero_()
27 | precision_matrices[n] = p.data
28 |
29 | trainer.model.eval()
30 | k=1
31 | for _,element in enumerate(train_loader):
32 | if k>=trainer.config.fisher_sample:
33 | break
34 | trainer.model.zero_grad()
35 | inputs=element[0].to(trainer.config.device)
36 | targets = element[1].long().to(trainer.config.device)
37 |
38 | task=element[2]
39 | out = trainer.forward(inputs,task)
40 | assert out.shape[0] == 1
41 |
42 | pred = out.cpu()
43 |
44 |
45 | loss=F.log_softmax(pred, dim=1)[0][targets.item()]
46 | loss.backward()
47 |
48 | for index in trainer.params:
49 | trainer._precision_matrices[index].data += trainer.params[index].grad.data ** 2 / trainer.config.fisher_sample
50 |
51 | precision_matrices = {n: p for n, p in precision_matrices.items()}
52 | return precision_matrices
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98 | __pypackages__/
99 |
100 | # Celery stuff
101 | celerybeat-schedule
102 | celerybeat.pid
103 |
104 | # SageMath parsed files
105 | *.sage.py
106 |
107 | # Environments
108 | .env
109 | .venv
110 | env/
111 | venv/
112 | ENV/
113 | env.bak/
114 | venv.bak/
115 |
116 | # Spyder project settings
117 | .spyderproject
118 | .spyproject
119 |
120 | # Rope project settings
121 | .ropeproject
122 |
123 | # mkdocs documentation
124 | /site
125 |
126 | # mypy
127 | .mypy_cache/
128 | .dmypy.json
129 | dmypy.json
130 |
131 | # Pyre type checker
132 | .pyre/
133 |
134 | # pytype static type analyzer
135 | .pytype/
136 |
137 | # Cython debug symbols
138 | cython_debug/
139 |
--------------------------------------------------------------------------------
/scripts/script.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | import os
3 |
4 |
5 | prefix_list = ["python main.py "]
6 |
7 | def generate_command(params_dict,prefix_list):
8 |
9 | final_command=[]
10 | for key,value in params_dict.items():
11 | if not isinstance(value,list):
12 | params_dict[key]=[value]
13 |
14 |
15 | keys, values = zip(*params_dict.items())
16 |
17 | params_combination=[(keys,v) for v in itertools.product(*values)]
18 |
19 | for i in params_combination:
20 |
21 | command=prefix_list[0]+" "
22 | for element in range(len(i[0])):
23 | command+=" --"+str(i[0][element])+" "+str(i[1][element])
24 |
25 | final_command.append(command)
26 |
27 | return final_command
28 |
29 |
30 | def replicate_results_split_cifar():
31 | params_dict={
32 | "nepoch":10,
33 | "memory_size":100,
34 | "batch_size":32,
35 | "lr":1e-3,
36 | "subset_size":1000,
37 | "memory_size":100,
38 | "n_tasks":20,
39 | "eval_freq":1000,
40 | "agem_mem_batch_size":256,
41 | "pca_sample":3000,
42 | "dataset":"split_cifar",
43 | "fisher_sample":1024,
44 | "hidden_dim": 200,
45 | "ewc_reg":25,
46 | "seed": [0,1,2,3,4],
47 | "method":["pca","ogd","ewc","sgd","agem"]
48 | }
49 | return params_dict
50 |
51 |
52 | ### rotated and permuted mnist
53 | def replicate_results_rotated_mnist():
54 | params_dict={
55 | "nepoch":10,
56 | "memory_size":100,
57 | "batch_size":32,
58 | "lr":1e-3,
59 | "subset_size":1000,
60 | "memory_size":100,
61 | "n_tasks":15,
62 | "eval_freq":1000,
63 | "agem_mem_batch_size":256,
64 | "pca_sample":3000,
65 | "dataset":"split_cifar",
66 | "fisher_sample":1024,
67 | "hidden_dim": 100,
68 | "ewc_reg":10,
69 | "seed": [0,1,2,3,4],
70 | "method":["pca","ogd","ewc","sgd","agem"]
71 | }
72 | return params_dict
73 |
74 |
75 | def replicate_results_split_mnist():
76 | params_dict={
77 | "nepoch":5,
78 | "memory_size":100,
79 | "batch_size":32,
80 | "lr":1e-3,
81 | "subset_size":2000,
82 | "memory_size":100,
83 | "n_tasks":5,
84 | "eval_freq":1000,
85 | "agem_mem_batch_size":256,
86 | "pca_sample":3000,
87 | "dataset":"split_cifar",
88 | "fisher_sample":1024,
89 | "hidden_dim": 100,
90 | "ewc_reg":10,
91 | "seed": [0,1,2,3,4],
92 | "method":["pca","ogd","ewc","sgd","agem"]
93 | }
94 | return params_dict
95 |
96 |
97 | if __name__ == '__main__':
98 |
99 |
100 | command_split_cifar=generate_command(replicate_results_split_cifar(),prefix_list)
101 | command_split_mnist=generate_command(replicate_results_split_mnist(),prefix_list)
102 | command_rotated_mnist=generate_command(replicate_results_rotated_mnist(),prefix_list)
103 |
104 | carrier=[command_split_cifar,command_split_mnist,command_rotated_mnist]
105 | name=["command_cifar","command_split_mnist","command_rotated_mnist"]
106 | if not os.path.exists("commands"):
107 | os.makedirs("commands", exist_ok=True)
108 | for it in range(len(carrier)):
109 |
110 | with open("commands/{}.sh".format(name[it]), "w") as outfile:
111 | outfile.write("\n".join(carrier[it])+"\n")
--------------------------------------------------------------------------------
/dataloaders/datasetGen.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from random import shuffle
3 | from .wrapper import Subclass, AppendName, Permutation
4 |
5 |
6 | def SplitGen(train_dataset, val_dataset, first_split_sz=2, other_split_sz=2, rand_split=False, remap_class=False):
7 | '''
8 | Generate the dataset splits based on the labels.
9 | :param train_dataset: (torch.utils.data.dataset)
10 | :param val_dataset: (torch.utils.data.dataset)
11 | :param first_split_sz: (int)
12 | :param other_split_sz: (int)
13 | :param rand_split: (bool) Randomize the set of label in each split
14 | :param remap_class: (bool) Ex: remap classes in a split from [2,4,6 ...] to [0,1,2 ...]
15 | :return: train_loaders {task_name:loader}, val_loaders {task_name:loader}, out_dim {task_name:num_classes}
16 | '''
17 | assert train_dataset.number_classes==val_dataset.number_classes,'Train/Val has different number of classes'
18 | num_classes = train_dataset.number_classes
19 |
20 | # Calculate the boundary index of classes for splits
21 | # Ex: [0,2,4,6,8,10] or [0,50,60,70,80,90,100]
22 | split_boundaries = [0, first_split_sz]
23 | while split_boundaries[-1] 0 :
71 | new_data.append(x)
72 | new_targets.append(label)
73 | cnt[label] -= 1
74 |
75 |
76 | train_dataset.dataset = torch.stack(new_data)
77 | train_dataset.labels = torch.Tensor(new_targets)
78 | train_dataset.data = torch.stack(new_data)
79 | train_dataset.targets = torch.Tensor(new_targets)
80 |
81 |
82 | ds = torch.utils.data.TensorDataset(train_dataset.dataset, train_dataset.labels)
83 | ds.root=dataroot
84 | train_dataset = CacheClassLabel(ds)
85 | else:
86 | train_dataset = CacheClassLabel(train_dataset)
87 |
88 | val_dataset = torchvision.datasets.MNIST(
89 | dataroot,
90 | train=False,
91 | transform=val_transform
92 | )
93 | val_dataset = CacheClassLabel(val_dataset)
94 |
95 | return train_dataset, val_dataset
96 |
97 |
98 | def CIFAR10(dataroot, train_aug=False, angle=0):
99 | normalize = transforms.Normalize(mean=[0.491, 0.482, 0.447], std=[0.247, 0.243, 0.262])
100 | rotate = RotationTransform(angle=angle)
101 |
102 | val_transform = transforms.Compose([
103 | rotate,
104 | transforms.ToTensor(),
105 | normalize,
106 | ])
107 | train_transform = val_transform
108 | if train_aug:
109 | train_transform = transforms.Compose([
110 | transforms.RandomCrop(32, padding=4),
111 | transforms.RandomHorizontalFlip(),
112 | rotate,
113 | transforms.ToTensor(),
114 | normalize,
115 | ])
116 |
117 | train_dataset = torchvision.datasets.CIFAR10(
118 | root=dataroot,
119 | train=True,
120 | download=True,
121 | transform=train_transform
122 | )
123 | train_dataset = CacheClassLabel(train_dataset)
124 |
125 | val_dataset = torchvision.datasets.CIFAR10(
126 | root=dataroot,
127 | train=False,
128 | download=True,
129 | transform=val_transform
130 | )
131 | val_dataset = CacheClassLabel(val_dataset)
132 |
133 | return train_dataset, val_dataset
134 |
135 |
136 | def CIFAR100(dataroot, train_aug=False, angle=0):
137 | normalize = transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276])
138 |
139 | val_transform = transforms.Compose([
140 | transforms.ToTensor(),
141 | normalize,
142 | ])
143 | train_transform = val_transform
144 | if train_aug:
145 | train_transform = transforms.Compose([
146 | transforms.RandomCrop(32, padding=4),
147 | transforms.RandomHorizontalFlip(),
148 | transforms.ToTensor(),
149 | normalize,
150 | ])
151 |
152 | train_dataset = torchvision.datasets.CIFAR100(
153 | root=dataroot,
154 | train=True,
155 | download=True,
156 | transform=train_transform
157 | )
158 | train_dataset = CacheClassLabel(train_dataset)
159 |
160 | val_dataset = torchvision.datasets.CIFAR100(
161 | root=dataroot,
162 | train=False,
163 | download=True,
164 | transform=val_transform
165 | )
166 | val_dataset = CacheClassLabel(val_dataset)
167 |
168 | return train_dataset, val_dataset
169 |
--------------------------------------------------------------------------------
/algos/gem.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from utils.utils import parameters_to_grad_vector
3 | from algos.common import Memory
4 | import numpy as np
5 | ## toolbox for GEM/AGEM methods
6 |
7 | def _get_new_gem_m_basis(self, device,optimizer, model,forward):
8 | new_basis = []
9 |
10 |
11 |
12 | for t,mem in self.task_memory.items():
13 |
14 |
15 | for index in range(len(self.task_mem_cache[t]['data'])):
16 |
17 |
18 | inputs=self.task_mem_cache[t]['data'][index].to(device).unsqueeze(0)
19 |
20 | targets=self.task_mem_cache[t]['target'][index].to(device).unsqueeze(0)
21 |
22 | task=self.task_mem_cache[t]['task'][0]
23 |
24 | out = forward(inputs,task)
25 |
26 | mem_loss = self.criterion(out, targets.long().to(self.config.device))
27 | optimizer.zero_grad()
28 | mem_loss.backward()
29 | new_basis.append(parameters_to_grad_vector(self.get_params_dict(last=False)).cpu())
30 |
31 | del out,inputs,targets
32 | torch.cuda.empty_cache()
33 | new_basis = torch.stack(new_basis).T
34 | return new_basis
35 |
36 |
37 | def update_agem_memory(trainer, train_loader,task_id):
38 |
39 | trainer.task_count =task_id
40 | num_sample_per_task = trainer.config.memory_size #// (self.config.n_tasks-1)
41 | randind = torch.randperm(len(train_loader.dataset))[:num_sample_per_task]
42 | for ind in randind: # save it to the memory
43 | trainer.agem_mem.append(train_loader.dataset[ind])
44 |
45 |
46 |
47 | mem_loader_batch_size = min(trainer.config.agem_mem_batch_size, len(trainer.agem_mem))
48 | trainer.agem_mem_loader = torch.utils.data.DataLoader(trainer.agem_mem,
49 | batch_size=mem_loader_batch_size,
50 | shuffle=True,
51 | num_workers=1)
52 |
53 |
54 | def update_gem_no_transfer_memory(trainer,train_loader,task_id):
55 |
56 | trainer.task_count=task_id
57 |
58 | num_sample_per_task = trainer.config.memory_size
59 |
60 | num_sample_per_task = min(len(train_loader.dataset), num_sample_per_task)
61 |
62 | trainer.task_memory[trainer.task_count] = Memory()
63 | randind = torch.randperm(len(train_loader.dataset))[:num_sample_per_task] # randomly sample some data
64 | for ind in randind: # save it to the memory
65 | trainer.task_memory[trainer.task_count].append(train_loader.dataset[ind])
66 |
67 |
68 | for t, mem in trainer.task_memory.items():
69 |
70 | mem_loader = torch.utils.data.DataLoader(mem,
71 | batch_size=len(mem),
72 | shuffle=False,
73 | num_workers=1)
74 | assert len(mem_loader) == 1, 'The length of mem_loader should be 1'
75 | for i, (mem_input, mem_target, mem_task) in enumerate(mem_loader):
76 | pass
77 |
78 | trainer.task_mem_cache[t] = {'data': mem_input, 'target': mem_target, 'task': mem_task}
79 |
80 |
81 |
82 |
83 |
84 | def _project_agem_grad(trainer, batch_grad_vec, mem_grad_vec):
85 |
86 | if torch.dot(batch_grad_vec, mem_grad_vec) >= 0:
87 | return batch_grad_vec
88 | else :
89 | trainer.gradient_violation+=1
90 |
91 |
92 |
93 | frac = torch.dot(batch_grad_vec, mem_grad_vec) / torch.dot(mem_grad_vec, mem_grad_vec)
94 |
95 | new_grad = batch_grad_vec - frac * mem_grad_vec
96 |
97 | check = torch.dot(new_grad, mem_grad_vec)
98 | assert torch.abs(check) < 1e-5
99 | return new_grad
100 |
101 | def project2cone2(trainer, gradient, memories):
102 | """
103 | Solves the GEM dual QP described in the paper given a proposed
104 | gradient "gradient", and a memory of task gradients "memories".
105 | Overwrites "gradient" with the final projected update.
106 | input: gradient, p-vector
107 | input: memories, (t * p)-vector
108 | output: x, p-vector
109 | Modified from: https://github.com/facebookresearch/GradientEpisodicMemory/blob/master/model/gem.py#L70
110 | """
111 |
112 | margin = trainer.config.margin_gem
113 | memories_np = memories.cpu().contiguous().double().numpy()
114 | gradient_np = gradient.cpu().contiguous().view(-1).double().numpy()
115 | t = memories_np.shape[0]
116 |
117 | P = np.dot(memories_np, memories_np.transpose())
118 | P = 0.5 * (P + P.transpose())
119 | q = np.dot(memories_np, gradient_np) * -1
120 | G = np.eye(t)
121 | P = P + G * 0.001
122 | h = np.zeros(t) + margin
123 | # print(P)
124 | v = trainer.quadprog.solve_qp(P, q, G, h)[0]
125 | x = np.dot(v, memories_np) + gradient_np
126 | new_grad = torch.Tensor(x).view(-1)
127 |
128 | new_grad = new_grad.to(trainer.config.device)
129 | return new_grad
--------------------------------------------------------------------------------
/scripts/commands/command_split_mnist.sh:
--------------------------------------------------------------------------------
1 | python main.py --nepoch 5 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 2000 --n_tasks 5 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 0 --method pca
2 | python main.py --nepoch 5 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 2000 --n_tasks 5 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 0 --method ogd
3 | python main.py --nepoch 5 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 2000 --n_tasks 5 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 0 --method ewc
4 | python main.py --nepoch 5 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 2000 --n_tasks 5 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 0 --method sgd
5 | python main.py --nepoch 5 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 2000 --n_tasks 5 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 0 --method agem
6 | python main.py --nepoch 5 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 2000 --n_tasks 5 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 1 --method pca
7 | python main.py --nepoch 5 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 2000 --n_tasks 5 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 1 --method ogd
8 | python main.py --nepoch 5 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 2000 --n_tasks 5 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 1 --method ewc
9 | python main.py --nepoch 5 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 2000 --n_tasks 5 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 1 --method sgd
10 | python main.py --nepoch 5 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 2000 --n_tasks 5 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 1 --method agem
11 | python main.py --nepoch 5 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 2000 --n_tasks 5 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 2 --method pca
12 | python main.py --nepoch 5 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 2000 --n_tasks 5 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 2 --method ogd
13 | python main.py --nepoch 5 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 2000 --n_tasks 5 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 2 --method ewc
14 | python main.py --nepoch 5 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 2000 --n_tasks 5 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 2 --method sgd
15 | python main.py --nepoch 5 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 2000 --n_tasks 5 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 2 --method agem
16 | python main.py --nepoch 5 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 2000 --n_tasks 5 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 3 --method pca
17 | python main.py --nepoch 5 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 2000 --n_tasks 5 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 3 --method ogd
18 | python main.py --nepoch 5 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 2000 --n_tasks 5 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 3 --method ewc
19 | python main.py --nepoch 5 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 2000 --n_tasks 5 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 3 --method sgd
20 | python main.py --nepoch 5 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 2000 --n_tasks 5 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 3 --method agem
21 | python main.py --nepoch 5 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 2000 --n_tasks 5 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 4 --method pca
22 | python main.py --nepoch 5 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 2000 --n_tasks 5 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 4 --method ogd
23 | python main.py --nepoch 5 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 2000 --n_tasks 5 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 4 --method ewc
24 | python main.py --nepoch 5 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 2000 --n_tasks 5 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 4 --method sgd
25 | python main.py --nepoch 5 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 2000 --n_tasks 5 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 4 --method agem
26 |
--------------------------------------------------------------------------------
/scripts/commands/command_cifar.sh:
--------------------------------------------------------------------------------
1 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 20 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 200 --ewc_reg 25 --seed 0 --method pca
2 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 20 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 200 --ewc_reg 25 --seed 0 --method ogd
3 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 20 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 200 --ewc_reg 25 --seed 0 --method ewc
4 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 20 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 200 --ewc_reg 25 --seed 0 --method sgd
5 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 20 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 200 --ewc_reg 25 --seed 0 --method agem
6 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 20 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 200 --ewc_reg 25 --seed 1 --method pca
7 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 20 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 200 --ewc_reg 25 --seed 1 --method ogd
8 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 20 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 200 --ewc_reg 25 --seed 1 --method ewc
9 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 20 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 200 --ewc_reg 25 --seed 1 --method sgd
10 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 20 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 200 --ewc_reg 25 --seed 1 --method agem
11 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 20 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 200 --ewc_reg 25 --seed 2 --method pca
12 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 20 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 200 --ewc_reg 25 --seed 2 --method ogd
13 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 20 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 200 --ewc_reg 25 --seed 2 --method ewc
14 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 20 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 200 --ewc_reg 25 --seed 2 --method sgd
15 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 20 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 200 --ewc_reg 25 --seed 2 --method agem
16 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 20 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 200 --ewc_reg 25 --seed 3 --method pca
17 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 20 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 200 --ewc_reg 25 --seed 3 --method ogd
18 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 20 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 200 --ewc_reg 25 --seed 3 --method ewc
19 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 20 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 200 --ewc_reg 25 --seed 3 --method sgd
20 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 20 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 200 --ewc_reg 25 --seed 3 --method agem
21 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 20 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 200 --ewc_reg 25 --seed 4 --method pca
22 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 20 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 200 --ewc_reg 25 --seed 4 --method ogd
23 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 20 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 200 --ewc_reg 25 --seed 4 --method ewc
24 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 20 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 200 --ewc_reg 25 --seed 4 --method sgd
25 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 20 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 200 --ewc_reg 25 --seed 4 --method agem
26 |
--------------------------------------------------------------------------------
/scripts/commands/command_rotated_mnist.sh:
--------------------------------------------------------------------------------
1 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 15 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 0 --method pca
2 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 15 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 0 --method ogd
3 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 15 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 0 --method ewc
4 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 15 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 0 --method sgd
5 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 15 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 0 --method agem
6 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 15 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 1 --method pca
7 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 15 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 1 --method ogd
8 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 15 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 1 --method ewc
9 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 15 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 1 --method sgd
10 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 15 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 1 --method agem
11 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 15 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 2 --method pca
12 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 15 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 2 --method ogd
13 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 15 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 2 --method ewc
14 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 15 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 2 --method sgd
15 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 15 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 2 --method agem
16 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 15 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 3 --method pca
17 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 15 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 3 --method ogd
18 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 15 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 3 --method ewc
19 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 15 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 3 --method sgd
20 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 15 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 3 --method agem
21 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 15 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 4 --method pca
22 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 15 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 4 --method ogd
23 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 15 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 4 --method ewc
24 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 15 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 4 --method sgd
25 | python main.py --nepoch 10 --memory_size 100 --batch_size 32 --lr 0.001 --subset_size 1000 --n_tasks 15 --eval_freq 1000 --agem_mem_batch_size 256 --pca_sample 3000 --dataset split_cifar --fisher_sample 1024 --hidden_dim 100 --ewc_reg 10 --seed 4 --method agem
26 |
--------------------------------------------------------------------------------
/algos/ogd.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from algos.common import Memory
3 | from utils.utils import parameters_to_grad_vector, count_parameter
4 |
5 | def _get_new_ogd_basis(trainer,train_loader, device,optimizer, model,forward):
6 | new_basis = []
7 |
8 |
9 | for _,element in enumerate(train_loader):
10 |
11 | inputs=element[0].to(device)
12 |
13 | targets = element[1].to(device)
14 |
15 | task=element[2]
16 |
17 | out = forward(inputs,task)
18 |
19 | assert out.shape[0] == 1
20 |
21 | pred = out[0,int(targets.item())].cpu()
22 |
23 | optimizer.zero_grad()
24 | pred.backward()
25 |
26 | ### retrieve \nabla f(x) wrt theta
27 | new_basis.append(parameters_to_grad_vector(trainer.get_params_dict(last=False)).cpu())
28 |
29 | del out,inputs,targets
30 | torch.cuda.empty_cache()
31 | new_basis = torch.stack(new_basis).T
32 |
33 | return new_basis
34 |
35 | def project_vec(vec, proj_basis):
36 | if proj_basis.shape[1] > 0 : # param x basis_size
37 | dots = torch.matmul(vec, proj_basis) # basis_size dots= [ for i in proj_basis ]
38 | out = torch.matmul(proj_basis, dots) # out = [ i for i in proj_basis ]
39 | return out
40 | else:
41 | return torch.zeros_like(vec)
42 |
43 | def update_mem(trainer,train_loader,task_count):
44 |
45 | trainer.task_count =task_count
46 |
47 |
48 | num_sample_per_task = trainer.config.memory_size # // (self.config.n_tasks-1)
49 | num_sample_per_task = min(len(train_loader.dataset), num_sample_per_task)
50 |
51 | memory_length=[]
52 | for i in range(task_count):
53 | memory_length.append(num_sample_per_task)
54 |
55 |
56 | for storage in trainer.task_memory.values():
57 | ## reduce the size of the stored elements
58 | storage.reduce(num_sample_per_task)
59 |
60 |
61 | trainer.task_memory[0] = Memory() # Initialize the memory slot
62 |
63 | if trainer.config.method=="pca":
64 | randind = torch.randperm(len(train_loader.dataset))[:trainer.config.pca_sample] ## for pca method we samples pca_samples > num_sample_per_task before applying pca and keeping (num_sample_per_task) elements
65 | else:
66 | randind = torch.randperm(len(train_loader.dataset))[:num_sample_per_task] # randomly sample some data
67 | for ind in randind: # save it to the memory
68 |
69 | trainer.task_memory[0].append(train_loader.dataset[ind])
70 |
71 |
72 | ####################################### Grads MEM ###########################
73 |
74 | for storage in trainer.task_grad_memory.values():
75 | storage.reduce(num_sample_per_task)
76 |
77 |
78 |
79 | if trainer.config.method in ['ogd','pca']:
80 | ogd_train_loader = torch.utils.data.DataLoader(trainer.task_memory[0],
81 | batch_size=1,
82 | shuffle=False,
83 | num_workers=1)
84 |
85 |
86 |
87 | trainer.task_memory[0] = Memory()
88 |
89 | new_basis_tensor = _get_new_ogd_basis(trainer,
90 | ogd_train_loader,
91 | trainer.config.device,
92 | trainer.optimizer,
93 | trainer.model,
94 | trainer.forward).cpu()
95 |
96 |
97 |
98 | if trainer.config.method=="pca":
99 |
100 | try:
101 | _,_,v1=torch.pca_lowrank(new_basis_tensor.T.cpu(), q=num_sample_per_task, center=True, niter=2)
102 |
103 | except:
104 | _,_,v1=torch.svd_lowrank((new_basis_tensor.T+1e-4*new_basis_tensor.T.mean()*torch.rand(new_basis_tensor.T.size(0), new_basis_tensor.T.size(1))).cpu(), q=num_sample_per_task, niter=2, M=None)
105 |
106 |
107 |
108 | del new_basis_tensor
109 | new_basis_tensor=v1.cpu()
110 | torch.cuda.empty_cache()
111 |
112 | if trainer.config.is_split:
113 | if trainer.config.all_features:
114 | if hasattr(trainer.model,"conv"):
115 | n_params = count_parameter(trainer.model.linear)+count_parameter(trainer.model.conv)
116 | else:
117 | n_params = count_parameter(trainer.model.linear)
118 | else:
119 |
120 | n_params = count_parameter(trainer.model.linear)
121 |
122 | else:
123 | n_params = count_parameter(trainer.model)
124 |
125 |
126 |
127 | trainer.ogd_basis = torch.empty(n_params, 0).cpu()
128 |
129 |
130 | for t, mem in trainer.task_grad_memory.items():
131 |
132 | task_ogd_basis_tensor=torch.stack(mem.storage,axis=1).cpu()
133 |
134 | trainer.ogd_basis = torch.cat([trainer.ogd_basis, task_ogd_basis_tensor], axis=1).cpu()
135 |
136 |
137 | trainer.ogd_basis=orthonormalize(trainer.ogd_basis,new_basis_tensor,trainer.config.device,normalize=True)
138 |
139 |
140 | # (g) Store in the new basis
141 | ptr = 0
142 |
143 | for t in range(len(memory_length)):
144 |
145 |
146 | task_mem_size=memory_length[t]
147 | idxs_list = [i + ptr for i in range(task_mem_size)]
148 |
149 | trainer.ogd_basis_ids[t] = torch.LongTensor(idxs_list).to(trainer.config.device)
150 |
151 |
152 | trainer.task_grad_memory[t] = Memory() # Initialize the memory slot
153 |
154 |
155 |
156 |
157 | if trainer.config.method=="pca":
158 | length=num_sample_per_task
159 | else:
160 | length=task_mem_size
161 | for ind in range(length): # save it to the memory
162 | trainer.task_grad_memory[t].append(trainer.ogd_basis[:, ptr].cpu())
163 | ptr += 1
164 |
165 |
166 | def orthonormalize(main_vectors, additional_vectors,device,normalize=True):
167 | ## orthnormalize the basis (graham schmidt)
168 | for element in range(additional_vectors.size()[1]):
169 |
170 |
171 | coeff=torch.mv(main_vectors.t(),additional_vectors[:,element]) ## x - y/ ||||
172 | pv=torch.mv(main_vectors, coeff)
173 | d=(additional_vectors[:,element]-pv)/torch.norm(additional_vectors[:,element]-pv,p=2)
174 | main_vectors=torch.cat((main_vectors,d.view(-1,1)),dim=1)
175 | del pv
176 | del d
177 | return main_vectors.to(device)
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import dataloaders.base
2 | from dataloaders.datasetGen import SplitGen, PermutedGen,RotatedGen
3 | from torch.nn.utils.convert_parameters import _check_param_device, parameters_to_vector, vector_to_parameters
4 | import torch
5 |
6 |
7 |
8 | def get_benchmark_data_loader(config):
9 | ## example: config.dataset_root_path='/home/usr/dataset/CIFAR100'
10 | config.dataset_root_path=''
11 | if config.dataset=="permuted":
12 | config.force_out_dim=10
13 | train_dataset, val_dataset = dataloaders.base.__dict__['MNIST'](config.dataset_root_path, False ,subset_size=config.subset_size)
14 |
15 | train_dataset_splits, val_dataset_splits, task_output_space = PermutedGen(train_dataset, val_dataset,config.n_tasks,remap_class= False)
16 |
17 | elif config.dataset=="rotated":
18 | config.force_out_dim=10
19 | import dataloaders.base
20 |
21 | Dataset = dataloaders.base.__dict__["MNIST"]
22 | n_rotate=config.n_tasks
23 |
24 | rotate_step=5
25 |
26 |
27 | train_dataset_splits, val_dataset_splits, task_output_space = RotatedGen(Dataset=Dataset,
28 | dataroot=config.dataset_root_path,
29 | train_aug=False,
30 | n_rotate=n_rotate,
31 | rotate_step=rotate_step,
32 | remap_class=False
33 | ,subset_size=config.subset_size)
34 |
35 | elif config.dataset=="split_mnist":
36 | config.first_split_size=2
37 | config.other_split_size=2
38 | config.force_out_dim=0
39 | config.is_split=True
40 | import dataloaders.base
41 | Dataset = dataloaders.base.__dict__["MNIST"]
42 |
43 |
44 |
45 | if config.subset_size<50000:
46 | train_dataset, val_dataset = Dataset(config.dataset_root_path,False, angle=0,noise=None,subset_size=config.subset_size)
47 | else:
48 | train_dataset, val_dataset = Dataset(config.dataset_root_path,False, angle=0,noise=None)
49 | train_dataset_splits, val_dataset_splits, task_output_space = SplitGen(train_dataset, val_dataset,
50 | first_split_sz=config.first_split_size,
51 | other_split_sz=config.other_split_size,
52 | rand_split=config.rand_split,
53 | remap_class=True)
54 |
55 | config.n_tasks = len(task_output_space.items())
56 |
57 |
58 | elif config.dataset=="split_cifar":
59 | config.force_out_dim=0
60 | config.first_split_size=5
61 | config.other_split_size=5
62 | config.is_split=True
63 | import dataloaders.base
64 | Dataset = dataloaders.base.__dict__["CIFAR100"]
65 | # assert config.model_type == "lenet" # CIFAR100 is trained with lenet only
66 |
67 | train_dataset, val_dataset = Dataset(config.dataset_root_path,False, angle=0)
68 |
69 | train_dataset_splits, val_dataset_splits, task_output_space = SplitGen(train_dataset, val_dataset,
70 | first_split_sz=config.first_split_size,
71 | other_split_sz=config.other_split_size,
72 | rand_split=config.rand_split,
73 | remap_class=True)
74 | config.n_tasks=len(train_dataset_splits)
75 |
76 | config.out_dim = {'All': config.force_out_dim} if config.force_out_dim > 0 else task_output_space
77 |
78 | val_loaders = [torch.utils.data.DataLoader(val_dataset_splits[str(task_id)],
79 | batch_size=256,shuffle=False,
80 | num_workers=config.workers)
81 | for task_id in range(1, config.n_tasks + 1)]
82 |
83 | return train_dataset_splits,val_loaders,task_output_space
84 |
85 |
86 |
87 | def test_error(trainer,task_idx):
88 | trainer.model.eval()
89 | acc = 0
90 | acc_cnt = 0
91 | with torch.no_grad():
92 | for idx, data in enumerate(trainer.val_loaders[task_idx]):
93 |
94 | data, target, task = data
95 |
96 | data = data.to(trainer.config.device)
97 | target = target.to(trainer.config.device)
98 |
99 | outputs = trainer.forward(data,task)
100 |
101 | acc += accuracy(outputs, target)
102 | acc_cnt += float(target.shape[0])
103 | return acc/acc_cnt
104 |
105 |
106 | def accuracy(outputs,target):
107 | topk=(1,)
108 | with torch.no_grad():
109 | maxk = max(topk)
110 |
111 | _, pred = outputs.topk(maxk, 1, True, True)
112 | pred = pred.t()
113 | correct = pred.eq(target.view(1, -1).expand_as(pred))
114 |
115 | res = []
116 | for k in topk:
117 | correct_k = correct[:k].view(-1).float().sum().item()
118 | res.append(correct_k)
119 |
120 | if len(res)==1:
121 | return res[0]
122 | else:
123 | return res
124 |
125 |
126 |
127 | def parameters_to_grad_vector(parameters):
128 | # Flag for the device where the parameter is located
129 | param_device = None
130 | vec = []
131 | for param in parameters:
132 | # Ensure the parameters are located in the same device
133 | param_device = _check_param_device(param, param_device)
134 | vec.append(param.grad.view(-1))
135 |
136 | return torch.cat(vec)
137 |
138 | def count_parameter(model):
139 | return sum(p.numel() for p in model.parameters())
140 |
141 |
142 |
143 | def grad_vector_to_parameters(vec, parameters):
144 | # Ensure vec of type Tensor
145 | if not isinstance(vec, torch.Tensor):
146 | raise TypeError('expected torch.Tensor, but got: {}'
147 | .format(torch.typename(vec)))
148 | # Flag for the device where the parameter is located
149 | param_device = None
150 | # Pointer for slicing the vector for each parameter
151 | pointer = 0
152 | for param in parameters:
153 | # Ensure the parameters are located in the same device
154 | param_device = _check_param_device(param, param_device)
155 | # The length of the parameter
156 | num_param = param.numel()
157 | # Slice the vector, reshape it, and replace the old data of the parameter
158 | # param.data = vec[pointer:pointer + num_param].view_as(param).data
159 | param.grad = vec[pointer:pointer + num_param].view_as(param).clone()
160 | # Increment the pointer
161 | pointer += num_param
162 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import matplotlib as mpl
2 | mpl.use('Agg')
3 | import argparse
4 | import numpy as np
5 | import matplotlib.pyplot as plt
6 | import torch
7 | import random
8 | import trainer
9 | import pickle
10 | import os
11 | from utils.utils import get_benchmark_data_loader, test_error
12 | from algos import ogd,ewc,gem
13 | from tqdm.auto import tqdm
14 |
15 |
16 |
17 | parser = argparse.ArgumentParser()
18 |
19 | ### Algo parameters
20 | parser.add_argument("--seed", default=1, type=int) # Sets Gym, PyTorch and Numpy seeds
21 | parser.add_argument("--val_size", default=256, type=int)
22 | parser.add_argument("--nepoch", default=10, type=int) # Number of epoches
23 | parser.add_argument("--batch_size", default=32, type=int) # Batch size
24 | parser.add_argument("--memory_size", default=100, type=int) # size of the memory
25 | parser.add_argument("--hidden_dim", default=100, type=int) # size of the hidden layer
26 | parser.add_argument('--lr',default=1e-3, type=float)
27 | parser.add_argument('--n_tasks',default=15, type=int)
28 | parser.add_argument('--workers',default=2, type=int)
29 | parser.add_argument('--eval_freq',default=1000, type=int)
30 |
31 | ## Methods parameters
32 | parser.add_argument("--all_features",default=0, type=int) # Leave it to 0, this is for the case when using Lenet, projecting orthogonally only against the linear layers seems to work better
33 |
34 | ## Dataset
35 | parser.add_argument('--dataset_root_path',default=" ", type=str,help="path to your dataset ex: /home/usr/datasets/")
36 | parser.add_argument('--subset_size',default=1000, type=int, help="number of samples per class, ex: for MNIST, \
37 | subset_size=1000 wil results in a dataset of total size 10,000")
38 | parser.add_argument('--dataset',default="split_cifar", type=str)
39 | parser.add_argument("--is_split", action="store_true")
40 | parser.add_argument('--first_split_size',default=5, type=int)
41 | parser.add_argument('--other_split_size',default=5, type=int)
42 | parser.add_argument("--rand_split",default=False, action="store_true")
43 | parser.add_argument('--force_out_dim', type=int, default=10,
44 | help="Set 0 to let the task decide the required output dimension", required=False)
45 |
46 | ## Method
47 | parser.add_argument('--method',default="ogd", type=str,help="sgd,ogd,pca,agem,gem-nt")
48 |
49 | ## PCA-OGD
50 | parser.add_argument('--pca_sample',default=3000, type=int)
51 | ## agem
52 | parser.add_argument("--agem_mem_batch_size", default=256, type=int) # size of the memory
53 | parser.add_argument('--margin_gem',default=0.5, type=float)
54 | ## EWC
55 | parser.add_argument('--ewc_reg',default=10, type=float)
56 | parser.add_argument('--fisher_sample',default=1024, type=int)
57 |
58 | ## Folder / Logging results
59 | parser.add_argument('--save_name',default="result", type=str, help="name of the file")
60 |
61 | config = parser.parse_args()
62 | config.device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
63 |
64 |
65 |
66 | np.set_printoptions(suppress=True)
67 |
68 | config_dict=vars(config)
69 |
70 | ### setting seeds
71 | torch.manual_seed(config.seed)
72 | np.random.seed(config.seed)
73 | random.seed(config.seed)
74 |
75 | torch.backends.cudnn.benchmark=True
76 | torch.backends.cudnn.enabled=True
77 |
78 | config.folder="method_{}_dataset_{}_memory_size_{}_bs_{}_lr_{}_epochs_per_task_{}".format(config.method, \
79 | config.dataset,config.memory_size,config.batch_size, config.lr,config.nepoch)
80 |
81 |
82 |
83 | ## create folder to log results
84 | if not os.path.exists(config.folder):
85 | os.makedirs(config.folder, exist_ok=True)
86 |
87 | ### name of the file
88 | config.save_name=config.save_name+'_seed_'+str(config.seed)
89 |
90 |
91 | ### dataset path
92 | # config.dataset_root_path="..."
93 |
94 |
95 | ########################################################################################
96 | ### dataset ############################################################################
97 | print('loading dataset')
98 | train_dataset_splits,val_loaders,task_output_space=get_benchmark_data_loader(config)
99 | config.out_dim = {'All': config.force_out_dim} if config.force_out_dim > 0 else task_output_space
100 |
101 |
102 | ### loading trainer module
103 | trainer=trainer.Trainer(config,val_loaders)
104 |
105 |
106 |
107 |
108 |
109 | t=0
110 | print('start training')
111 | ########################################################################################
112 | ### start training #####################################################################
113 | for task_in in range(config.n_tasks):
114 | rr=0
115 |
116 | train_loader = torch.utils.data.DataLoader(train_dataset_splits[str(task_in+1)],
117 | batch_size=config.batch_size,
118 | shuffle=True,
119 | num_workers=config.workers)
120 | ### train for EPOCH times
121 | print("================== TASK {} / {} =================".format(task_in+1, config.n_tasks))
122 | for epoch in tqdm(range( config.nepoch), desc="Train task"):
123 |
124 |
125 | trainer.ogd_basis.to(trainer.config.device)
126 |
127 | for i, (input, target, task) in enumerate(train_loader):
128 |
129 | trainer.task_id = int(task[0])
130 | t+=1
131 | rr+=1
132 | inputs = input.to(trainer.config.device)
133 | target = target.long().to(trainer.config.device)
134 |
135 | out = trainer.forward(inputs,task).to(trainer.config.device)
136 | loss = trainer.criterion(out, target)
137 |
138 | if config.method=="ewc" and (task_in+1)>1:
139 | loss+=config.ewc_reg*ewc.penalty(trainer)
140 |
141 | loss.backward()
142 | trainer.optimizer_step()
143 | ### validation accuracy
144 |
145 | if rr%trainer.config.eval_freq==0:
146 | for element in range(task_in+1):
147 | trainer.acc[element]['test_acc'].append(test_error(trainer,element))
148 | trainer.acc[element]['training_steps'].append(t)
149 |
150 |
151 | for element in range(task_in+1):
152 | trainer.acc[element]['test_acc'].append(test_error(trainer,element))
153 | trainer.acc[element]['training_steps'].append(t)
154 | print(" task {} / accuracy: {} ".format(element+1, trainer.acc[element]['test_acc'][-1]))
155 |
156 |
157 | ## update memory at the end of each tasks depending on the method
158 | if config.method in ['ogd','pca']:
159 | trainer.ogd_basis.to(trainer.config.device)
160 | ogd.update_mem(trainer,train_loader,task_in+1)
161 |
162 | if config.method=="agem":
163 | gem.update_agem_memory(trainer,train_loader,task_in+1)
164 |
165 | if config.method=="gem-nt": ## GEM-NT
166 | gem.update_gem_no_transfer_memory(trainer,train_loader,task_in+1)
167 |
168 | if config.method=="ewc":
169 | ewc.update_means(trainer)
170 | ewc.consolidate(trainer,ewc._diag_fisher(trainer,train_loader))
171 |
172 |
173 |
174 |
175 | ### Plotting accuracies
176 | print('plotting accuracies')
177 | plt.close('all')
178 | for tasks_id in range(len(trainer.acc.items())):
179 | plt.plot(trainer.acc[tasks_id]['training_steps'],trainer.acc[tasks_id]['test_acc'])
180 | plt.grid()
181 | plt.savefig(config.folder+'/'+config.save_name+".png",dpi=72)
182 |
183 |
184 |
185 |
186 | print('Saving results')
187 | output = open(config.folder+'/'+config.save_name+'.p', 'wb')
188 | pickle.dump(trainer.acc, output)
189 | output.close()
190 |
191 |
192 |
193 |
194 |
195 |
--------------------------------------------------------------------------------
/trainer.py:
--------------------------------------------------------------------------------
1 | from torch.nn.utils.convert_parameters import parameters_to_vector, vector_to_parameters
2 | import torch.nn as nn
3 | import torch
4 | from collections import defaultdict
5 |
6 |
7 | from types import MethodType
8 | from importlib import import_module
9 |
10 | from utils.utils import parameters_to_grad_vector,count_parameter
11 | from copy import deepcopy
12 | from algos import gem,ogd
13 |
14 |
15 |
16 |
17 | class Trainer(object):
18 | def __init__(self, config,val_loaders):
19 | self.config=config
20 |
21 |
22 | self.model = self.create_model()
23 |
24 | self.optimizer = torch.optim.SGD(params=self.model.parameters(),lr=self.config.lr,momentum=0,weight_decay=0)
25 | self.criterion = nn.CrossEntropyLoss()
26 |
27 |
28 |
29 | n_params = count_parameter(self.model)
30 | self.ogd_basis = torch.empty(n_params, 0).to(self.config.device)
31 |
32 |
33 | self.mem_dataset = None
34 |
35 |
36 | self.ogd_basis_ids = defaultdict(lambda: torch.LongTensor([]).to(self.config.device))
37 |
38 | self.val_loaders = val_loaders
39 |
40 | self.task_count = 0
41 | self.task_memory = {}
42 |
43 |
44 | ### FOR GEM no transfer
45 | self.task_mem_cache = {}
46 |
47 | self.task_grad_memory = {}
48 | self.task_grad_mem_cache = {}
49 |
50 | ### Creating dictionary to save accuracy / can also save loss functions but have not implemented it
51 | self.acc={}
52 | for element in range(self.config.n_tasks):
53 | self.acc[element]={}
54 | self.acc[element]['test_acc']=[]
55 | self.acc[element]['training_acc']=[]
56 | self.acc[element]['training_steps']=[]
57 |
58 | if self.config.method=="gem-nt":
59 | self.quadprog = import_module('quadprog')
60 | self.grad_to_be_saved={}
61 | if self.config.method=="agem":
62 | self.agem_mem = list()
63 | self.agem_mem_loader = None
64 |
65 |
66 | self.gradient_count=0
67 | self.gradient_violation=0
68 | self.eval_freq=config.eval_freq
69 |
70 | if self.config.method=="ewc":
71 | ## split cifar
72 | if self.config.is_split:
73 | if self.config.all_features:
74 | if hasattr(self.model,"conv"):
75 | r=list(self.model.linear.named_parameters())+list(self.model.conv.named_parameters())
76 | self.params = {n: p for n, p in r if p.requires_grad}
77 | else:
78 | self.params = {n: p for n, p in self.model.linear.named_parameters() if p.requires_grad}
79 | else:
80 | self.params = {n: p for n, p in self.model.linear.named_parameters() if p.requires_grad}
81 |
82 | ### rotated
83 | else:
84 | self.params = {n: p for n, p in self.model.named_parameters() if p.requires_grad}
85 |
86 | self._means = {}
87 |
88 | for n, p in deepcopy(self.params).items():
89 | self._means[n] = p.data
90 |
91 | self._precision_matrices ={}
92 | for n, p in deepcopy(self.params).items():
93 | p.data.zero_()
94 | self._precision_matrices[n] = p.data
95 |
96 |
97 |
98 |
99 | def create_model(self):
100 |
101 | cfg = self.config
102 |
103 |
104 | if "cifar" not in cfg.dataset:
105 | import models.mlp
106 |
107 | model = models.mlp.MLP(hidden_dim=cfg.hidden_dim)
108 |
109 | else:
110 | import models.lenet
111 |
112 | model = models.lenet.LeNetC(hidden_dim=cfg.hidden_dim)
113 |
114 |
115 | n_feat = model.last.in_features
116 |
117 | model.last = nn.ModuleDict()
118 | for task,out_dim in cfg.out_dim.items():
119 |
120 | model.last[task] = nn.Linear(n_feat,out_dim)
121 |
122 |
123 |
124 | # Redefine the task-dependent function
125 | def new_logits(self, x):
126 | outputs = {}
127 | for task, func in self.last.items():
128 | outputs[task] = func(x)
129 | return outputs
130 |
131 | # Replace the task-dependent function
132 | model.logits = MethodType(new_logits, model)
133 | model.to(self.config.device)
134 | return model
135 |
136 | def forward(self, x, task):
137 |
138 | task_key = task[0]
139 | out = self.model.forward(x)
140 | # print(out)
141 | if self.config.is_split :
142 | try:
143 | return out[task_key]
144 | except:
145 | return out[int(task_key)]
146 | else :
147 | # return out
148 | return out["All"]
149 |
150 | def get_params_dict(self, last, task_key=None):
151 | if self.config.is_split :
152 | if last:
153 | return self.model.last[task_key].parameters()
154 | else:
155 |
156 | if self.config.all_features:
157 | ## take the conv parameters into account
158 | if hasattr(self.model,"conv"):
159 | return list(self.model.linear.parameters())+list(self.model.conv.parameters())
160 | else:
161 | return self.model.linear.parameters()
162 |
163 |
164 | else:
165 | return self.model.linear.parameters()
166 | # return self.model.linear.parameters()
167 | else:
168 | return self.model.parameters()
169 |
170 |
171 |
172 |
173 |
174 | def optimizer_step(self):
175 |
176 | task_key = str(self.task_id)
177 |
178 | ### take gradients with respect to the parameters
179 | grad_vec = parameters_to_grad_vector(self.get_params_dict(last=False))
180 | cur_param = parameters_to_vector(self.get_params_dict(last=False))
181 |
182 | if self.config.method in ['ogd','pca']:
183 |
184 | proj_grad_vec = ogd.project_vec(grad_vec,
185 | proj_basis=self.ogd_basis)
186 | ## take the orthogonal projection
187 | new_grad_vec = grad_vec - proj_grad_vec
188 |
189 | elif self.config.method=="agem" and self.agem_mem_loader is not None :
190 | self.optimizer.zero_grad()
191 | data, target, task = next(iter(self.agem_mem_loader))
192 | # data = self.to_device(data)
193 | data=data.to(self.config.device)
194 | target = target.long().to(self.config.device)
195 |
196 | output = self.forward(data, task)
197 | mem_loss = self.criterion(output, target)
198 | mem_loss.backward()
199 | mem_grad_vec = parameters_to_grad_vector(self.get_params_dict(last=False))
200 |
201 | self.gradient_count+=1
202 |
203 |
204 | new_grad_vec = gem._project_agem_grad(self,batch_grad_vec=grad_vec,
205 | mem_grad_vec=mem_grad_vec)
206 | elif self.config.method=="gem-nt":
207 |
208 | if self.task_count >= 1:
209 |
210 | for t,mem in self.task_memory.items():
211 | self.optimizer.zero_grad()
212 |
213 | mem_out = self.forward(self.task_mem_cache[t]['data'].to(self.config.device),self.task_mem_cache[t]['task'])
214 |
215 | mem_loss = self.criterion(mem_out, self.task_mem_cache[t]['target'].long().to(self.config.device))
216 |
217 | mem_loss.backward()
218 |
219 | self.task_grad_memory[t]=parameters_to_grad_vector(self.get_params_dict(last=False))
220 |
221 | mem_grad_vec = torch.stack(list(self.task_grad_memory.values()))
222 |
223 | new_grad_vec = gem.project2cone2(self,grad_vec, mem_grad_vec)
224 |
225 | else:
226 | new_grad_vec = grad_vec
227 |
228 | else:
229 | new_grad_vec = grad_vec
230 |
231 | ### SGD update => new_theta= old_theta - learning_rate x ( derivative of loss function wrt parameters )
232 | cur_param -= self.config.lr * new_grad_vec#.to(self.config.device)
233 |
234 | vector_to_parameters(cur_param, self.get_params_dict(last=False))
235 |
236 | if self.config.is_split :
237 | # Update the parameters of the last layer without projection, when there are multiple heads)
238 | cur_param = parameters_to_vector(self.get_params_dict(last=True, task_key=task_key))
239 | grad_vec = parameters_to_grad_vector(self.get_params_dict(last=True, task_key=task_key))
240 | cur_param -= self.config.lr * grad_vec
241 | vector_to_parameters(cur_param, self.get_params_dict(last=True, task_key=task_key))
242 |
243 | ### zero grad
244 | self.optimizer.zero_grad()
245 |
246 |
247 |
248 |
--------------------------------------------------------------------------------