3 |
4 | # [ScrollNet: Dynamic Weight Importance for Continual Learning](http://arxiv.org/abs/2308.16567)
5 |
6 |
7 |
8 | # Introduction
9 | The official Pytorch implementation for ScrollNet: Dynamic Weight Importance for Continual Learning, Visual Continual Learning workshop, ICCV 2023.
10 |
11 | # Installation
12 | ## Clone this github repository
13 | ```
14 | git clone https://github.com/FireFYF/ScrollNet.git
15 | cd ScrollNet
16 | ```
17 | ## Create a conda enviroment
18 | ```
19 | conda env create --file env.yml --name ScrollNet
20 | ```
21 | *Notice:* set the appropriate version of your CUDA driver for `cudatoolkit` in `env.yml`.
22 | ## Environment activation/deactivation
23 | ```
24 | conda activate ScrollNet
25 | conda deactivate
26 | ```
27 |
28 | # Launch experiments
29 |
30 | ## Run with ScrollNet-FT
31 | ```
32 | python -u src/main_incremental.py --gpu 0 --approach finetuning --results-path ./results/5splits/scrollnet_ft --num-tasks 5
33 | ```
34 | ## Run with ScrollNet-LWF
35 | ```
36 | python -u src/main_incremental.py --gpu 0 --approach lwf --results-path ./results/5splits/scrollnet_lwf --num-tasks 5
37 | ```
38 | ## Run with ScrollNet-EWC
39 | ```
40 | python -u src/main_incremental.py --gpu 0 --approach ewc --results-path ./results/5splits/scrollnet_ewc --num-tasks 5
41 | ```
42 |
43 | # Tune the number of subnetworks
44 | Please modify the file 'SizeOfSubnetworks.yml'. The default setting is for 4 subnetworks with equal splitting (ScrollNet-4).
45 |
46 | # Acknowledgement
47 | The implementation is based on [FACIL](https://github.com/mmasana/FACIL), which was developed as a framework based on class-incremental learning. We suggest referring to it if you want to incorporate more CL methods into ScrollNet.
48 |
49 | # Cite
50 | If you find this work useful for your research, please cite:
51 | ```bibtex
52 | @misc{yang2023scrollnet,
53 | title={ScrollNet: Dynamic Weight Importance for Continual Learning},
54 | author={Fei Yang and Kai Wang and Joost van de Weijer},
55 | year={2023},
56 | eprint={2308.16567},
57 | archivePrefix={arXiv},
58 | primaryClass={cs.CV}
59 | }
60 | ```
61 |
--------------------------------------------------------------------------------
/SizeOfSubnetworks.yml:
--------------------------------------------------------------------------------
1 | # =========================== Size of each subnetwork in ScrollNet ===========================
2 | width_mult: 1.0
3 | width_mult_list: [0.25, 0.5, 0.75, 1.0]
4 |
5 |
--------------------------------------------------------------------------------
/env.yml:
--------------------------------------------------------------------------------
1 | name: scrollnet
2 | channels:
3 | - defaults
4 | dependencies:
5 | - _libgcc_mutex=0.1=main
6 | - _openmp_mutex=5.1=1_gnu
7 | - ca-certificates=2023.01.10=h06a4308_0
8 | - certifi=2022.12.7=py38h06a4308_0
9 | - ld_impl_linux-64=2.38=h1181459_1
10 | - libffi=3.4.2=h6a678d5_6
11 | - libgcc-ng=11.2.0=h1234567_1
12 | - libgomp=11.2.0=h1234567_1
13 | - libstdcxx-ng=11.2.0=h1234567_1
14 | - ncurses=6.4=h6a678d5_0
15 | - openssl=1.1.1t=h7f8727e_0
16 | - pip=23.0.1=py38h06a4308_0
17 | - python=3.8.16=h7a1cb2a_3
18 | - readline=8.2=h5eee18b_0
19 | - setuptools=65.6.3=py38h06a4308_0
20 | - sqlite=3.41.1=h5eee18b_0
21 | - tk=8.6.12=h1ccaba5_0
22 | - wheel=0.38.4=py38h06a4308_0
23 | - xz=5.2.10=h5eee18b_1
24 | - zlib=1.2.13=h5eee18b_0
25 | - pip:
26 | - charset-normalizer==3.1.0
27 | - cmake==3.26.0
28 | - contourpy==1.0.7
29 | - cycler==0.11.0
30 | - filelock==3.10.0
31 | - fonttools==4.39.2
32 | - idna==3.4
33 | - importlib-resources==5.12.0
34 | - jinja2==3.1.2
35 | - kiwisolver==1.4.4
36 | - lit==15.0.7
37 | - markupsafe==2.1.2
38 | - matplotlib==3.7.1
39 | - mpmath==1.3.0
40 | - networkx==3.0
41 | - numpy==1.24.2
42 | - nvidia-cublas-cu11==11.10.3.66
43 | - nvidia-cuda-cupti-cu11==11.7.101
44 | - nvidia-cuda-nvrtc-cu11==11.7.99
45 | - nvidia-cuda-runtime-cu11==11.7.99
46 | - nvidia-cudnn-cu11==8.5.0.96
47 | - nvidia-cufft-cu11==10.9.0.58
48 | - nvidia-curand-cu11==10.2.10.91
49 | - nvidia-cusolver-cu11==11.4.0.1
50 | - nvidia-cusparse-cu11==11.7.4.91
51 | - nvidia-nccl-cu11==2.14.3
52 | - nvidia-nvtx-cu11==11.7.91
53 | - packaging==23.0
54 | - pillow==9.4.0
55 | - ptflops==0.7
56 | - pyparsing==3.0.9
57 | - python-dateutil==2.8.2
58 | - pyyaml==6.0
59 | - requests==2.28.2
60 | - six==1.16.0
61 | - sympy==1.11.1
62 | - torch==2.0.0
63 | - torchaudio==2.0.1
64 | - torchvision==0.15.1
65 | - triton==2.0.0
66 | - typing-extensions==4.5.0
67 | - urllib3==1.26.15
68 | - zipp==3.15.0
69 |
--------------------------------------------------------------------------------
/scrollnet.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FireFYF/ScrollNet/2b003e328666bcd796f768cc8111bed66ba232b4/scrollnet.gif
--------------------------------------------------------------------------------
/src/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FireFYF/ScrollNet/2b003e328666bcd796f768cc8111bed66ba232b4/src/.DS_Store
--------------------------------------------------------------------------------
/src/approach/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | # list all approaches available
4 | __all__ = list(
5 | map(lambda x: x[:-3],
6 | filter(lambda x: x not in ['__init__.py', 'incremental_learning.py'] and x.endswith('.py'),
7 | os.listdir(os.path.dirname(__file__))
8 | )
9 | )
10 | )
11 |
--------------------------------------------------------------------------------
/src/approach/ewc.py:
--------------------------------------------------------------------------------
1 | from turtle import width
2 | import torch
3 | import itertools
4 | from argparse import ArgumentParser
5 |
6 | from datasets.exemplars_dataset import ExemplarsDataset
7 | from .incremental_learning import Inc_Learning_Appr
8 | from widths.config import FLAGS
9 |
10 | class Appr(Inc_Learning_Appr):
11 | """Class implementing the Elastic Weight Consolidation (EWC) approach
12 | described in http://arxiv.org/abs/1612.00796
13 | """
14 |
15 | def __init__(self, model, device, nepochs=100, lr=0.05, decay_mile_stone=[80,120], lr_decay=0.1, clipgrad=10000,
16 | momentum=0, wd=0, multi_softmax=False, scroll_step=1, fix_bn=False, eval_on_train=False,
17 | logger=None, exemplars_dataset=None, lamb=5000, alpha=0.5, fi_sampling_type='max_pred',
18 | fi_num_samples=-1):
19 | super(Appr, self).__init__(model, device, nepochs, lr, decay_mile_stone, lr_decay, clipgrad, momentum, wd,
20 | multi_softmax, scroll_step, fix_bn, eval_on_train, logger,
21 | exemplars_dataset)
22 | self.lamb = lamb
23 | self.alpha = alpha
24 | self.sampling_type = fi_sampling_type
25 | self.num_samples = fi_num_samples
26 | self.scroll_step = scroll_step
27 |
28 | # In all cases, we only keep importance weights for the model, but not for the heads.
29 | feat_ext = self.model.model
30 | # Store current parameters as the initial parameters before first task starts
31 | self.older_params = {n: p.clone().detach() for n, p in feat_ext.named_parameters() if p.requires_grad}
32 | # Store fisher information weight importance
33 | self.fisher = {n: torch.zeros(p.shape).to(self.device) for n, p in feat_ext.named_parameters()
34 | if p.requires_grad}
35 |
36 | @staticmethod
37 | def exemplars_dataset_class():
38 | return ExemplarsDataset
39 |
40 | @staticmethod
41 | def extra_parser(args):
42 | """Returns a parser containing the approach specific parameters"""
43 | parser = ArgumentParser()
44 | # Eq. 3: "lambda sets how important the old task is compared to the new one"
45 | parser.add_argument('--lamb', default=5000, type=float, required=False,
46 | help='Forgetting-intransigence trade-off (default=%(default)s)')
47 | # Define how old and new fisher is fused, by default it is a 50-50 fusion
48 | parser.add_argument('--alpha', default=0.5, type=float, required=False,
49 | help='EWC alpha (default=%(default)s)')
50 | parser.add_argument('--fi-sampling-type', default='max_pred', type=str, required=False,
51 | choices=['true', 'max_pred', 'multinomial'],
52 | help='Sampling type for Fisher information (default=%(default)s)')
53 | parser.add_argument('--fi-num-samples', default=-1, type=int, required=False,
54 | help='Number of samples for Fisher information (-1: all available) (default=%(default)s)')
55 |
56 | return parser.parse_known_args(args)
57 |
58 | def _get_optimizer(self):
59 | """Returns the optimizer"""
60 | if len(self.exemplars_dataset) == 0 and len(self.model.heads) > 1:
61 | # if there are no exemplars, previous heads are not modified
62 | params = list(self.model.model.parameters()) + list(self.model.heads[-1].parameters())
63 | else:
64 | params = self.model.parameters()
65 | return torch.optim.SGD(params, lr=self.lr, weight_decay=self.wd, momentum=self.momentum)
66 |
67 | def compute_fisher_matrix_diag(self, trn_loader, t):
68 | # Store Fisher Information
69 | fisher = {n: torch.zeros(p.shape).to(self.device) for n, p in self.model.model.named_parameters() if p.requires_grad}
70 | # Compute fisher information for specified number of samples -- rounded to the batch size
71 | n_samples_batches = (self.num_samples // trn_loader.batch_size + 1) if self.num_samples > 0 \
72 | else (len(trn_loader.dataset) // trn_loader.batch_size)
73 | # Do forward and backward pass to compute the fisher information
74 | self.model.train()
75 | for images, targets in itertools.islice(trn_loader, n_samples_batches):
76 | self.model.apply(lambda m: setattr(m, 'scroll', self.scroll_step * t))
77 | self.model.apply(lambda m: setattr(m, 'width_mult', 1.0))
78 |
79 | outputs = self.model.forward(images.to(self.device))
80 |
81 | if self.sampling_type == 'true':
82 | # Use the labels to compute the gradients based on the CE-loss with the ground truth
83 | preds = targets.to(self.device)
84 | elif self.sampling_type == 'max_pred':
85 | # Not use labels and compute the gradients related to the prediction the model has learned
86 | preds = torch.cat(outputs, dim=1).argmax(1).flatten()
87 | elif self.sampling_type == 'multinomial':
88 | # Use a multinomial sampling to compute the gradients
89 | probs = torch.nn.functional.softmax(torch.cat(outputs, dim=1), dim=1)
90 | preds = torch.multinomial(probs, len(targets)).flatten()
91 |
92 | loss = torch.nn.functional.cross_entropy(torch.cat(outputs, dim=1), preds)
93 | self.optimizer.zero_grad()
94 | loss.backward()
95 | # Accumulate all gradients from loss with regularization
96 | for n, p in self.model.model.named_parameters():
97 | if p.grad is not None:
98 | fisher[n] += p.grad.pow(2) * len(targets)
99 | # Apply mean across all samples
100 | n_samples = n_samples_batches * trn_loader.batch_size
101 | fisher = {n: (p / n_samples) for n, p in fisher.items()}
102 | return fisher
103 |
104 | def train_loop(self, t, trn_loader, val_loader):
105 | """Contains the epochs loop"""
106 |
107 | # add exemplars to train_loader
108 | if len(self.exemplars_dataset) > 0 and t > 0:
109 | trn_loader = torch.utils.data.DataLoader(trn_loader.dataset + self.exemplars_dataset,
110 | batch_size=trn_loader.batch_size,
111 | shuffle=True,
112 | num_workers=trn_loader.num_workers,
113 | pin_memory=trn_loader.pin_memory)
114 |
115 | # FINETUNING TRAINING -- contains the epochs loop
116 | super().train_loop(t, trn_loader, val_loader)
117 |
118 | # EXEMPLAR MANAGEMENT -- select training subset
119 | self.exemplars_dataset.collect_exemplars(self.model, t, trn_loader, val_loader.dataset.transform)
120 |
121 | def post_train_process(self, t, trn_loader):
122 | """Runs after training all the epochs of the task (after the train session)"""
123 |
124 | # Store current parameters for the next task
125 | self.older_params = {n: p.clone().detach() for n, p in self.model.model.named_parameters() if p.requires_grad}
126 |
127 | # calculate Fisher information
128 | curr_fisher = self.compute_fisher_matrix_diag(trn_loader, t)
129 | # merge fisher information, we do not want to keep fisher information for each task in memory
130 | for n in self.fisher.keys():
131 | # Added option to accumulate fisher over time with a pre-fixed growing alpha
132 | if self.alpha == -1:
133 | alpha = (sum(self.model.task_cls[:t]) / sum(self.model.task_cls)).to(self.device)
134 | self.fisher[n] = alpha * self.fisher[n] + (1 - alpha) * curr_fisher[n]
135 | else:
136 | # pdb.set_trace()
137 | self.fisher[n] = (self.alpha * self.fisher[n] + (1 - self.alpha) * curr_fisher[n])
138 |
139 | def criterion(self, t, outputs, targets):
140 | """Returns the loss value"""
141 | loss = 0
142 | if t > 0:
143 | loss_reg = 0
144 | # Eq. 3: elastic weight consolidation quadratic penalty
145 | for n, p in self.model.model.named_parameters():
146 | if n in self.fisher.keys():
147 | loss_reg += torch.sum(self.fisher[n] * (p - self.older_params[n]).pow(2)) / 2
148 | loss += self.lamb * loss_reg
149 | # Current cross-entropy loss -- with exemplars use all heads
150 | if len(self.exemplars_dataset) > 0:
151 | return loss + torch.nn.functional.cross_entropy(torch.cat(outputs, dim=1), targets)
152 | return loss + torch.nn.functional.cross_entropy(outputs[t], targets - self.model.task_offset[t])
153 |
--------------------------------------------------------------------------------
/src/approach/finetuning.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from argparse import ArgumentParser
3 |
4 | from .incremental_learning import Inc_Learning_Appr
5 | from datasets.exemplars_dataset import ExemplarsDataset
6 |
7 |
8 | class Appr(Inc_Learning_Appr):
9 | """Class implementing the finetuning baseline"""
10 |
11 | def __init__(self, model, device, nepochs=100, lr=0.05, decay_mile_stone=[80,120], lr_decay=0.1, clipgrad=10000,
12 | momentum=0, wd=0, multi_softmax=False, scroll_step=1, fix_bn=False,
13 | eval_on_train=False, logger=None, exemplars_dataset=None, all_outputs=False):
14 | super(Appr, self).__init__(model, device, nepochs, lr, decay_mile_stone, lr_decay, clipgrad, momentum, wd,
15 | multi_softmax, scroll_step, fix_bn, eval_on_train, logger,
16 | exemplars_dataset)
17 | self.all_out = all_outputs
18 |
19 | @staticmethod
20 | def exemplars_dataset_class():
21 | return ExemplarsDataset
22 |
23 | @staticmethod
24 | def extra_parser(args):
25 | """Returns a parser containing the approach specific parameters"""
26 | parser = ArgumentParser()
27 | parser.add_argument('--all-outputs', action='store_true', required=False,
28 | help='Allow all weights related to all outputs to be modified (default=%(default)s)')
29 | return parser.parse_known_args(args)
30 |
31 | def _get_optimizer(self):
32 | """Returns the optimizer"""
33 | if len(self.exemplars_dataset) == 0 and len(self.model.heads) > 1 and not self.all_out:
34 | # if there are no exemplars, previous heads are not modified
35 | params = list(self.model.model.parameters()) + list(self.model.heads[-1].parameters())
36 | else:
37 | params = self.model.parameters()
38 | return torch.optim.SGD(params, lr=self.lr, weight_decay=self.wd, momentum=self.momentum)
39 |
40 | def train_loop(self, t, trn_loader, val_loader):
41 | """Contains the epochs loop"""
42 |
43 | # add exemplars to train_loader
44 | if len(self.exemplars_dataset) > 0 and t > 0:
45 | trn_loader = torch.utils.data.DataLoader(trn_loader.dataset + self.exemplars_dataset,
46 | batch_size=trn_loader.batch_size,
47 | shuffle=True,
48 | num_workers=trn_loader.num_workers,
49 | pin_memory=trn_loader.pin_memory)
50 |
51 | # FINETUNING TRAINING -- contains the epochs loop
52 | super().train_loop(t, trn_loader, val_loader)
53 |
54 | # EXEMPLAR MANAGEMENT -- select training subset
55 | self.exemplars_dataset.collect_exemplars(self.model, t, trn_loader, val_loader.dataset.transform)
56 |
57 | def criterion(self, t, outputs, targets):
58 | """Returns the loss value"""
59 | if self.all_out or len(self.exemplars_dataset) > 0:
60 | return torch.nn.functional.cross_entropy(torch.cat(outputs, dim=1), targets)
61 | return torch.nn.functional.cross_entropy(outputs[t], targets - self.model.task_offset[t])
62 |
--------------------------------------------------------------------------------
/src/approach/incremental_learning.py:
--------------------------------------------------------------------------------
1 | from sched import scheduler
2 | import time
3 | import torch
4 | import numpy as np
5 | from argparse import ArgumentParser
6 | from widths.config import FLAGS
7 | from loggers.exp_logger import ExperimentLogger
8 | from datasets.exemplars_dataset import ExemplarsDataset
9 |
10 | class Inc_Learning_Appr:
11 | """Basic class for implementing incremental learning approaches"""
12 |
13 | def __init__(self, model, device, nepochs=100, lr=0.05, decay_mile_stone=[80,120], lr_decay=0.1, clipgrad=10000,
14 | momentum=0, wd=0, multi_softmax=False, scroll_step=1, fix_bn=False,
15 | eval_on_train=False, logger: ExperimentLogger = None, exemplars_dataset: ExemplarsDataset = None):
16 | self.model = model
17 | self.device = device
18 | self.nepochs = nepochs
19 | self.lr = lr
20 | self.decay_mile_stone = decay_mile_stone
21 | self.lr_decay = lr_decay
22 | self.clipgrad = clipgrad
23 | self.momentum = momentum
24 | self.wd = wd
25 | self.multi_softmax = multi_softmax
26 | self.logger = logger
27 | self.exemplars_dataset = exemplars_dataset
28 | self.fix_bn = fix_bn
29 | self.eval_on_train = eval_on_train
30 | self.scroll_step = scroll_step
31 | self.optimizer = None
32 |
33 | @staticmethod
34 | def extra_parser(args):
35 | """Returns a parser containing the approach specific parameters"""
36 | parser = ArgumentParser()
37 | return parser.parse_known_args(args)
38 |
39 | @staticmethod
40 | def exemplars_dataset_class():
41 | """Returns a exemplar dataset to use during the training if the approach needs it
42 | :return: ExemplarDataset class or None
43 | """
44 | return None
45 |
46 | def _get_optimizer(self):
47 | """Returns the optimizer"""
48 | return torch.optim.SGD(self.model.parameters(), lr=self.lr, weight_decay=self.wd, momentum=self.momentum)
49 |
50 | def train(self, t, trn_loader, val_loader):
51 | """Main train structure"""
52 | self.train_loop(t, trn_loader, val_loader)
53 | self.post_train_process(t, trn_loader)
54 |
55 | def train_loop(self, t, trn_loader, val_loader):
56 | """Contains the epochs loop"""
57 | lr = self.lr
58 | self.optimizer = self._get_optimizer()
59 | scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=self.decay_mile_stone, gamma=self.lr_decay)
60 |
61 | # Loop epochs
62 | for e in range(self.nepochs):
63 | # Train
64 | clock0 = time.time()
65 | self.train_epoch(t, trn_loader)
66 | clock1 = time.time()
67 | if self.eval_on_train:
68 | train_loss, train_acc, _ = self.eval(t, trn_loader)
69 | clock2 = time.time()
70 | print('| Epoch {:3d}, time={:5.1f}s/{:5.1f}s | Train: loss={:.3f}, TAw acc={:5.1f}% |'.format(
71 | e + 1, clock1 - clock0, clock2 - clock1, train_loss, 100 * train_acc), end='')
72 | self.logger.log_scalar(task=t, iter=e + 1, name="loss", value=train_loss, group="train")
73 | self.logger.log_scalar(task=t, iter=e + 1, name="acc", value=100 * train_acc, group="train")
74 | else:
75 | print('| Epoch {:3d}, time={:5.1f}s | Train: skip eval |'.format(e + 1, clock1 - clock0), end='')
76 |
77 | # Valid
78 | clock3 = time.time()
79 | valid_loss, valid_acc, _ = self.eval(t, val_loader, t)
80 | clock4 = time.time()
81 | print(' Valid: time={:5.1f}s loss={:.3f} TAw acc={:5.1f}% |'.format(
82 | clock4 - clock3, valid_loss, 100 * valid_acc), end='')
83 | self.logger.log_scalar(task=t, iter=e + 1, name="loss", value=valid_loss, group="valid")
84 | self.logger.log_scalar(task=t, iter=e + 1, name="acc", value=100 * valid_acc, group="valid")
85 | scheduler.step()
86 | print(' lr={:.1e}'.format(self.optimizer.param_groups[0]['lr']), end='')
87 | self.logger.log_scalar(task=t, iter=e + 1, name="lr", value=lr, group="train")
88 | print()
89 |
90 | def post_train_process(self, t, trn_loader):
91 | """Runs after training all the epochs of the task (after the train session)"""
92 | pass
93 |
94 | def train_epoch(self, t, trn_loader):
95 | """Runs a single epoch"""
96 | self.model.train()
97 | if self.fix_bn and t > 0:
98 | self.model.freeze_bn()
99 | for images, targets in trn_loader:
100 | # Forward current model
101 | total_loss = 0.0
102 | self.model.apply(lambda m: setattr(m, 'scroll', self.scroll_step * t)) # change the perception of channels to shuffle
103 | for width_mult in sorted(FLAGS.width_mult_list, reverse=True):
104 | self.model.apply(lambda m: setattr(m, 'width_mult', width_mult))
105 | outputs = self.model(images.to(self.device))
106 | loss = self.criterion(t, outputs, targets.to(self.device))
107 | total_loss += loss
108 | # Backward
109 | self.optimizer.zero_grad()
110 | total_loss.backward()
111 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clipgrad)
112 | self.optimizer.step()
113 |
114 | def eval(self, t, val_loader, real_t):
115 | """Contains the evaluation code"""
116 | width_max = max(FLAGS.width_mult_list)
117 | with torch.no_grad():
118 | total_loss, total_acc_taw, total_acc_tag, total_num = 0, 0, 0, 0
119 | self.model.eval()
120 | self.model.apply(lambda m: setattr(m, 'scroll', self.scroll_step * real_t))
121 | for images, targets in val_loader:
122 | # Forward current model
123 | self.model.apply(lambda m: setattr(m, 'width_mult', width_max))
124 | outputs = self.model(images.to(self.device))
125 | loss = self.criterion(t, outputs, targets.to(self.device))
126 | hits_taw, hits_tag = self.calculate_metrics(outputs, targets)
127 | # Log
128 | total_loss += loss.item() * len(targets)
129 | total_acc_taw += hits_taw.sum().item()
130 | total_acc_tag += hits_tag.sum().item()
131 | total_num += len(targets)
132 | return total_loss / total_num, total_acc_taw / total_num, total_acc_tag / total_num
133 |
134 | def calculate_metrics(self, outputs, targets):
135 | """Contains the main Task-Aware and Task-Agnostic metrics"""
136 | pred = torch.zeros_like(targets.to(self.device))
137 | # Task-Aware Multi-Head
138 | for m in range(len(pred)):
139 | this_task = (self.model.task_cls.cumsum(0) <= targets[m]).sum()
140 | pred[m] = outputs[this_task][m].argmax() + self.model.task_offset[this_task]
141 | hits_taw = (pred == targets.to(self.device)).float()
142 | # Task-Agnostic Multi-Head
143 | if self.multi_softmax:
144 | outputs = [torch.nn.functional.log_softmax(output, dim=1) for output in outputs]
145 | pred = torch.cat(outputs, dim=1).argmax(1)
146 | else:
147 | pred = torch.cat(outputs, dim=1).argmax(1)
148 | hits_tag = (pred == targets.to(self.device)).float()
149 | return hits_taw, hits_tag
150 |
151 | def criterion(self, t, outputs, targets):
152 | """Returns the loss value"""
153 | return torch.nn.functional.cross_entropy(outputs[t], targets - self.model.task_offset[t])
154 |
--------------------------------------------------------------------------------
/src/approach/lwf.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from copy import deepcopy
3 | from argparse import ArgumentParser
4 |
5 | from .incremental_learning import Inc_Learning_Appr
6 | from datasets.exemplars_dataset import ExemplarsDataset
7 | from widths.config import FLAGS
8 |
9 | class Appr(Inc_Learning_Appr):
10 | """Class implementing the Learning Without Forgetting (LwF) approach
11 | described in https://arxiv.org/abs/1606.09282
12 | """
13 |
14 | # Weight decay of 0.0005 is used in the original article (page 4).
15 | # Page 4: "The warm-up step greatly enhances fine-tuning’s old-task performance, but is not so crucial to either our
16 | # method or the compared Less Forgetting Learning (see Table 2(b))."
17 | def __init__(self, model, device, nepochs=100, lr=0.05, decay_mile_stone=[80,120], lr_decay=0.1, clipgrad=10000,
18 | momentum=0, wd=0, multi_softmax=False, scroll_step=1, fix_bn=False, eval_on_train=False,
19 | logger=None, exemplars_dataset=None, lamb=1, T=2):
20 | super(Appr, self).__init__(model, device, nepochs, lr, decay_mile_stone, lr_decay, clipgrad, momentum, wd,
21 | multi_softmax, scroll_step, fix_bn, eval_on_train, logger,
22 | exemplars_dataset)
23 | self.model_old = None
24 | self.lamb = lamb
25 | self.T = T
26 |
27 | @staticmethod
28 | def exemplars_dataset_class():
29 | return ExemplarsDataset
30 |
31 | @staticmethod
32 | def extra_parser(args):
33 | """Returns a parser containing the approach specific parameters"""
34 | parser = ArgumentParser()
35 | # Page 5: "lambda is a loss balance weight, set to 1 for most our experiments. Making lambda larger will favor
36 | # the old task performance over the new task’s, so we can obtain a old-task-new-task performance line by
37 | # changing lambda."
38 | parser.add_argument('--lamb', default=1, type=float, required=False,
39 | help='Forgetting-intransigence trade-off (default=%(default)s)')
40 | # Page 5: "We use T=2 according to a grid search on a held out set, which aligns with the authors’
41 | # recommendations." -- Using a higher value for T produces a softer probability distribution over classes.
42 | parser.add_argument('--T', default=2, type=int, required=False,
43 | help='Temperature scaling (default=%(default)s)')
44 | return parser.parse_known_args(args)
45 |
46 | def _get_optimizer(self):
47 | """Returns the optimizer"""
48 | if len(self.exemplars_dataset) == 0 and len(self.model.heads) > 1:
49 | # if there are no exemplars, previous heads are not modified
50 | params = list(self.model.model.parameters()) + list(self.model.heads[-1].parameters())
51 | else:
52 | params = self.model.parameters()
53 | return torch.optim.SGD(params, lr=self.lr, weight_decay=self.wd, momentum=self.momentum)
54 |
55 | def train_loop(self, t, trn_loader, val_loader):
56 | """Contains the epochs loop"""
57 |
58 | # add exemplars to train_loader
59 | if len(self.exemplars_dataset) > 0 and t > 0:
60 | trn_loader = torch.utils.data.DataLoader(trn_loader.dataset + self.exemplars_dataset,
61 | batch_size=trn_loader.batch_size,
62 | shuffle=True,
63 | num_workers=trn_loader.num_workers,
64 | pin_memory=trn_loader.pin_memory)
65 |
66 | # FINETUNING TRAINING -- contains the epochs loop
67 | super().train_loop(t, trn_loader, val_loader)
68 |
69 | # EXEMPLAR MANAGEMENT -- select training subset
70 | self.exemplars_dataset.collect_exemplars(self.model, t, trn_loader, val_loader.dataset.transform)
71 |
72 | def post_train_process(self, t, trn_loader):
73 | """Runs after training all the epochs of the task (after the train session)"""
74 |
75 | # Restore best and save model for future tasks
76 | self.model_old = deepcopy(self.model)
77 | self.model_old.eval()
78 | self.model_old.freeze_all()
79 |
80 | def train_epoch(self, t, trn_loader):
81 | """Runs a single epoch"""
82 | self.model.train()
83 |
84 | if self.fix_bn and t > 0:
85 | self.model.freeze_bn()
86 | for images, targets in trn_loader:
87 | # Forward old model
88 | targets_old = None
89 | if t > 0:
90 | self.model_old.apply(lambda m: setattr(m, 'scroll', self.scroll_step * (t-1)))
91 | width_mult = max(FLAGS.width_mult_list)
92 | self.model_old.apply(lambda m: setattr(m, 'width_mult', width_mult))
93 | targets_old = self.model_old(images.to(self.device))
94 | # Forward current model
95 | total_loss = 0.0
96 | self.model.apply(lambda m: setattr(m, 'scroll', self.scroll_step * t))
97 | for i, width_mult in enumerate(sorted(FLAGS.width_mult_list, reverse=True)):
98 | self.model.apply(lambda m: setattr(m, 'width_mult', width_mult))
99 | outputs = self.model(images.to(self.device))
100 | if i == 0:
101 | loss = self.criterion(t, outputs, targets.to(self.device), targets_old)
102 | else:
103 | loss = self.criterion(t, outputs, targets.to(self.device))
104 | total_loss += loss
105 | # Backward
106 | self.optimizer.zero_grad()
107 | total_loss.backward()
108 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clipgrad)
109 | self.optimizer.step()
110 |
111 | def eval(self, t, val_loader, real_t):
112 | """Contains the evaluation code"""
113 | width_max = max(FLAGS.width_mult_list)
114 | with torch.no_grad():
115 | total_loss, total_acc_taw, total_acc_tag, total_num = 0, 0, 0, 0
116 | self.model.eval()
117 | for images, targets in val_loader:
118 | # Forward old model
119 | targets_old = None
120 | if t > 0:
121 | self.model_old.apply(lambda m: setattr(m, 'scroll', self.scroll_step * (real_t-1)))
122 | self.model_old.apply(lambda m: setattr(m, 'width_mult', width_max))
123 | targets_old = self.model_old(images.to(self.device))
124 | # Forward current model
125 | self.model.apply(lambda m: setattr(m, 'scroll', self.scroll_step * (real_t)))
126 | self.model.apply(lambda m: setattr(m, 'width_mult', width_max))
127 | outputs = self.model(images.to(self.device))
128 | loss = self.criterion(t, outputs, targets.to(self.device), targets_old)
129 | hits_taw, hits_tag = self.calculate_metrics(outputs, targets)
130 | # Log
131 | total_loss += loss.item() * len(targets)
132 | total_acc_taw += hits_taw.sum().item()
133 | total_acc_tag += hits_tag.sum().item()
134 | total_num += len(targets)
135 | return total_loss / total_num, total_acc_taw / total_num, total_acc_tag / total_num
136 |
137 | def cross_entropy(self, outputs, targets, exp=1.0, size_average=True, eps=1e-5):
138 | """Calculates cross-entropy with temperature scaling"""
139 | out = torch.nn.functional.softmax(outputs, dim=1)
140 | tar = torch.nn.functional.softmax(targets, dim=1)
141 | if exp != 1:
142 | out = out.pow(exp)
143 | out = out / out.sum(1).view(-1, 1).expand_as(out)
144 | tar = tar.pow(exp)
145 | tar = tar / tar.sum(1).view(-1, 1).expand_as(tar)
146 | out = out + eps / out.size(1)
147 | out = out / out.sum(1).view(-1, 1).expand_as(out)
148 | ce = -(tar * out.log()).sum(1)
149 | if size_average:
150 | ce = ce.mean()
151 | return ce
152 |
153 | def criterion(self, t, outputs, targets, outputs_old=None):
154 | """Returns the loss value"""
155 | loss = 0
156 | if t > 0 and outputs_old is not None:
157 | # Knowledge distillation loss for all previous tasks
158 | loss += self.lamb * self.cross_entropy(torch.cat(outputs[:t], dim=1),
159 | torch.cat(outputs_old[:t], dim=1), exp=1.0 / self.T)
160 | # Current cross-entropy loss -- with exemplars use all heads
161 | if len(self.exemplars_dataset) > 0:
162 | return loss + torch.nn.functional.cross_entropy(torch.cat(outputs, dim=1), targets)
163 | return loss + torch.nn.functional.cross_entropy(outputs[t], targets - self.model.task_offset[t])
164 |
--------------------------------------------------------------------------------
/src/datasets/base_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import numpy as np
4 | from PIL import Image
5 | from torch.utils.data import Dataset
6 |
7 |
8 | class BaseDataset(Dataset):
9 | """Characterizes a dataset for PyTorch -- this dataset pre-loads all paths in memory"""
10 |
11 | def __init__(self, data, transform, class_indices=None):
12 | """Initialization"""
13 | self.labels = data['y']
14 | self.images = data['x']
15 | self.transform = transform
16 | self.class_indices = class_indices
17 |
18 | def __len__(self):
19 | """Denotes the total number of samples"""
20 | return len(self.images)
21 |
22 | def __getitem__(self, index):
23 | """Generates one sample of data"""
24 | x = Image.open(self.images[index]).convert('RGB')
25 | x = self.transform(x)
26 | y = self.labels[index]
27 | return x, y
28 |
29 |
30 | def get_data(path, num_tasks, nc_first_task, validation, shuffle_classes, class_order=None):
31 | """Prepare data: dataset splits, task partition, class order"""
32 |
33 | data = {}
34 | taskcla = []
35 |
36 | # read filenames and labels
37 | trn_lines = np.loadtxt(os.path.join(path, 'train.txt'), dtype=str)
38 | tst_lines = np.loadtxt(os.path.join(path, 'test.txt'), dtype=str)
39 | if class_order is None:
40 | num_classes = len(np.unique(trn_lines[:, 1]))
41 | class_order = list(range(num_classes))
42 | else:
43 | num_classes = len(class_order)
44 | class_order = class_order.copy()
45 | if shuffle_classes:
46 | np.random.shuffle(class_order)
47 |
48 | # compute classes per task and num_tasks
49 | if nc_first_task is None:
50 | cpertask = np.array([num_classes // num_tasks] * num_tasks)
51 | for i in range(num_classes % num_tasks):
52 | cpertask[i] += 1
53 | else:
54 | assert nc_first_task < num_classes, "first task wants more classes than exist"
55 | remaining_classes = num_classes - nc_first_task
56 | assert remaining_classes >= (num_tasks - 1), "at least one class is needed per task" # better minimum 2
57 | cpertask = np.array([nc_first_task] + [remaining_classes // (num_tasks - 1)] * (num_tasks - 1))
58 | for i in range(remaining_classes % (num_tasks - 1)):
59 | cpertask[i + 1] += 1
60 |
61 | assert num_classes == cpertask.sum(), "something went wrong, the split does not match num classes"
62 | cpertask_cumsum = np.cumsum(cpertask)
63 | init_class = np.concatenate(([0], cpertask_cumsum[:-1]))
64 |
65 | # initialize data structure
66 | for tt in range(num_tasks):
67 | data[tt] = {}
68 | data[tt]['name'] = 'task-' + str(tt)
69 | data[tt]['trn'] = {'x': [], 'y': []}
70 | data[tt]['val'] = {'x': [], 'y': []}
71 | data[tt]['tst'] = {'x': [], 'y': []}
72 |
73 | # ALL OR TRAIN
74 | for this_image, this_label in trn_lines:
75 | if not os.path.isabs(this_image):
76 | this_image = os.path.join(path, this_image)
77 | this_label = int(this_label)
78 | if this_label not in class_order:
79 | continue
80 | # If shuffling is false, it won't change the class number
81 | this_label = class_order.index(this_label)
82 |
83 | # add it to the corresponding split
84 | this_task = (this_label >= cpertask_cumsum).sum()
85 | data[this_task]['trn']['x'].append(this_image)
86 | data[this_task]['trn']['y'].append(this_label - init_class[this_task])
87 |
88 | # ALL OR TEST
89 | for this_image, this_label in tst_lines:
90 | if not os.path.isabs(this_image):
91 | this_image = os.path.join(path, this_image)
92 | this_label = int(this_label)
93 | if this_label not in class_order:
94 | continue
95 | # If shuffling is false, it won't change the class number
96 | this_label = class_order.index(this_label)
97 |
98 | # add it to the corresponding split
99 | this_task = (this_label >= cpertask_cumsum).sum()
100 | data[this_task]['tst']['x'].append(this_image)
101 | data[this_task]['tst']['y'].append(this_label - init_class[this_task])
102 |
103 | # check classes
104 | for tt in range(num_tasks):
105 | data[tt]['ncla'] = len(np.unique(data[tt]['trn']['y']))
106 | assert data[tt]['ncla'] == cpertask[tt], "something went wrong splitting classes"
107 |
108 | # validation
109 | if validation > 0.0:
110 | for tt in data.keys():
111 | for cc in range(data[tt]['ncla']):
112 | cls_idx = list(np.where(np.asarray(data[tt]['trn']['y']) == cc)[0])
113 | rnd_img = random.sample(cls_idx, int(np.round(len(cls_idx) * validation)))
114 | rnd_img.sort(reverse=True)
115 | for ii in range(len(rnd_img)):
116 | data[tt]['val']['x'].append(data[tt]['trn']['x'][rnd_img[ii]])
117 | data[tt]['val']['y'].append(data[tt]['trn']['y'][rnd_img[ii]])
118 | data[tt]['trn']['x'].pop(rnd_img[ii])
119 | data[tt]['trn']['y'].pop(rnd_img[ii])
120 |
121 | # other
122 | n = 0
123 | for t in data.keys():
124 | taskcla.append((t, data[t]['ncla']))
125 | n += data[t]['ncla']
126 | data['ncla'] = n
127 |
128 | return data, taskcla, class_order
129 |
--------------------------------------------------------------------------------
/src/datasets/data_loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from torch.utils import data
4 | import torchvision.transforms as transforms
5 | from torchvision.datasets import MNIST as TorchVisionMNIST
6 | from torchvision.datasets import CIFAR100 as TorchVisionCIFAR100
7 | from torchvision.datasets import SVHN as TorchVisionSVHN
8 |
9 | from . import base_dataset as basedat
10 | from . import memory_dataset as memd
11 | from .dataset_config import dataset_config
12 |
13 |
14 | def get_loaders(datasets, num_tasks, nc_first_task, batch_size, num_workers, pin_memory, validation=.1):
15 | """Apply transformations to Datasets and create the DataLoaders for each task"""
16 |
17 | trn_load, val_load, tst_load = [], [], []
18 | taskcla = []
19 | dataset_offset = 0
20 | for idx_dataset, cur_dataset in enumerate(datasets, 0):
21 | # get configuration for current dataset
22 | dc = dataset_config[cur_dataset]
23 |
24 | # transformations
25 | trn_transform, tst_transform = get_transforms(resize=dc['resize'],
26 | pad=dc['pad'],
27 | crop=dc['crop'],
28 | flip=dc['flip'],
29 | normalize=dc['normalize'],
30 | extend_channel=dc['extend_channel'])
31 |
32 | # datasets
33 | trn_dset, val_dset, tst_dset, curtaskcla = get_datasets(cur_dataset, dc['path'], num_tasks, nc_first_task,
34 | validation=validation,
35 | trn_transform=trn_transform,
36 | tst_transform=tst_transform,
37 | class_order=dc['class_order'])
38 |
39 | # apply offsets in case of multiple datasets
40 | if idx_dataset > 0:
41 | for tt in range(num_tasks):
42 | trn_dset[tt].labels = [elem + dataset_offset for elem in trn_dset[tt].labels]
43 | val_dset[tt].labels = [elem + dataset_offset for elem in val_dset[tt].labels]
44 | tst_dset[tt].labels = [elem + dataset_offset for elem in tst_dset[tt].labels]
45 | dataset_offset = dataset_offset + sum([tc[1] for tc in curtaskcla])
46 |
47 | # reassign class idx for multiple dataset case
48 | curtaskcla = [(tc[0] + idx_dataset * num_tasks, tc[1]) for tc in curtaskcla]
49 |
50 | # extend final taskcla list
51 | taskcla.extend(curtaskcla)
52 |
53 | # loaders
54 | for tt in range(num_tasks):
55 | trn_load.append(data.DataLoader(trn_dset[tt], batch_size=batch_size, shuffle=True, num_workers=num_workers,
56 | pin_memory=pin_memory))
57 | val_load.append(data.DataLoader(val_dset[tt], batch_size=batch_size, shuffle=False, num_workers=num_workers,
58 | pin_memory=pin_memory))
59 | tst_load.append(data.DataLoader(tst_dset[tt], batch_size=batch_size, shuffle=False, num_workers=num_workers,
60 | pin_memory=pin_memory))
61 | return trn_load, val_load, tst_load, taskcla
62 |
63 |
64 | def get_datasets(dataset, path, num_tasks, nc_first_task, validation, trn_transform, tst_transform, class_order=None):
65 | """Extract datasets and create Dataset class"""
66 |
67 | trn_dset, val_dset, tst_dset = [], [], []
68 |
69 | if 'mnist' in dataset:
70 | tvmnist_trn = TorchVisionMNIST(path, train=True, download=True)
71 | tvmnist_tst = TorchVisionMNIST(path, train=False, download=True)
72 | trn_data = {'x': tvmnist_trn.data.numpy(), 'y': tvmnist_trn.targets.tolist()}
73 | tst_data = {'x': tvmnist_tst.data.numpy(), 'y': tvmnist_tst.targets.tolist()}
74 | # compute splits
75 | all_data, taskcla, class_indices = memd.get_data(trn_data, tst_data, validation=validation,
76 | num_tasks=num_tasks, nc_first_task=nc_first_task,
77 | shuffle_classes=class_order is None, class_order=class_order)
78 | # set dataset type
79 | Dataset = memd.MemoryDataset
80 |
81 | elif 'cifar100' in dataset:
82 | tvcifar_trn = TorchVisionCIFAR100(path, train=True, download=True)
83 | tvcifar_tst = TorchVisionCIFAR100(path, train=False, download=True)
84 | trn_data = {'x': tvcifar_trn.data, 'y': tvcifar_trn.targets}
85 | tst_data = {'x': tvcifar_tst.data, 'y': tvcifar_tst.targets}
86 | # compute splits
87 | all_data, taskcla, class_indices = memd.get_data(trn_data, tst_data, validation=validation,
88 | num_tasks=num_tasks, nc_first_task=nc_first_task,
89 | shuffle_classes=class_order is None, class_order=class_order)
90 | # set dataset type
91 | Dataset = memd.MemoryDataset
92 |
93 | elif dataset == 'svhn':
94 | tvsvhn_trn = TorchVisionSVHN(path, split='train', download=True)
95 | tvsvhn_tst = TorchVisionSVHN(path, split='test', download=True)
96 | trn_data = {'x': tvsvhn_trn.data.transpose(0, 2, 3, 1), 'y': tvsvhn_trn.labels}
97 | tst_data = {'x': tvsvhn_tst.data.transpose(0, 2, 3, 1), 'y': tvsvhn_tst.labels}
98 | # Notice that SVHN in Torchvision has an extra training set in case needed
99 | # tvsvhn_xtr = TorchVisionSVHN(path, split='extra', download=True)
100 | # xtr_data = {'x': tvsvhn_xtr.data.transpose(0, 2, 3, 1), 'y': tvsvhn_xtr.labels}
101 |
102 | # compute splits
103 | all_data, taskcla, class_indices = memd.get_data(trn_data, tst_data, validation=validation,
104 | num_tasks=num_tasks, nc_first_task=nc_first_task,
105 | shuffle_classes=class_order is None, class_order=class_order)
106 | # set dataset type
107 | Dataset = memd.MemoryDataset
108 |
109 | elif 'imagenet_32' in dataset:
110 | import pickle
111 | # load data
112 | x_trn, y_trn = [], []
113 | for i in range(1, 11):
114 | with open(os.path.join(path, 'train_data_batch_{}'.format(i)), 'rb') as f:
115 | d = pickle.load(f)
116 | x_trn.append(d['data'])
117 | y_trn.append(np.array(d['labels']) - 1) # labels from 0 to 999
118 | with open(os.path.join(path, 'val_data'), 'rb') as f:
119 | d = pickle.load(f)
120 | x_trn.append(d['data'])
121 | y_tst = np.array(d['labels']) - 1 # labels from 0 to 999
122 | # reshape data
123 | for i, d in enumerate(x_trn, 0):
124 | x_trn[i] = d.reshape(d.shape[0], 3, 32, 32).transpose(0, 2, 3, 1)
125 | x_tst = x_trn[-1]
126 | x_trn = np.vstack(x_trn[:-1])
127 | y_trn = np.concatenate(y_trn)
128 | trn_data = {'x': x_trn, 'y': y_trn}
129 | tst_data = {'x': x_tst, 'y': y_tst}
130 | # compute splits
131 | all_data, taskcla, class_indices = memd.get_data(trn_data, tst_data, validation=validation,
132 | num_tasks=num_tasks, nc_first_task=nc_first_task,
133 | shuffle_classes=class_order is None, class_order=class_order)
134 | # set dataset type
135 | Dataset = memd.MemoryDataset
136 |
137 | else:
138 | # read data paths and compute splits -- path needs to have a train.txt and a test.txt with image-label pairs
139 | all_data, taskcla, class_indices = basedat.get_data(path, num_tasks=num_tasks, nc_first_task=nc_first_task,
140 | validation=validation, shuffle_classes=class_order is None,
141 | class_order=class_order)
142 | # set dataset type
143 | Dataset = basedat.BaseDataset
144 |
145 | # get datasets, apply correct label offsets for each task
146 | offset = 0
147 | for task in range(num_tasks):
148 | all_data[task]['trn']['y'] = [label + offset for label in all_data[task]['trn']['y']]
149 | all_data[task]['val']['y'] = [label + offset for label in all_data[task]['val']['y']]
150 | all_data[task]['tst']['y'] = [label + offset for label in all_data[task]['tst']['y']]
151 | trn_dset.append(Dataset(all_data[task]['trn'], trn_transform, class_indices))
152 | val_dset.append(Dataset(all_data[task]['val'], tst_transform, class_indices))
153 | tst_dset.append(Dataset(all_data[task]['tst'], tst_transform, class_indices))
154 | offset += taskcla[task][1]
155 |
156 | return trn_dset, val_dset, tst_dset, taskcla
157 |
158 |
159 | def get_transforms(resize, pad, crop, flip, normalize, extend_channel):
160 | """Unpack transformations and apply to train or test splits"""
161 |
162 | trn_transform_list = []
163 | tst_transform_list = []
164 |
165 | # resize
166 | if resize is not None:
167 | trn_transform_list.append(transforms.Resize(resize))
168 | tst_transform_list.append(transforms.Resize(resize))
169 |
170 | # padding
171 | if pad is not None:
172 | trn_transform_list.append(transforms.Pad(pad))
173 | tst_transform_list.append(transforms.Pad(pad))
174 |
175 | # crop
176 | if crop is not None:
177 | trn_transform_list.append(transforms.RandomResizedCrop(crop))
178 | tst_transform_list.append(transforms.CenterCrop(crop))
179 |
180 | # flips
181 | if flip:
182 | trn_transform_list.append(transforms.RandomHorizontalFlip())
183 |
184 | # to tensor
185 | trn_transform_list.append(transforms.ToTensor())
186 | tst_transform_list.append(transforms.ToTensor())
187 |
188 | # normalization
189 | if normalize is not None:
190 | trn_transform_list.append(transforms.Normalize(mean=normalize[0], std=normalize[1]))
191 | tst_transform_list.append(transforms.Normalize(mean=normalize[0], std=normalize[1]))
192 |
193 | # gray to rgb
194 | if extend_channel is not None:
195 | trn_transform_list.append(transforms.Lambda(lambda x: x.repeat(extend_channel, 1, 1)))
196 | tst_transform_list.append(transforms.Lambda(lambda x: x.repeat(extend_channel, 1, 1)))
197 |
198 | return transforms.Compose(trn_transform_list), \
199 | transforms.Compose(tst_transform_list)
200 |
--------------------------------------------------------------------------------
/src/datasets/dataset_config.py:
--------------------------------------------------------------------------------
1 | from os.path import join
2 |
3 | _BASE_DATA_PATH = "../data"
4 |
5 | dataset_config = {
6 | 'mnist': {
7 | 'path': join(_BASE_DATA_PATH, 'mnist'),
8 | 'normalize': ((0.1307,), (0.3081,)),
9 | # Use the next 3 lines to use MNIST with a 3x32x32 input
10 | # 'extend_channel': 3,
11 | # 'pad': 2,
12 | # 'normalize': ((0.1,), (0.2752,)) # values including padding
13 | },
14 | 'svhn': {
15 | 'path': join(_BASE_DATA_PATH, 'svhn'),
16 | 'resize': (224, 224),
17 | 'crop': None,
18 | 'flip': False,
19 | 'normalize': ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
20 | },
21 | 'cifar100': {
22 | 'path': join(_BASE_DATA_PATH, 'cifar100'),
23 | 'resize': None,
24 | 'pad': 4,
25 | 'crop': 32,
26 | 'flip': True,
27 | 'normalize': ((0.5071, 0.4866, 0.4409), (0.2009, 0.1984, 0.2023))
28 | },
29 | 'cifar100_icarl': {
30 | 'path': join(_BASE_DATA_PATH, 'cifar100'),
31 | 'resize': None,
32 | 'pad': 4,
33 | 'crop': 32,
34 | 'flip': True,
35 | 'normalize': ((0.5071, 0.4866, 0.4409), (0.2009, 0.1984, 0.2023)),
36 | 'class_order': [
37 | 68, 56, 78, 8, 23, 84, 90, 65, 74, 76, 40, 89, 3, 92, 55, 9, 26, 80, 43, 38, 58, 70, 77, 1, 85, 19, 17, 50,
38 | 28, 53, 13, 81, 45, 82, 6, 59, 83, 16, 15, 44, 91, 41, 72, 60, 79, 52, 20, 10, 31, 54, 37, 95, 14, 71, 96,
39 | 98, 97, 2, 64, 66, 42, 22, 35, 86, 24, 34, 87, 21, 99, 0, 88, 27, 18, 94, 11, 12, 47, 25, 30, 46, 62, 69,
40 | 36, 61, 7, 63, 75, 5, 32, 4, 51, 48, 73, 93, 39, 67, 29, 49, 57, 33
41 | ]
42 | },
43 | 'vggface2': {
44 | 'path': join(_BASE_DATA_PATH, 'VGGFace2'),
45 | 'resize': 256,
46 | 'crop': 224,
47 | 'flip': True,
48 | 'normalize': ((0.5199, 0.4116, 0.3610), (0.2604, 0.2297, 0.2169))
49 | },
50 | 'imagenet_256': {
51 | 'path': join(_BASE_DATA_PATH, 'ILSVRC12_256'),
52 | 'resize': None,
53 | 'crop': 224,
54 | 'flip': True,
55 | 'normalize': ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
56 | },
57 | 'imagenet_subset': {
58 | 'path': join(_BASE_DATA_PATH, 'ILSVRC12_256'),
59 | 'resize': None,
60 | 'crop': 224,
61 | 'flip': True,
62 | 'normalize': ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
63 | 'class_order': [
64 | 68, 56, 78, 8, 23, 84, 90, 65, 74, 76, 40, 89, 3, 92, 55, 9, 26, 80, 43, 38, 58, 70, 77, 1, 85, 19, 17, 50,
65 | 28, 53, 13, 81, 45, 82, 6, 59, 83, 16, 15, 44, 91, 41, 72, 60, 79, 52, 20, 10, 31, 54, 37, 95, 14, 71, 96,
66 | 98, 97, 2, 64, 66, 42, 22, 35, 86, 24, 34, 87, 21, 99, 0, 88, 27, 18, 94, 11, 12, 47, 25, 30, 46, 62, 69,
67 | 36, 61, 7, 63, 75, 5, 32, 4, 51, 48, 73, 93, 39, 67, 29, 49, 57, 33
68 | ]
69 | },
70 | 'imagenet_32_reduced': {
71 | 'path': join(_BASE_DATA_PATH, 'ILSVRC12_32'),
72 | 'resize': None,
73 | 'pad': 4,
74 | 'crop': 32,
75 | 'flip': True,
76 | 'normalize': ((0.481, 0.457, 0.408), (0.260, 0.253, 0.268)),
77 | 'class_order': [
78 | 472, 46, 536, 806, 547, 976, 662, 12, 955, 651, 492, 80, 999, 996, 788, 471, 911, 907, 680, 126, 42, 882,
79 | 327, 719, 716, 224, 918, 647, 808, 261, 140, 908, 833, 925, 57, 388, 407, 215, 45, 479, 525, 641, 915, 923,
80 | 108, 461, 186, 843, 115, 250, 829, 625, 769, 323, 974, 291, 438, 50, 825, 441, 446, 200, 162, 373, 872, 112,
81 | 212, 501, 91, 672, 791, 370, 942, 172, 315, 959, 636, 635, 66, 86, 197, 182, 59, 736, 175, 445, 947, 268,
82 | 238, 298, 926, 851, 494, 760, 61, 293, 696, 659, 69, 819, 912, 486, 706, 343, 390, 484, 282, 729, 575, 731,
83 | 530, 32, 534, 838, 466, 734, 425, 400, 290, 660, 254, 266, 551, 775, 721, 134, 886, 338, 465, 236, 522, 655,
84 | 209, 861, 88, 491, 985, 304, 981, 560, 405, 902, 521, 909, 763, 455, 341, 905, 280, 776, 113, 434, 274, 581,
85 | 158, 738, 671, 702, 147, 718, 148, 35, 13, 585, 591, 371, 745, 281, 956, 935, 346, 352, 284, 604, 447, 415,
86 | 98, 921, 118, 978, 880, 509, 381, 71, 552, 169, 600, 334, 171, 835, 798, 77, 249, 318, 419, 990, 335, 374,
87 | 949, 316, 755, 878, 946, 142, 299, 863, 558, 306, 183, 417, 64, 765, 565, 432, 440, 939, 297, 805, 364, 735,
88 | 251, 270, 493, 94, 773, 610, 278, 16, 363, 92, 15, 593, 96, 468, 252, 699, 377, 95, 799, 868, 820, 328, 756,
89 | 81, 991, 464, 774, 584, 809, 844, 940, 720, 498, 310, 384, 619, 56, 406, 639, 285, 67, 634, 792, 232, 54,
90 | 664, 818, 513, 349, 330, 207, 361, 345, 279, 549, 944, 817, 353, 228, 312, 796, 193, 179, 520, 451, 871,
91 | 692, 60, 481, 480, 929, 499, 673, 331, 506, 70, 645, 759, 744, 459]
92 | }
93 | }
94 |
95 | # Add missing keys:
96 | for dset in dataset_config.keys():
97 | for k in ['resize', 'pad', 'crop', 'normalize', 'class_order', 'extend_channel']:
98 | if k not in dataset_config[dset].keys():
99 | dataset_config[dset][k] = None
100 | if 'flip' not in dataset_config[dset].keys():
101 | dataset_config[dset]['flip'] = False
102 |
--------------------------------------------------------------------------------
/src/datasets/exemplars_dataset.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | from argparse import ArgumentParser
3 |
4 | from datasets.memory_dataset import MemoryDataset
5 |
6 |
7 | class ExemplarsDataset(MemoryDataset):
8 | """Exemplar storage for approaches with an interface of Dataset"""
9 |
10 | def __init__(self, transform, class_indices,
11 | num_exemplars=0, num_exemplars_per_class=0, exemplar_selection='random'):
12 | super().__init__({'x': [], 'y': []}, transform, class_indices=class_indices)
13 | self.max_num_exemplars_per_class = num_exemplars_per_class
14 | self.max_num_exemplars = num_exemplars
15 | assert (num_exemplars_per_class == 0) or (num_exemplars == 0), 'Cannot use both limits at once!'
16 | cls_name = "{}ExemplarsSelector".format(exemplar_selection.capitalize())
17 | selector_cls = getattr(importlib.import_module(name='datasets.exemplars_selection'), cls_name)
18 | self.exemplars_selector = selector_cls(self)
19 |
20 | # Returns a parser containing the approach specific parameters
21 | @staticmethod
22 | def extra_parser(args):
23 | parser = ArgumentParser("Exemplars Management Parameters")
24 | _group = parser.add_mutually_exclusive_group()
25 | _group.add_argument('--num-exemplars', default=0, type=int, required=False,
26 | help='Fixed memory, total number of exemplars (default=%(default)s)')
27 | _group.add_argument('--num-exemplars-per-class', default=0, type=int, required=False,
28 | help='Growing memory, number of exemplars per class (default=%(default)s)')
29 | parser.add_argument('--exemplar-selection', default='random', type=str,
30 | choices=['herding', 'random', 'entropy', 'distance'],
31 | required=False, help='Exemplar selection strategy (default=%(default)s)')
32 | return parser.parse_known_args(args)
33 |
34 | def _is_active(self):
35 | return self.max_num_exemplars_per_class > 0 or self.max_num_exemplars > 0
36 |
37 | def collect_exemplars(self, model, t, trn_loader, selection_transform):
38 | if self._is_active():
39 | self.images, self.labels = self.exemplars_selector(model, t, trn_loader, selection_transform)
40 |
--------------------------------------------------------------------------------
/src/datasets/exemplars_selection.py:
--------------------------------------------------------------------------------
1 | import random
2 | import time
3 | from contextlib import contextmanager
4 | from typing import Iterable
5 |
6 | import numpy as np
7 | import torch
8 | from torch.utils.data import DataLoader, ConcatDataset
9 | from torchvision.transforms import Lambda
10 |
11 | from datasets.exemplars_dataset import ExemplarsDataset
12 | from networks.network import LLL_Net
13 |
14 |
15 | class ExemplarsSelector:
16 | """Exemplar selector for approaches with an interface of Dataset"""
17 |
18 | def __init__(self, exemplars_dataset: ExemplarsDataset):
19 | self.exemplars_dataset = exemplars_dataset
20 |
21 | def __call__(self, model: LLL_Net, t: int, trn_loader: DataLoader, transform):
22 | clock0 = time.time()
23 | exemplars_per_class = self._exemplars_per_class_num(model)
24 | with override_dataset_transform(trn_loader.dataset, transform) as ds_for_selection:
25 | # change loader and fix to go sequentially (shuffle=False), keeps same order for later, eval transforms
26 | sel_loader = DataLoader(ds_for_selection, batch_size=trn_loader.batch_size, shuffle=False,
27 | num_workers=trn_loader.num_workers, pin_memory=trn_loader.pin_memory)
28 | selected_indices = self._select_indices(model, t, sel_loader, exemplars_per_class, transform)
29 | with override_dataset_transform(trn_loader.dataset, Lambda(lambda x: np.array(x))) as ds_for_raw:
30 | x, y = zip(*(ds_for_raw[idx] for idx in selected_indices))
31 | clock1 = time.time()
32 | print('| Selected {:d} train exemplars, time={:5.1f}s'.format(len(x), clock1 - clock0))
33 | return x, y
34 |
35 | def _exemplars_per_class_num(self, model: LLL_Net):
36 | if self.exemplars_dataset.max_num_exemplars_per_class:
37 | return self.exemplars_dataset.max_num_exemplars_per_class
38 |
39 | num_cls = model.task_cls.sum().item()
40 | num_exemplars = self.exemplars_dataset.max_num_exemplars
41 | exemplars_per_class = int(np.ceil(num_exemplars / num_cls))
42 | assert exemplars_per_class > 0, \
43 | "Not enough exemplars to cover all classes!\n" \
44 | "Number of classes so far: {}. " \
45 | "Limit of exemplars: {}".format(num_cls,
46 | num_exemplars)
47 | return exemplars_per_class
48 |
49 | def _select_indices(self, model: LLL_Net, t: int, sel_loader: DataLoader, exemplars_per_class: int, transform) -> Iterable:
50 | pass
51 |
52 |
53 | class RandomExemplarsSelector(ExemplarsSelector):
54 | """Selection of new samples. This is based on random selection, which produces a random list of samples."""
55 |
56 | def __init__(self, exemplars_dataset):
57 | super().__init__(exemplars_dataset)
58 |
59 | def _select_indices(self, model: LLL_Net, t: int, sel_loader: DataLoader, exemplars_per_class: int, transform) -> Iterable:
60 | num_cls = sum(model.task_cls)
61 | result = []
62 | labels = self._get_labels(sel_loader)
63 | for curr_cls in range(num_cls):
64 | # get all indices from current class -- check if there are exemplars from previous task in the loader
65 | cls_ind = np.where(labels == curr_cls)[0]
66 | assert (len(cls_ind) > 0), "No samples to choose from for class {:d}".format(curr_cls)
67 | assert (exemplars_per_class <= len(cls_ind)), "Not enough samples to store"
68 | # select the exemplars randomly
69 | result.extend(random.sample(list(cls_ind), exemplars_per_class))
70 | return result
71 |
72 | def _get_labels(self, sel_loader):
73 | if hasattr(sel_loader.dataset, 'labels'): # BaseDataset, MemoryDataset
74 | labels = np.asarray(sel_loader.dataset.labels)
75 | elif isinstance(sel_loader.dataset, ConcatDataset):
76 | labels = []
77 | for ds in sel_loader.dataset.datasets:
78 | labels.extend(ds.labels)
79 | labels = np.array(labels)
80 | else:
81 | raise RuntimeError("Unsupported dataset: {}".format(sel_loader.dataset.__class__.__name__))
82 | return labels
83 |
84 |
85 | class HerdingExemplarsSelector(ExemplarsSelector):
86 | """Selection of new samples. This is based on herding selection, which produces a sorted list of samples of one
87 | class based on the distance to the mean sample of that class. From iCaRL algorithm 4 and 5:
88 | https://openaccess.thecvf.com/content_cvpr_2017/papers/Rebuffi_iCaRL_Incremental_Classifier_CVPR_2017_paper.pdf
89 | """
90 | def __init__(self, exemplars_dataset):
91 | super().__init__(exemplars_dataset)
92 |
93 | def _select_indices(self, model: LLL_Net, t: int, sel_loader: DataLoader, exemplars_per_class: int, transform) -> Iterable:
94 | model_device = next(model.parameters()).device # we assume here that whole model is on a single device
95 |
96 | # extract outputs from the model for all train samples
97 | extracted_features = []
98 | extracted_targets = []
99 | with torch.no_grad():
100 | model.eval()
101 | for images, targets in sel_loader:
102 | model.apply(lambda m: setattr(m, 'shift', 1*t))
103 | model.apply(lambda m: setattr(m, 'width_mult', 1.0))
104 | feats = model(images.to(model_device), return_features=True)[1]
105 | feats = feats / feats.norm(dim=1).view(-1, 1) # Feature normalization
106 | extracted_features.append(feats)
107 | extracted_targets.extend(targets)
108 | extracted_features = (torch.cat(extracted_features)).cpu()
109 | extracted_targets = np.array(extracted_targets)
110 | result = []
111 | # iterate through all classes
112 | for curr_cls in np.unique(extracted_targets):
113 | # get all indices from current class
114 | cls_ind = np.where(extracted_targets == curr_cls)[0]
115 | assert (len(cls_ind) > 0), "No samples to choose from for class {:d}".format(curr_cls)
116 | assert (exemplars_per_class <= len(cls_ind)), "Not enough samples to store"
117 | # get all extracted features for current class
118 | cls_feats = extracted_features[cls_ind]
119 | # calculate the mean
120 | cls_mu = cls_feats.mean(0)
121 | # select the exemplars closer to the mean of each class
122 | selected = []
123 | selected_feat = []
124 | for k in range(exemplars_per_class):
125 | # fix this to the dimension of the model features
126 | sum_others = torch.zeros(cls_feats.shape[1])
127 | for j in selected_feat:
128 | sum_others += j / (k + 1)
129 | dist_min = np.inf
130 | # choose the closest to the mean of the current class
131 | for item in cls_ind:
132 | if item not in selected:
133 | feat = extracted_features[item]
134 | dist = torch.norm(cls_mu - feat / (k + 1) - sum_others)
135 | if dist < dist_min:
136 | dist_min = dist
137 | newone = item
138 | newonefeat = feat
139 | selected_feat.append(newonefeat)
140 | selected.append(newone)
141 | result.extend(selected)
142 | return result
143 |
144 |
145 | class EntropyExemplarsSelector(ExemplarsSelector):
146 | """Selection of new samples. This is based on entropy selection, which produces a sorted list of samples of one
147 | class based on entropy of each sample. From RWalk http://arxiv-export-lb.library.cornell.edu/pdf/1801.10112
148 | """
149 | def __init__(self, exemplars_dataset):
150 | super().__init__(exemplars_dataset)
151 |
152 | def _select_indices(self, model: LLL_Net, sel_loader: DataLoader, exemplars_per_class: int, transform) -> Iterable:
153 | model_device = next(model.parameters()).device # we assume here that whole model is on a single device
154 |
155 | # extract outputs from the model for all train samples
156 | extracted_logits = []
157 | extracted_targets = []
158 | with torch.no_grad():
159 | model.eval()
160 | for images, targets in sel_loader:
161 | extracted_logits.append(torch.cat(model(images.to(model_device)), dim=1))
162 | extracted_targets.extend(targets)
163 | extracted_logits = (torch.cat(extracted_logits)).cpu()
164 | extracted_targets = np.array(extracted_targets)
165 | result = []
166 | # iterate through all classes
167 | for curr_cls in np.unique(extracted_targets):
168 | # get all indices from current class
169 | cls_ind = np.where(extracted_targets == curr_cls)[0]
170 | assert (len(cls_ind) > 0), "No samples to choose from for class {:d}".format(curr_cls)
171 | assert (exemplars_per_class <= len(cls_ind)), "Not enough samples to store"
172 | # get all extracted features for current class
173 | cls_logits = extracted_logits[cls_ind]
174 | # select the exemplars with higher entropy (lower: -entropy)
175 | probs = torch.softmax(cls_logits, dim=1)
176 | log_probs = torch.log(probs)
177 | minus_entropy = (probs * log_probs).sum(1) # change sign of this variable for inverse order
178 | selected = cls_ind[minus_entropy.sort()[1][:exemplars_per_class]]
179 | result.extend(selected)
180 | return result
181 |
182 |
183 | class DistanceExemplarsSelector(ExemplarsSelector):
184 | """Selection of new samples. This is based on distance-based selection, which produces a sorted list of samples of
185 | one class based on closeness to decision boundary of each sample. From RWalk
186 | http://arxiv-export-lb.library.cornell.edu/pdf/1801.10112
187 | """
188 | def __init__(self, exemplars_dataset):
189 | super().__init__(exemplars_dataset)
190 |
191 | def _select_indices(self, model: LLL_Net, sel_loader: DataLoader, exemplars_per_class: int,
192 | transform) -> Iterable:
193 | model_device = next(model.parameters()).device # we assume here that whole model is on a single device
194 |
195 | # extract outputs from the model for all train samples
196 | extracted_logits = []
197 | extracted_targets = []
198 | with torch.no_grad():
199 | model.eval()
200 | for images, targets in sel_loader:
201 | extracted_logits.append(torch.cat(model(images.to(model_device)), dim=1))
202 | extracted_targets.extend(targets)
203 | extracted_logits = (torch.cat(extracted_logits)).cpu()
204 | extracted_targets = np.array(extracted_targets)
205 | result = []
206 | # iterate through all classes
207 | for curr_cls in np.unique(extracted_targets):
208 | # get all indices from current class
209 | cls_ind = np.where(extracted_targets == curr_cls)[0]
210 | assert (len(cls_ind) > 0), "No samples to choose from for class {:d}".format(curr_cls)
211 | assert (exemplars_per_class <= len(cls_ind)), "Not enough samples to store"
212 | # get all extracted features for current class
213 | cls_logits = extracted_logits[cls_ind]
214 | # select the exemplars closer to boundary
215 | distance = cls_logits[:, curr_cls] # change sign of this variable for inverse order
216 | selected = cls_ind[distance.sort()[1][:exemplars_per_class]]
217 | result.extend(selected)
218 | return result
219 |
220 |
221 | def dataset_transforms(dataset, transform_to_change):
222 | if isinstance(dataset, ConcatDataset):
223 | r = []
224 | for ds in dataset.datasets:
225 | r += dataset_transforms(ds, transform_to_change)
226 | return r
227 | else:
228 | old_transform = dataset.transform
229 | dataset.transform = transform_to_change
230 | return [(dataset, old_transform)]
231 |
232 |
233 | @contextmanager
234 | def override_dataset_transform(dataset, transform):
235 | try:
236 | datasets_with_orig_transform = dataset_transforms(dataset, transform)
237 | yield dataset
238 | finally:
239 | # get bac original transformations
240 | for ds, orig_transform in datasets_with_orig_transform:
241 | ds.transform = orig_transform
242 |
--------------------------------------------------------------------------------
/src/datasets/memory_dataset.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | from PIL import Image
4 | from torch.utils.data import Dataset
5 |
6 |
7 | class MemoryDataset(Dataset):
8 | """Characterizes a dataset for PyTorch -- this dataset pre-loads all images in memory"""
9 |
10 | def __init__(self, data, transform, class_indices=None):
11 | """Initialization"""
12 | self.labels = data['y']
13 | self.images = data['x']
14 | self.transform = transform
15 | self.class_indices = class_indices
16 |
17 | def __len__(self):
18 | """Denotes the total number of samples"""
19 | return len(self.images)
20 |
21 | def __getitem__(self, index):
22 | """Generates one sample of data"""
23 | x = Image.fromarray(self.images[index])
24 | x = self.transform(x)
25 | y = self.labels[index]
26 | return x, y
27 |
28 |
29 | def get_data(trn_data, tst_data, num_tasks, nc_first_task, validation, shuffle_classes, class_order=None):
30 | """Prepare data: dataset splits, task partition, class order"""
31 |
32 | data = {}
33 | taskcla = []
34 | if class_order is None:
35 | num_classes = len(np.unique(trn_data['y']))
36 | class_order = list(range(num_classes))
37 | else:
38 | num_classes = len(class_order)
39 | class_order = class_order.copy()
40 | if shuffle_classes:
41 | np.random.shuffle(class_order)
42 |
43 | # compute classes per task and num_tasks
44 | if nc_first_task is None:
45 | cpertask = np.array([num_classes // num_tasks] * num_tasks)
46 | for i in range(num_classes % num_tasks):
47 | cpertask[i] += 1
48 | else:
49 | assert nc_first_task < num_classes, "first task wants more classes than exist"
50 | remaining_classes = num_classes - nc_first_task
51 | assert remaining_classes >= (num_tasks - 1), "at least one class is needed per task" # better minimum 2
52 | cpertask = np.array([nc_first_task] + [remaining_classes // (num_tasks - 1)] * (num_tasks - 1))
53 | for i in range(remaining_classes % (num_tasks - 1)):
54 | cpertask[i + 1] += 1
55 |
56 | assert num_classes == cpertask.sum(), "something went wrong, the split does not match num classes"
57 | cpertask_cumsum = np.cumsum(cpertask)
58 | init_class = np.concatenate(([0], cpertask_cumsum[:-1]))
59 |
60 | # initialize data structure
61 | for tt in range(num_tasks):
62 | data[tt] = {}
63 | data[tt]['name'] = 'task-' + str(tt)
64 | data[tt]['trn'] = {'x': [], 'y': []}
65 | data[tt]['val'] = {'x': [], 'y': []}
66 | data[tt]['tst'] = {'x': [], 'y': []}
67 |
68 | # ALL OR TRAIN
69 | filtering = np.isin(trn_data['y'], class_order)
70 | if filtering.sum() != len(trn_data['y']):
71 | trn_data['x'] = trn_data['x'][filtering]
72 | trn_data['y'] = np.array(trn_data['y'])[filtering]
73 | for this_image, this_label in zip(trn_data['x'], trn_data['y']):
74 | # If shuffling is false, it won't change the class number
75 | this_label = class_order.index(this_label)
76 | # add it to the corresponding split
77 | this_task = (this_label >= cpertask_cumsum).sum()
78 | data[this_task]['trn']['x'].append(this_image)
79 | data[this_task]['trn']['y'].append(this_label - init_class[this_task])
80 |
81 | # ALL OR TEST
82 | filtering = np.isin(tst_data['y'], class_order)
83 | if filtering.sum() != len(tst_data['y']):
84 | tst_data['x'] = tst_data['x'][filtering]
85 | tst_data['y'] = tst_data['y'][filtering]
86 | for this_image, this_label in zip(tst_data['x'], tst_data['y']):
87 | # If shuffling is false, it won't change the class number
88 | this_label = class_order.index(this_label)
89 | # add it to the corresponding split
90 | this_task = (this_label >= cpertask_cumsum).sum()
91 | data[this_task]['tst']['x'].append(this_image)
92 | data[this_task]['tst']['y'].append(this_label - init_class[this_task])
93 |
94 | # check classes
95 | for tt in range(num_tasks):
96 | data[tt]['ncla'] = len(np.unique(data[tt]['trn']['y']))
97 | assert data[tt]['ncla'] == cpertask[tt], "something went wrong splitting classes"
98 |
99 | # validation
100 | if validation > 0.0:
101 | for tt in data.keys():
102 | for cc in range(data[tt]['ncla']):
103 | cls_idx = list(np.where(np.asarray(data[tt]['trn']['y']) == cc)[0])
104 | rnd_img = random.sample(cls_idx, int(np.round(len(cls_idx) * validation)))
105 | rnd_img.sort(reverse=True)
106 | for ii in range(len(rnd_img)):
107 | data[tt]['val']['x'].append(data[tt]['trn']['x'][rnd_img[ii]])
108 | data[tt]['val']['y'].append(data[tt]['trn']['y'][rnd_img[ii]])
109 | data[tt]['trn']['x'].pop(rnd_img[ii])
110 | data[tt]['trn']['y'].pop(rnd_img[ii])
111 |
112 | # convert them to numpy arrays
113 | for tt in data.keys():
114 | for split in ['trn', 'val', 'tst']:
115 | data[tt][split]['x'] = np.asarray(data[tt][split]['x'])
116 |
117 | # other
118 | n = 0
119 | for t in data.keys():
120 | taskcla.append((t, data[t]['ncla']))
121 | n += data[t]['ncla']
122 | data['ncla'] = n
123 |
124 | return data, taskcla, class_order
125 |
--------------------------------------------------------------------------------
/src/loggers/README.md:
--------------------------------------------------------------------------------
1 | # Loggers
2 |
3 | We include a disk logger, which logs into files and folders in the disk. We also provide a tensorboard logger which
4 | provides a faster way of analysing a training process without need of further development. They can be specified with
5 | `--log` followed by `disk`, `tensorboard` or both. Custom loggers can be defined by inheriting the `ExperimentLogger`
6 | in [exp_logger.py](exp_logger.py).
7 |
8 | When enabled, both loggers will output everything in the path `[RESULTS_PATH]/[DATASETS]_[APPROACH]_[EXP_NAME]` or
9 | `[RESULTS_PATH]/[DATASETS]_[APPROACH]` if `--exp-name` is not set.
10 |
11 | ## Disk logger
12 | The disk logger outputs the following file and folder structure:
13 | - **figures/**: folder where generated figures are logged.
14 | - **models/**: folder where model weight checkpoints are saved.
15 | - **results/**: folder containing the results.
16 | - **acc_tag**: task-agnostic accuracy table.
17 | - **acc_taw**: task-aware accuracy table.
18 | - **avg_acc_tag**: task-agnostic average accuracies.
19 | - **avg_acc_taw**: task-agnostic average accuracies.
20 | - **forg_tag**: task-agnostic forgetting table.
21 | - **forg_taw**: task-aware forgetting table.
22 | - **wavg_acc_tag**: task-agnostic average accuracies weighted according to the number of classes of each task.
23 | - **wavg_acc_taw**: task-aware average accuracies weighted according to the number of classes of each task.
24 | - **raw_log**: json file containing all the logged metrics easily read by many tools (e.g. `pandas`).
25 | - stdout: a copy from the standard output of the terminal.
26 | - stderr: a copy from the error output of the terminal.
27 |
28 | ## TensorBoard logger
29 | The tensorboard logger outputs analogous metrics to the disk logger separated into different tabs according to the task
30 | and different graphs according to the data splits.
31 |
32 | Screenshot for a 10 task experiment, showing the last task plots:
33 |
34 |
35 |
36 |
--------------------------------------------------------------------------------
/src/loggers/__pycache__/disk_logger.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FireFYF/ScrollNet/2b003e328666bcd796f768cc8111bed66ba232b4/src/loggers/__pycache__/disk_logger.cpython-37.pyc
--------------------------------------------------------------------------------
/src/loggers/__pycache__/disk_logger.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FireFYF/ScrollNet/2b003e328666bcd796f768cc8111bed66ba232b4/src/loggers/__pycache__/disk_logger.cpython-38.pyc
--------------------------------------------------------------------------------
/src/loggers/__pycache__/exp_logger.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FireFYF/ScrollNet/2b003e328666bcd796f768cc8111bed66ba232b4/src/loggers/__pycache__/exp_logger.cpython-37.pyc
--------------------------------------------------------------------------------
/src/loggers/__pycache__/exp_logger.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FireFYF/ScrollNet/2b003e328666bcd796f768cc8111bed66ba232b4/src/loggers/__pycache__/exp_logger.cpython-38.pyc
--------------------------------------------------------------------------------
/src/loggers/disk_logger.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import json
4 | import torch
5 | import numpy as np
6 | from datetime import datetime
7 |
8 | from loggers.exp_logger import ExperimentLogger
9 |
10 |
11 | class Logger(ExperimentLogger):
12 | """Characterizes a disk logger"""
13 |
14 | def __init__(self, log_path, exp_name, begin_time=None):
15 | super(Logger, self).__init__(log_path, exp_name, begin_time)
16 |
17 | self.begin_time_str = self.begin_time.strftime("%Y-%m-%d-%H-%M")
18 |
19 | # Duplicate standard outputs
20 | sys.stdout = FileOutputDuplicator(sys.stdout,
21 | os.path.join(self.exp_path, 'stdout-{}.txt'.format(self.begin_time_str)), 'w')
22 | sys.stderr = FileOutputDuplicator(sys.stderr,
23 | os.path.join(self.exp_path, 'stderr-{}.txt'.format(self.begin_time_str)), 'w')
24 |
25 | # Raw log file
26 | self.raw_log_file = open(os.path.join(self.exp_path, "raw_log-{}.txt".format(self.begin_time_str)), 'a')
27 |
28 | def log_scalar(self, task, iter, name, value, group=None, curtime=None):
29 | if curtime is None:
30 | curtime = datetime.now()
31 |
32 | # Raw dump
33 | entry = {"task": task, "iter": iter, "name": name, "value": value, "group": group,
34 | "time": curtime.strftime("%Y-%m-%d-%H-%M")}
35 | self.raw_log_file.write(json.dumps(entry, sort_keys=True) + "\n")
36 | self.raw_log_file.flush()
37 |
38 | def log_args(self, args):
39 | with open(os.path.join(self.exp_path, 'args-{}.txt'.format(self.begin_time_str)), 'w') as f:
40 | json.dump(args.__dict__, f, separators=(',\n', ' : '), sort_keys=True)
41 |
42 | def log_result(self, array, name, step):
43 | if array.ndim <= 1:
44 | array = array[None]
45 | np.savetxt(os.path.join(self.exp_path, 'results', '{}-{}.txt'.format(name, self.begin_time_str)),
46 | array, '%.6f', delimiter='\t')
47 |
48 | def log_figure(self, name, iter, figure, curtime=None):
49 | curtime = datetime.now()
50 | figure.savefig(os.path.join(self.exp_path, 'figures',
51 | '{}_{}-{}.png'.format(name, iter, curtime.strftime("%Y-%m-%d-%H-%M-%S"))))
52 | figure.savefig(os.path.join(self.exp_path, 'figures',
53 | '{}_{}-{}.pdf'.format(name, iter, curtime.strftime("%Y-%m-%d-%H-%M-%S"))))
54 |
55 | def save_model(self, state_dict, task):
56 | torch.save(state_dict, os.path.join(self.exp_path, "models", "task{}.ckpt".format(task)))
57 |
58 | def __del__(self):
59 | self.raw_log_file.close()
60 |
61 |
62 | class FileOutputDuplicator(object):
63 | def __init__(self, duplicate, fname, mode):
64 | self.file = open(fname, mode)
65 | self.duplicate = duplicate
66 |
67 | def __del__(self):
68 | self.file.close()
69 |
70 | def write(self, data):
71 | self.file.write(data)
72 | self.duplicate.write(data)
73 |
74 | def flush(self):
75 | self.file.flush()
76 |
--------------------------------------------------------------------------------
/src/loggers/exp_logger.py:
--------------------------------------------------------------------------------
1 | import os
2 | import importlib
3 | from datetime import datetime
4 |
5 |
6 | class ExperimentLogger:
7 | """Main class for experiment logging"""
8 |
9 | def __init__(self, log_path, exp_name, begin_time=None):
10 | self.log_path = log_path
11 | self.exp_name = exp_name
12 | self.exp_path = os.path.join(log_path, exp_name)
13 | if begin_time is None:
14 | self.begin_time = datetime.now()
15 | else:
16 | self.begin_time = begin_time
17 |
18 | def log_scalar(self, task, iter, name, value, group=None, curtime=None):
19 | pass
20 |
21 | def log_args(self, args):
22 | pass
23 |
24 | def log_result(self, array, name, step):
25 | pass
26 |
27 | def log_figure(self, name, iter, figure, curtime=None):
28 | pass
29 |
30 | def save_model(self, state_dict, task):
31 | pass
32 |
33 |
34 | class MultiLogger(ExperimentLogger):
35 | """This class allows to use multiple loggers"""
36 |
37 | def __init__(self, log_path, exp_name, loggers=None, save_models=True):
38 | super(MultiLogger, self).__init__(log_path, exp_name)
39 | if os.path.exists(self.exp_path):
40 | print("WARNING: {} already exists!".format(self.exp_path))
41 | else:
42 | os.makedirs(os.path.join(self.exp_path, 'models'))
43 | os.makedirs(os.path.join(self.exp_path, 'results'))
44 | os.makedirs(os.path.join(self.exp_path, 'figures'))
45 |
46 | self.save_models = save_models
47 | self.loggers = []
48 | for l in loggers:
49 | lclass = getattr(importlib.import_module(name='loggers.' + l + '_logger'), 'Logger')
50 | self.loggers.append(lclass(self.log_path, self.exp_name))
51 |
52 | def log_scalar(self, task, iter, name, value, group=None, curtime=None):
53 | if curtime is None:
54 | curtime = datetime.now()
55 | for l in self.loggers:
56 | l.log_scalar(task, iter, name, value, group, curtime)
57 |
58 | def log_args(self, args):
59 | for l in self.loggers:
60 | l.log_args(args)
61 |
62 | def log_result(self, array, name, step):
63 | for l in self.loggers:
64 | l.log_result(array, name, step)
65 |
66 | def log_figure(self, name, iter, figure, curtime=None):
67 | if curtime is None:
68 | curtime = datetime.now()
69 | for l in self.loggers:
70 | l.log_figure(name, iter, figure, curtime)
71 |
72 | def save_model(self, state_dict, task):
73 | if self.save_models:
74 | for l in self.loggers:
75 | l.save_model(state_dict, task)
76 |
--------------------------------------------------------------------------------
/src/loggers/tensorboard_logger.py:
--------------------------------------------------------------------------------
1 | from torch.utils.tensorboard import SummaryWriter
2 |
3 | from loggers.exp_logger import ExperimentLogger
4 | import json
5 | import numpy as np
6 |
7 |
8 | class Logger(ExperimentLogger):
9 | """Characterizes a Tensorboard logger"""
10 |
11 | def __init__(self, log_path, exp_name, begin_time=None):
12 | super(Logger, self).__init__(log_path, exp_name, begin_time)
13 | self.tbwriter = SummaryWriter(self.exp_path)
14 |
15 | def log_scalar(self, task, iter, name, value, group=None, curtime=None):
16 | self.tbwriter.add_scalar(tag="t{}/{}_{}".format(task, group, name),
17 | scalar_value=value,
18 | global_step=iter)
19 | self.tbwriter.file_writer.flush()
20 |
21 | def log_figure(self, name, iter, figure, curtime=None):
22 | self.tbwriter.add_figure(tag=name, figure=figure, global_step=iter)
23 | self.tbwriter.file_writer.flush()
24 |
25 | def log_args(self, args):
26 | self.tbwriter.add_text(
27 | 'args',
28 | json.dumps(args.__dict__,
29 | separators=(',\n', ' : '),
30 | sort_keys=True))
31 | self.tbwriter.file_writer.flush()
32 |
33 | def log_result(self, array, name, step):
34 | if array.ndim == 1:
35 | # log as scalars
36 | self.tbwriter.add_scalar(f'results/{name}', array[step], step)
37 |
38 | elif array.ndim == 2:
39 | s = ""
40 | i = step
41 | # for i in range(array.shape[0]):
42 | for j in range(array.shape[1]):
43 | s += '{:5.1f}% '.format(100 * array[i, j])
44 | if np.trace(array) == 0.0:
45 | if i > 0:
46 | s += '\tAvg.:{:5.1f}% \n'.format(100 * array[i, :i].mean())
47 | else:
48 | s += '\tAvg.:{:5.1f}% \n'.format(100 * array[i, :i + 1].mean())
49 | self.tbwriter.add_text(f'results/{name}', s, step)
50 |
51 | def __del__(self):
52 | self.tbwriter.close()
53 |
--------------------------------------------------------------------------------
/src/main_incremental.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import torch
4 | import argparse
5 | import importlib
6 | import numpy as np
7 | from functools import reduce
8 |
9 | import utils
10 | import approach
11 | from loggers.exp_logger import MultiLogger
12 | from datasets.data_loader import get_loaders
13 | from datasets.dataset_config import dataset_config
14 | from networks import allmodels
15 |
16 | def main(argv=None):
17 | tstart = time.time()
18 | # Arguments
19 | parser = argparse.ArgumentParser(description='ScrollNet: Dynamic Weight Importance for Continual Learning')
20 |
21 | # miscellaneous args
22 | parser.add_argument('--gpu', type=int, default=0,
23 | help='GPU (default=%(default)s)')
24 | parser.add_argument('--results-path', type=str, default='../results',
25 | help='Results path (default=%(default)s)')
26 | parser.add_argument('--exp-name', default=None, type=str,
27 | help='Experiment name (default=%(default)s)')
28 | parser.add_argument('--seed', type=int, default=0,
29 | help='Random seed (default=%(default)s)')
30 | parser.add_argument('--log', default=['disk'], type=str, choices=['disk', 'tensorboard'],
31 | help='Loggers used (disk, tensorboard) (default=%(default)s)', nargs='*', metavar="LOGGER")
32 | parser.add_argument('--save-models', action='store_true',
33 | help='Save trained models (default=%(default)s)')
34 | parser.add_argument('--no-cudnn-deterministic', action='store_true',
35 | help='Disable CUDNN deterministic (default=%(default)s)')
36 | # dataset args
37 | parser.add_argument('--datasets', default=['cifar100'], type=str, choices=list(dataset_config.keys()),
38 | help='Dataset or datasets used (default=%(default)s)', nargs='+', metavar="DATASET")
39 | parser.add_argument('--num-workers', default=4, type=int, required=False,
40 | help='Number of subprocesses to use for dataloader (default=%(default)s)')
41 | parser.add_argument('--pin-memory', default=False, type=bool, required=False,
42 | help='Copy Tensors into CUDA pinned memory before returning them (default=%(default)s)')
43 | parser.add_argument('--batch-size', default=64, type=int, required=False,
44 | help='Number of samples per batch to load (default=%(default)s)')
45 | parser.add_argument('--num-tasks', default=10, type=int, required=False,
46 | help='Number of tasks per dataset (default=%(default)s)')
47 | parser.add_argument('--nc-first-task', default=None, type=int, required=False,
48 | help='Number of classes of the first task (default=%(default)s)')
49 | parser.add_argument('--use-valid-only', action='store_true',
50 | help='Use validation split instead of test (default=%(default)s)')
51 | parser.add_argument('--stop-at-task', default=0, type=int, required=False,
52 | help='Stop training after specified task (default=%(default)s)')
53 | # model args
54 | parser.add_argument('--network', default='scroll_resnet18', type=str, choices=allmodels,
55 | help='Network architecture used (default=%(default)s)', metavar="NETWORK")
56 | parser.add_argument('--keep-existing-head', action='store_true',
57 | help='Disable removing classifier last layer (default=%(default)s)')
58 | parser.add_argument('--pretrained', action='store_true',
59 | help='Use pretrained backbone (default=%(default)s)')
60 | # training args
61 | parser.add_argument('--approach', default='finetuning', type=str, choices=approach.__all__,
62 | help='Learning approach used (default=%(default)s)', metavar="APPROACH")
63 | parser.add_argument('--nepochs', default=200, type=int, required=False,
64 | help='Number of epochs per training session (default=%(default)s)')
65 | parser.add_argument('--lr', default=0.1, type=float, required=False,
66 | help='Starting learning rate (default=%(default)s)')
67 | parser.add_argument('--decay-mile-stone', nargs='+', type=int,
68 | help='mile stone of learning rate decay')
69 | parser.add_argument('--lr-decay', type=float, default=0.1,
70 | help='ratio of learning rate decay')
71 | parser.add_argument('--clipping', default=10000, type=float, required=False,
72 | help='Clip gradient norm (default=%(default)s)')
73 | parser.add_argument('--momentum', default=0.0, type=float, required=False,
74 | help='Momentum factor (default=%(default)s)')
75 | parser.add_argument('--weight-decay', default=0.0, type=float, required=False,
76 | help='Weight decay (L2 penalty) (default=%(default)s)')
77 | parser.add_argument('--multi-softmax', action='store_true',
78 | help='Apply separate softmax for each task (default=%(default)s)')
79 | parser.add_argument('--fix-bn', action='store_true',
80 | help='Fix batch normalization after first task (default=%(default)s)')
81 | parser.add_argument('--eval-on-train', action='store_true',
82 | help='Show train loss and accuracy (default=%(default)s)')
83 | # scrolling args
84 | parser.add_argument('--scroll_step', default=1, type=int,
85 | help='Scrolling step size.')
86 |
87 | # Args -- Incremental Learning Framework
88 | args, extra_args = parser.parse_known_args(argv)
89 | args.results_path = os.path.expanduser(args.results_path)
90 | base_kwargs = dict(nepochs=args.nepochs, lr=args.lr, clipgrad=args.clipping, momentum=args.momentum,
91 | wd=args.weight_decay, multi_softmax=args.multi_softmax, scroll_step=args.scroll_step,
92 | fix_bn=args.fix_bn, eval_on_train=args.eval_on_train)
93 |
94 | if args.no_cudnn_deterministic:
95 | print('WARNING: CUDNN Deterministic will be disabled.')
96 | utils.cudnn_deterministic = False
97 |
98 | utils.seed_everything(seed=args.seed)
99 | print('=' * 108)
100 | print('Arguments =')
101 | for arg in np.sort(list(vars(args).keys())):
102 | print('\t' + arg + ':', getattr(args, arg))
103 | print('=' * 108)
104 |
105 | # Args -- CUDA
106 | if torch.cuda.is_available():
107 | torch.cuda.set_device(args.gpu)
108 | device = 'cuda'
109 | else:
110 | print('WARNING: [CUDA unavailable] Using CPU instead!')
111 | device = 'cpu'
112 |
113 | # Args -- Network
114 | from networks.network import LLL_Net
115 | net = getattr(importlib.import_module(name='networks'), args.network)
116 | init_model = net(pretrained=False)
117 |
118 | # Args -- Continual Learning Approach
119 | from approach.incremental_learning import Inc_Learning_Appr
120 | Appr = getattr(importlib.import_module(name='approach.' + args.approach), 'Appr')
121 | assert issubclass(Appr, Inc_Learning_Appr)
122 | appr_args, extra_args = Appr.extra_parser(extra_args)
123 | print('Approach arguments =')
124 | for arg in np.sort(list(vars(appr_args).keys())):
125 | print('\t' + arg + ':', getattr(appr_args, arg))
126 | print('=' * 108)
127 |
128 | # Args -- Exemplars Management
129 | from datasets.exemplars_dataset import ExemplarsDataset
130 | Appr_ExemplarsDataset = Appr.exemplars_dataset_class()
131 | if Appr_ExemplarsDataset:
132 | assert issubclass(Appr_ExemplarsDataset, ExemplarsDataset)
133 | appr_exemplars_dataset_args, extra_args = Appr_ExemplarsDataset.extra_parser(extra_args)
134 | print('Exemplars dataset arguments =')
135 | for arg in np.sort(list(vars(appr_exemplars_dataset_args).keys())):
136 | print('\t' + arg + ':', getattr(appr_exemplars_dataset_args, arg))
137 | print('=' * 108)
138 | else:
139 | appr_exemplars_dataset_args = argparse.Namespace()
140 |
141 | # Log all arguments
142 | full_exp_name = reduce((lambda x, y: x[0] + y[0]), args.datasets) if len(args.datasets) > 0 else args.datasets[0]
143 | full_exp_name += '_' + args.approach
144 | if args.exp_name is not None:
145 | full_exp_name += '_' + args.exp_name
146 | logger = MultiLogger(args.results_path, full_exp_name, loggers=args.log, save_models=args.save_models)
147 | logger.log_args(argparse.Namespace(**args.__dict__, **appr_args.__dict__, **appr_exemplars_dataset_args.__dict__))
148 |
149 | # Loaders
150 | utils.seed_everything(seed=args.seed)
151 | trn_loader, val_loader, tst_loader, taskcla = get_loaders(args.datasets, args.num_tasks, args.nc_first_task,
152 | args.batch_size, num_workers=args.num_workers,
153 | pin_memory=args.pin_memory)
154 | # Apply arguments for loaders
155 | if args.use_valid_only:
156 | tst_loader = val_loader
157 | max_task = len(taskcla) if args.stop_at_task == 0 else args.stop_at_task
158 |
159 | # Network and Approach instances
160 | utils.seed_everything(seed=args.seed)
161 | net = LLL_Net(init_model, remove_existing_head=not args.keep_existing_head)
162 | utils.seed_everything(seed=args.seed)
163 | # taking transformations and class indices from first train dataset
164 | first_train_ds = trn_loader[0].dataset
165 | transform, class_indices = first_train_ds.transform, first_train_ds.class_indices
166 | appr_kwargs = {**base_kwargs, **dict(logger=logger, **appr_args.__dict__)}
167 | if Appr_ExemplarsDataset:
168 | appr_kwargs['exemplars_dataset'] = Appr_ExemplarsDataset(transform, class_indices,
169 | **appr_exemplars_dataset_args.__dict__)
170 | utils.seed_everything(seed=args.seed)
171 | appr = Appr(net, device, **appr_kwargs)
172 |
173 | # Loop tasks
174 | print(taskcla)
175 | acc_taw = np.zeros((max_task, max_task))
176 | acc_tag = np.zeros((max_task, max_task))
177 | forg_taw = np.zeros((max_task, max_task))
178 | forg_tag = np.zeros((max_task, max_task))
179 | for t, (_, ncla) in enumerate(taskcla):
180 | # Early stop tasks if flag
181 | if t >= max_task:
182 | continue
183 |
184 | print('*' * 108)
185 | print('Task {:2d}'.format(t))
186 | print('*' * 108)
187 |
188 | # Add head for current task
189 | net.add_head(taskcla[t][1])
190 | net.to(device)
191 |
192 | # Train
193 | appr.train(t, trn_loader[t], val_loader[t])
194 | print('-' * 108)
195 |
196 | # Test
197 | for u in range(t + 1):
198 | test_loss, acc_taw[t, u], acc_tag[t, u] = appr.eval(u, tst_loader[u], t)
199 |
200 | if u < t:
201 | forg_taw[t, u] = acc_taw[:t, u].max(0) - acc_taw[t, u]
202 | forg_tag[t, u] = acc_tag[:t, u].max(0) - acc_tag[t, u]
203 | print('>>> Test on task {:2d} : loss={:.3f} | TAw acc={:5.1f}%, forg={:5.1f}%'
204 | '| TAg acc={:5.1f}%, forg={:5.1f}% <<<'.format(u, test_loss,
205 | 100 * acc_taw[t, u],
206 | 100 * forg_taw[t, u],
207 | 100 * acc_tag[t, u],
208 | 100 * forg_tag[t, u]))
209 | logger.log_scalar(task=t, iter=u, name='loss', group='test', value=test_loss)
210 | logger.log_scalar(task=t, iter=u, name='acc_taw', group='test', value=100 * acc_taw[t, u])
211 | logger.log_scalar(task=t, iter=u, name='acc_tag', group='test', value=100 * acc_tag[t, u])
212 | logger.log_scalar(task=t, iter=u, name='forg_taw', group='test', value=100 * forg_taw[t, u])
213 | logger.log_scalar(task=t, iter=u, name='forg_tag', group='test', value=100 * forg_tag[t, u])
214 |
215 | # Save
216 | print('Save at ' + os.path.join(args.results_path, full_exp_name))
217 | logger.log_result(acc_taw, name="acc_taw", step=t)
218 | logger.log_result(acc_tag, name="acc_tag", step=t)
219 | logger.log_result(forg_taw, name="forg_taw", step=t)
220 | logger.log_result(forg_tag, name="forg_tag", step=t)
221 | logger.save_model(net.state_dict(), task=t)
222 | logger.log_result(acc_taw.sum(1) / np.tril(np.ones(acc_taw.shape[0])).sum(1), name="avg_accs_taw", step=t)
223 | logger.log_result(acc_tag.sum(1) / np.tril(np.ones(acc_tag.shape[0])).sum(1), name="avg_accs_tag", step=t)
224 | aux = np.tril(np.repeat([[tdata[1] for tdata in taskcla[:max_task]]], max_task, axis=0))
225 | logger.log_result((acc_taw * aux).sum(1) / aux.sum(1), name="wavg_accs_taw", step=t)
226 | logger.log_result((acc_tag * aux).sum(1) / aux.sum(1), name="wavg_accs_tag", step=t)
227 |
228 | # Print Summary
229 | utils.print_summary(acc_taw, acc_tag, forg_taw, forg_tag)
230 | print('[Elapsed time = {:.1f} h]'.format((time.time() - tstart) / (60 * 60)))
231 | print('Done!')
232 |
233 | return acc_taw, acc_tag, forg_taw, forg_tag, logger.exp_path
234 |
235 | if __name__ == '__main__':
236 | main()
237 |
--------------------------------------------------------------------------------
/src/networks/__init__.py:
--------------------------------------------------------------------------------
1 | from .scroll_resnet18 import scroll_resnet18
2 | allmodels = ['scroll_resnet18']
3 |
--------------------------------------------------------------------------------
/src/networks/network.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from copy import deepcopy
4 | from .slimmable_ops import SlimmableConv2d, SlimmableLinear
5 | from widths.config import FLAGS
6 |
7 | class LLL_Net(nn.Module):
8 | """Basic class for implementing networks"""
9 |
10 | def __init__(self, model, remove_existing_head=False):
11 | head_var = model.head_var
12 | assert type(head_var) == str
13 | assert not remove_existing_head or hasattr(model, head_var), \
14 | "Given model does not have a variable called {}".format(head_var)
15 | assert not remove_existing_head or type(getattr(model, head_var)) in [nn.Sequential, nn.Linear, SlimmableLinear], \
16 | "Given model's head {} does is not an instance of nn.Sequential or nn.Linear".format(head_var)
17 | super(LLL_Net, self).__init__()
18 |
19 | self.model = model
20 | last_layer = getattr(self.model, head_var)
21 |
22 | if remove_existing_head:
23 | if type(last_layer) == nn.Sequential:
24 | self.out_size = last_layer[-1].in_features
25 | # strips off last linear layer of classifier
26 | del last_layer[-1]
27 | elif type(last_layer) == nn.Linear:
28 | self.out_size = last_layer.in_features
29 | # converts last layer into identity
30 | # setattr(self.model, head_var, nn.Identity())
31 | # WARNING: this is for when pytorch version is <1.2
32 | setattr(self.model, head_var, nn.Sequential())
33 | elif type(last_layer) == SlimmableLinear:
34 | self.out_size = last_layer.in_features
35 | setattr(self.model, head_var, nn.Sequential())
36 | else:
37 | self.out_size = last_layer.out_features
38 |
39 | self.heads = nn.ModuleList()
40 | self.task_cls = []
41 | self.task_offset = []
42 | self._initialize_weights()
43 |
44 | def add_head(self, num_outputs):
45 | """Add a new head with the corresponding number of outputs. Also update the number of classes per task and the
46 | corresponding offsets
47 | """
48 | Ch_in = [int(self.out_size * width_mult) for width_mult in FLAGS.width_mult_list]
49 | Ch_out = [num_outputs for width_mult in FLAGS.width_mult_list]
50 |
51 | self.heads.append(SlimmableLinear(Ch_in, Ch_out))
52 | # we re-compute instead of append in case an approach makes changes to the heads
53 | self.task_cls = torch.tensor([head.out_features for head in self.heads])
54 | self.task_offset = torch.cat([torch.LongTensor(1).zero_(), self.task_cls.cumsum(0)[:-1]])
55 |
56 | def forward(self, x, return_features=False):
57 | """Applies the forward pass
58 |
59 | Simplification to work on multi-head only -- returns all head outputs in a list
60 | Args:
61 | x (tensor): input images
62 | return_features (bool): return the representations before the heads
63 | """
64 | x = self.model(x)
65 | assert (len(self.heads) > 0), "Cannot access any head"
66 | y = []
67 | for head in self.heads:
68 | y.append(head(x))
69 | if return_features:
70 | return y, x
71 | else:
72 | return y
73 |
74 | def get_copy(self):
75 | """Get weights from the model"""
76 | return deepcopy(self.state_dict())
77 |
78 | def set_state_dict(self, state_dict):
79 | """Load weights into the model"""
80 | self.load_state_dict(deepcopy(state_dict))
81 | return
82 |
83 | def freeze_all(self):
84 | """Freeze all parameters from the model, including the heads"""
85 | for param in self.parameters():
86 | param.requires_grad = False
87 |
88 | def freeze_backbone(self):
89 | """Freeze all parameters from the main model, but not the heads"""
90 | for param in self.model.parameters():
91 | param.requires_grad = False
92 |
93 | def freeze_bn(self):
94 | """Freeze all Batch Normalization layers from the model and use them in eval() mode"""
95 | for m in self.model.modules():
96 | if isinstance(m, nn.BatchNorm2d):
97 | m.eval()
98 |
99 | def _initialize_weights(self):
100 | """Initialize weights using different strategies"""
101 | # TODO: add different initialization strategies
102 | pass
103 |
--------------------------------------------------------------------------------
/src/networks/scroll_resnet18.py:
--------------------------------------------------------------------------------
1 | from distutils.util import change_root
2 | import torch.nn as nn
3 | import math
4 |
5 | from .slimmable_ops import SwitchableBatchNorm2d
6 | from .slimmable_ops import SlimmableConv2d, SlimmableLinear
7 | from widths.config import FLAGS
8 |
9 | def slimconv3x3(in_planes, out_planes, stride=1):
10 | """3x3 convolution with padding"""
11 | return SlimmableConv2d(in_planes, out_planes, kernel_size=3, stride=stride,
12 | padding=1, bias=False)
13 |
14 | class BasicBlock(nn.Module):
15 | expansion = 1
16 |
17 | def __init__(self, inplanes, planes, stride=1, downsample=None):
18 | super(BasicBlock, self).__init__()
19 | channels_in = [int(inplanes * width_mult) for width_mult in FLAGS.width_mult_list]
20 | channels_out = [int(planes * width_mult) for width_mult in FLAGS.width_mult_list]
21 |
22 | self.conv1 = slimconv3x3(channels_in, channels_out, stride)
23 | self.bn1 = SwitchableBatchNorm2d(channels_out)
24 | self.relu = nn.ReLU(inplace=True)
25 | self.conv2 = slimconv3x3(channels_out, channels_out)
26 | self.bn2 = SwitchableBatchNorm2d(channels_out)
27 | self.downsample = downsample
28 | self.stride = stride
29 |
30 | def forward(self, x):
31 | residual = x
32 | out = self.conv1(x)
33 | out = self.bn1(out)
34 | out = self.relu(out)
35 | out = self.conv2(out)
36 | out = self.bn2(out)
37 | if self.downsample is not None:
38 | residual = self.downsample(x)
39 | out += residual
40 | out = self.relu(out)
41 |
42 | return out
43 |
44 | class Scroll_ResNet(nn.Module):
45 |
46 | def __init__(self, block, layers, num_classes=10):
47 | self.inplanes = 64
48 | super(Scroll_ResNet, self).__init__()
49 | chann_head_in = [3 for width_mult in FLAGS.width_mult_list]
50 | chann_head_out = [int(64 * width_mult) for width_mult in FLAGS.width_mult_list]
51 | self.conv1 = SlimmableConv2d(chann_head_in, chann_head_out, kernel_size=3, stride=1, padding=1,
52 | bias=False)
53 |
54 | self.bn1 = SwitchableBatchNorm2d(chann_head_out)
55 | self.relu = nn.ReLU(inplace=True)
56 | self.layer1 = self._make_layer(block, 64, layers[0])
57 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
58 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
59 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
60 | self.avgpool = nn.AdaptiveAvgPool2d((1,1))
61 |
62 | # last classifier layer (head) with as many outputs as classes
63 | chann_tail_in = [int(512 * block.expansion * width_mult) for width_mult in FLAGS.width_mult_list]
64 | chann_tail_out = [num_classes for width_mult in FLAGS.width_mult_list]
65 | self.fc = SlimmableLinear(chann_tail_in, chann_tail_out)
66 | # self.last_dim = self.fc.in_features
67 | # and `head_var` with the name of the head, so it can be removed when doing incremental learning experiments
68 | self.head_var = 'fc'
69 |
70 | for m in self.modules():
71 | if isinstance(m, nn.Conv2d):
72 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
73 | elif isinstance(m, nn.BatchNorm2d):
74 | nn.init.constant_(m.weight, 1)
75 | nn.init.constant_(m.bias, 0)
76 |
77 | def _make_layer(self, block, planes, blocks, stride=1):
78 | downsample = None
79 | chann_d_in = [int(self.inplanes * width_mult) for width_mult in FLAGS.width_mult_list]
80 | chann_d_out = [int(planes * block.expansion * width_mult) for width_mult in FLAGS.width_mult_list]
81 |
82 | if stride != 1 or self.inplanes != planes * block.expansion:
83 | downsample = nn.Sequential(
84 | SlimmableConv2d(chann_d_in, chann_d_out,
85 | kernel_size=1, stride=stride, bias=False),
86 | SwitchableBatchNorm2d(chann_d_out),
87 | )
88 |
89 | layers = []
90 | layers.append(block(self.inplanes, planes, stride, downsample))
91 | self.inplanes = planes * block.expansion
92 |
93 | for i in range(1, blocks):
94 | layers.append(block(self.inplanes, planes))
95 |
96 | return nn.Sequential(*layers)
97 |
98 | def forward(self, x):
99 | x = self.conv1(x)
100 | x = self.bn1(x)
101 | x = self.relu(x)
102 |
103 | x = self.layer1(x)
104 | x = self.layer2(x)
105 | x = self.layer3(x)
106 | x = self.layer4(x)
107 |
108 | x = self.avgpool(x)
109 | x = x.view(x.size(0), -1)
110 | x = self.fc(x)
111 |
112 | return x
113 |
114 | def scroll_resnet18(pretrained=False, **kwargs):
115 | """Constructs a ResNet-18 model.
116 |
117 | Args:
118 | pretrained (bool): If True, returns a model pre-trained on ImageNet
119 | """
120 | model = Scroll_ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
121 | return model
122 |
--------------------------------------------------------------------------------
/src/networks/slimmable_ops.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import random
4 | from widths.config import FLAGS
5 |
6 | class SwitchableBatchNorm2d(nn.Module):
7 | def __init__(self, num_features_list):
8 | super(SwitchableBatchNorm2d, self).__init__()
9 | self.num_features_list = num_features_list
10 | self.num_features = max(num_features_list)
11 | bns = []
12 | for i in num_features_list:
13 | bns.append(nn.BatchNorm2d(i))
14 | self.bn = nn.ModuleList(bns)
15 | self.width_mult = max(FLAGS.width_mult_list)
16 | self.ignore_model_profiling = True
17 | self.scroll = 0.0
18 |
19 | def forward(self, input):
20 | idx = FLAGS.width_mult_list.index(self.width_mult)
21 | y = self.bn[idx](input)
22 | return y
23 |
24 |
25 | class SlimmableConv2d(nn.Conv2d):
26 | def __init__(self, in_channels_list, out_channels_list,
27 | kernel_size, stride=1, padding=0, dilation=1,
28 | groups_list=[1], bias=True):
29 | super(SlimmableConv2d, self).__init__(
30 | max(in_channels_list), max(out_channels_list),
31 | kernel_size, stride=stride, padding=padding, dilation=dilation,
32 | groups=max(groups_list), bias=bias)
33 | self.in_channels_list = in_channels_list
34 | self.out_channels_list = out_channels_list
35 | self.groups_list = groups_list
36 | if self.groups_list == [1]:
37 | self.groups_list = [1 for _ in range(len(in_channels_list))]
38 | self.width_mult = max(FLAGS.width_mult_list)
39 | self.scroll = 0.0
40 | self.inverse = 1 # 0 is up, 1 is down
41 |
42 | def cyc_scroll(self, scroll_num1, scroll_num2):
43 | h, w, _, _ = self.weight.shape
44 | matrix = torch.cat((self.weight[(h-scroll_num1):,:], self.weight[:(h-scroll_num1),:]), dim=0)
45 | weight = torch.cat((matrix[:,(h-scroll_num2):], matrix[:,:(h-scroll_num2)]), dim=1)
46 |
47 | return weight
48 |
49 | def cyc_scroll_bias(self, scroll_num):
50 | L = self.bias
51 | bias = torch.cat((self.bias[(L-scroll_num):], self.bias[:(L-scroll_num)]), dim=0)
52 |
53 | return bias
54 |
55 | def cyc_scroll_inverse(self, scroll_num1, scroll_num2):
56 | h, w, _, _ = self.weight.shape
57 | matrix = torch.cat((self.weight[scroll_num1:,:], self.weight[:scroll_num1,:]), dim=0)
58 | weight = torch.cat((matrix[:,scroll_num2:], matrix[:,:scroll_num2]), dim=1)
59 |
60 | return weight
61 |
62 | def cyc_scroll_bias_inverse(self, scroll_num):
63 | bias = torch.cat((self.bias[scroll_num:], self.bias[:scroll_num]), dim=0)
64 |
65 | return bias
66 |
67 | def forward(self, input):
68 | self.scroll = self.scroll % len(FLAGS.width_mult_list) # cycle scrolling
69 | idx = FLAGS.width_mult_list.index(self.width_mult)
70 | self.in_channels = self.in_channels_list[idx]
71 | self.out_channels = self.out_channels_list[idx]
72 | self.groups = self.groups_list[idx]
73 | scroll_num1 = int(self.scroll*(self.out_channels_list[1]-self.out_channels_list[0]))
74 | scroll_num2 = int(self.scroll*(self.in_channels_list[1]-self.in_channels_list[0]))
75 |
76 | if self.inverse==0:
77 | weight = self.cyc_scroll(scroll_num1, scroll_num2)
78 | elif self.inverse==1:
79 | weight = self.cyc_scroll_inverse(scroll_num1, scroll_num2)
80 | weight = weight[:self.out_channels, :self.in_channels, :, :]
81 |
82 | if self.bias is not None:
83 | if self.inverse==0:
84 | bias = self.cyc_scroll_bias(scroll_num1)
85 | elif self.inverse==1:
86 | bias = self.cyc_scroll_bias_inverse(scroll_num1)
87 | bias = bias[:self.out_channels]
88 | else:
89 | bias = self.bias
90 |
91 | y = nn.functional.conv2d(
92 | input, weight, bias, self.stride, self.padding,
93 | self.dilation, self.groups)
94 | return y
95 |
96 |
97 | class SlimmableLinear(nn.Linear):
98 | def __init__(self, in_features_list, out_features_list, bias=True):
99 | super(SlimmableLinear, self).__init__(
100 | max(in_features_list), max(out_features_list), bias=bias)
101 | self.in_features_list = in_features_list
102 | self.out_features_list = out_features_list
103 | self.width_mult = max(FLAGS.width_mult_list)
104 | self.scroll = 0.0
105 |
106 | def cyc_scroll_inverse(self, scroll_num1, scroll_num2):
107 | matrix = torch.cat((self.weight[scroll_num1:,:], self.weight[:scroll_num1,:]), dim=0)
108 | weight = torch.cat((matrix[:,scroll_num2:], matrix[:,:scroll_num2]), dim=1)
109 | return weight
110 |
111 | def cyc_scroll_bias_inverse(self, scroll_num):
112 | bias = torch.cat((self.bias[scroll_num:], self.bias[:scroll_num]), dim=0)
113 | return bias
114 |
115 | def forward(self, input):
116 |
117 | self.scroll = self.scroll % len(FLAGS.width_mult_list) # cycle scrolling
118 |
119 | idx = FLAGS.width_mult_list.index(self.width_mult)
120 | self.in_features = self.in_features_list[idx]
121 | self.out_features = self.out_features_list[idx]
122 | scroll_num1 = int(self.scroll*(self.out_features_list[1]-self.out_features_list[0]))
123 | scroll_num2 = int(self.scroll*(self.in_features_list[1]-self.in_features_list[0]))
124 |
125 | weight = self.cyc_scroll_inverse(scroll_num1, scroll_num2)
126 | weight = weight[:self.out_features, :self.in_features]
127 | if self.bias is not None:
128 | bias = self.cyc_scroll_bias_inverse(scroll_num1)
129 | bias = bias[:self.out_features]
130 | else:
131 | bias = self.bias
132 | return nn.functional.linear(input, weight, bias)
133 |
134 | def make_divisible(v, divisor=8, min_value=1):
135 | """
136 | forked from slim:
137 | https://github.com/tensorflow/models/blob/\
138 | 0344c5503ee55e24f0de7f37336a6e08f10976fd/\
139 | research/slim/nets/mobilenet/mobilenet.py#L62-L69
140 | """
141 | if min_value is None:
142 | min_value = divisor
143 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
144 | # Make sure that round down does not go down by more than 10%.
145 | if new_v < 0.9 * v:
146 | new_v += divisor
147 | return new_v
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import random
4 | import numpy as np
5 |
6 | cudnn_deterministic = True
7 |
8 |
9 | def seed_everything(seed=0):
10 | """Fix all random seeds"""
11 | random.seed(seed)
12 | np.random.seed(seed)
13 | torch.manual_seed(seed)
14 | torch.cuda.manual_seed_all(seed)
15 | os.environ['PYTHONHASHSEED'] = str(seed)
16 | torch.backends.cudnn.deterministic = cudnn_deterministic
17 |
18 |
19 | def print_summary(acc_taw, acc_tag, forg_taw, forg_tag):
20 | """Print summary of results"""
21 | for name, metric in zip(['TAw Acc', 'TAg Acc', 'TAw Forg', 'TAg Forg'], [acc_taw, acc_tag, forg_taw, forg_tag]):
22 | print('*' * 108)
23 | print(name)
24 | for i in range(metric.shape[0]):
25 | print('\t', end='')
26 | for j in range(metric.shape[1]):
27 | print('{:5.1f}% '.format(100 * metric[i, j]), end='')
28 | if np.trace(metric) == 0.0:
29 | if i > 0:
30 | print('\tAvg.:{:5.1f}% '.format(100 * metric[i, :i].mean()), end='')
31 | else:
32 | print('\tAvg.:{:5.1f}% '.format(100 * metric[i, :i + 1].mean()), end='')
33 | print()
34 | print('*' * 108)
35 |
--------------------------------------------------------------------------------
/src/widths/config.py:
--------------------------------------------------------------------------------
1 | """config utilities for yml file."""
2 | import os
3 | import sys
4 | import yaml
5 |
6 | # singletone
7 | FLAGS = None
8 |
9 |
10 | class LoaderMeta(type):
11 | """Constructor for supporting `!include`.
12 | """
13 | def __new__(mcs, __name__, __bases__, __dict__):
14 | """Add include constructer to class."""
15 | # register the include constructor on the class
16 | cls = super().__new__(mcs, __name__, __bases__, __dict__)
17 | cls.add_constructor('!include', cls.construct_include)
18 | return cls
19 |
20 |
21 | class Loader(yaml.Loader, metaclass=LoaderMeta):
22 | """YAML Loader with `!include` constructor.
23 | """
24 | def __init__(self, stream):
25 | try:
26 | self._root = os.path.split(stream.name)[0]
27 | except AttributeError:
28 | self._root = os.path.curdir
29 | super().__init__(stream)
30 |
31 | def construct_include(self, node):
32 | """Include file referenced at node."""
33 | filename = os.path.abspath(
34 | os.path.join(self._root, self.construct_scalar(node)))
35 | extension = os.path.splitext(filename)[1].lstrip('.')
36 | with open(filename, 'r') as f:
37 | if extension in ('yaml', 'yml'):
38 | return yaml.load(f, Loader)
39 | else:
40 | return ''.join(f.readlines())
41 |
42 | class AttrDict(dict):
43 | """Dict as attribute trick.
44 |
45 | """
46 | def __init__(self, *args, **kwargs):
47 | super(AttrDict, self).__init__(*args, **kwargs)
48 | self.__dict__ = self
49 | for key in self.__dict__:
50 | value = self.__dict__[key]
51 | if isinstance(value, dict):
52 | self.__dict__[key] = AttrDict(value)
53 | elif isinstance(value, list):
54 | if isinstance(value[0], dict):
55 | self.__dict__[key] = [AttrDict(item) for item in value]
56 | else:
57 | self.__dict__[key] = value
58 |
59 | def yaml(self):
60 | """Convert object to yaml dict and return.
61 |
62 | """
63 | yaml_dict = {}
64 | for key in self.__dict__:
65 | value = self.__dict__[key]
66 | if isinstance(value, AttrDict):
67 | yaml_dict[key] = value.yaml()
68 | elif isinstance(value, list):
69 | if isinstance(value[0], AttrDict):
70 | new_l = []
71 | for item in value:
72 | new_l.append(item.yaml())
73 | yaml_dict[key] = new_l
74 | else:
75 | yaml_dict[key] = value
76 | else:
77 | yaml_dict[key] = value
78 | return yaml_dict
79 |
80 | def __repr__(self):
81 | """Print all variables.
82 |
83 | """
84 | ret_str = []
85 | for key in self.__dict__:
86 | value = self.__dict__[key]
87 | if isinstance(value, AttrDict):
88 | ret_str.append('{}:'.format(key))
89 | child_ret_str = value.__repr__().split('\n')
90 | for item in child_ret_str:
91 | ret_str.append(' ' + item)
92 | elif isinstance(value, list):
93 | if isinstance(value[0], AttrDict):
94 | ret_str.append('{}:'.format(key))
95 | for item in value:
96 | # treat as AttrDict above
97 | child_ret_str = item.__repr__().split('\n')
98 | for item in child_ret_str:
99 | ret_str.append(' ' + item)
100 | else:
101 | ret_str.append('{}: {}'.format(key, value))
102 | else:
103 | ret_str.append('{}: {}'.format(key, value))
104 | return '\n'.join(ret_str)
105 |
106 |
107 | class Config(AttrDict):
108 | """Config with yaml file.
109 |
110 | This class is used to config model hyper-parameters, global constants, and
111 | other settings with yaml file. All settings in yaml file will be
112 | automatically logged into file.
113 |
114 | Args:
115 | filename(str): File name.
116 |
117 | Examples:
118 |
119 | yaml file ``model.yml``::
120 |
121 | NAME: 'neuralgym'
122 | ALPHA: 1.0
123 | DATASET: '/mnt/data/imagenet'
124 |
125 | Usage in .py:
126 |
127 | >>> from neuralgym import Config
128 | >>> config = Config('model.yml')
129 | >>> print(config.NAME)
130 | neuralgym
131 | >>> print(config.ALPHA)
132 | 1.0
133 | >>> print(config.DATASET)
134 | /mnt/data/imagenet
135 |
136 | """
137 |
138 | def __init__(self, filename=None, verbose=False):
139 | assert os.path.exists(filename), 'File {} not exist.'.format(filename)
140 | try:
141 | with open(filename, 'r') as f:
142 | cfg_dict = yaml.load(f, Loader)
143 | except EnvironmentError:
144 | print('Please check the file with name of "%s"', filename)
145 | super(Config, self).__init__(cfg_dict)
146 | if verbose:
147 | print(' pi.cfg '.center(80, '-'))
148 | print(self.__repr__())
149 | print(''.center(80, '-'))
150 |
151 | def app():
152 | """Load app via stdin from subprocess"""
153 | global FLAGS
154 | if FLAGS is None:
155 | job_yaml_file = 'SizeOfSubnetworks.yml'
156 | FLAGS = Config(job_yaml_file)
157 | return FLAGS
158 | else:
159 | return FLAGS
160 |
161 | app()
162 |
--------------------------------------------------------------------------------