├── LICENSE
├── README.md
├── README.pdf
├── imgs
├── CIL-diagram.png
├── catastrophicforgetting.png
├── cifar.png
├── classifier.png
├── convnet.png
├── cub.png
├── decouple.png
├── herding.png
├── icarl.png
├── imagenet.png
├── kd.png
├── logo.png
├── maze.jpg
├── norm.png
├── outputprobabilities.png
├── replay.png
└── vgg16.png
├── requirements.txt
├── resnet.py
├── template.py
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Fu-Yun Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | This is a **[PyTorch](https://pytorch.org) Tutorial to Class-Incremental Learning**.
2 |
3 | Basic knowledge of PyTorch, convolutional neural networks is assumed.
4 |
5 | If you're new to PyTorch, first read [Deep Learning with PyTorch: A 60 Minute Blitz](https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html) and [Learning PyTorch with Examples](https://pytorch.org/tutorials/beginner/pytorch_with_examples.html).
6 |
7 | Questions, suggestions, or corrections can be posted as issues.
8 |
9 | I'm using `PyTorch 1.11.0+cu113` in `Python 3.9`.
10 |
11 | Note: We recommond you install [mathjax-plugin-for-github](https://chrome.google.com/webstore/search/mathjax) read the following math formulas or clone this repository to read locally. Here is a pdf version [README.pdf](README.pdf)
12 |
13 | **key words:** `Class-Incremental Learning`, `PyTorch Distributed Training`
14 |
15 | ---
16 |
17 | # Contents
18 |
19 | [***Objective***](#Objective)
20 |
21 | [***Toolbox***](#Toolbox)
22 |
23 | [***Concepts***](#Concepts)
24 |
25 | [***Overview***](#Overview)
26 |
27 | [***Implementation***](#Implementation)
28 |
29 | [***Contact***](#Contact)
30 |
31 | [***Acknowledgments***](#Acknowledgments)
32 |
33 | # Objective
34 |
35 | **To build a model that can learn novel classes while maintaining discrimination ability for old categories.**
36 |
37 |
38 |
39 | We will be implementing the [Maintaining Discrimination and Fairness in Class Incremental Learning (WA)](https://arxiv.org/abs/1512.02325), a strong fundamental baseline of class-incremental learning methods.
40 |
41 | Our implementation is very efficient and straightforward to understand. Utilizing relevant open source tools such as [torch.distributed](https://pytorch.org/tutorials/beginner/dist_overview.html), [timm](https://github.com/rwightman/pytorch-image-models), [continuum](https://github.com/Continvvm/continuum), etc., we keep the core code within 100 lines. We believe that both beginners and researchers in related fields can gain some inspiration from this tutorial.
42 |
43 | We also provide an introduction (in Chinese) about CIL. It is also available [here](https://zhuanlan.zhihu.com/p/490308909).
44 |
45 | Here are some images in traditional benchmarks.
46 |
47 |
48 | ImageNet
49 |
50 |
51 |
52 |
53 |
54 | CUB
55 |
56 |
57 |
58 |
59 | CIFAR
60 |
61 |
62 |
63 |
64 | # Toolbox
65 |
66 | Now there are many excellent implementations of incremental learning methods.
67 |
68 | [PyCIL: A Python Toolbox for Class-Incremental Learning](https://github.com/G-U-N/PyCIL)
69 |
70 |
71 |

72 |
73 | PyCIL mainly focuses on Class-incremental learning. It contains implementations of a number of founding works of CIL, such as EWC and iCaRL. It also provides current state-of-the-art algorithms that can be used for conducting novel fundamental research.
74 |
75 | We **strongly recommend** you refer to methods reproduced in [PyCIL](https://github.com/G-U-N/PyCIL) as the basis of your Class-Incremental Learning research.
76 |
77 | [FACIL](https://github.com/mmasana/FACIL)
78 |
79 | Framework for Analysis of Class-Incremental Learning with 12 state-of-the-art methods and three baselines.
80 |
81 | [Avalanche](https://github.com/ContinualAI/avalanche)
82 |
83 | Avalanche contains various paradigms of incremental learning and is one of the earliest incremental learning toolkits.
84 |
85 | # Concepts
86 |
87 | * **Online Machine Learning**. Online machine learning is a machine learning method in which data becomes available in a sequential order and is used to update the best predictor for future data at each step, as opposed to batch learning techniques which generate the best predictor by learning on the entire training data set at once.
88 | * **Class-Incremental Learning (CIL)**. duh.
89 | * **Catastrophic Interference/Forgetting**. Catastrophic interference, also known as catastrophic forgetting, is the tendency of an [artificial neural network](https://en.wikipedia.org/wiki/Artificial_neural_network) to completely and abruptly forget previously learned information upon learning new information.
90 | * **Exemplars/Replay Buffer**. `Replay Buffer` is commonly used in Deep Reinforcement Learning (DRL). DRL algorithms, especially off-policy algorithms, use *replay buffers* to store trajectories of experience when executing a policy in an environment. In the Scenario of CIL, we use a size-limited replay buffer to store a few representative instances of old classes for future training.
91 |
92 |
93 | - **Rehearsal**. Rehearsal is the process of training with exemplars.
94 | - **Knowledge Distillation**. Knowledge distillation is the process of transferring knowledge from a teacher [model](https://en.wikipedia.org/wiki/Statistical_model) to a student model. It was proposed in [Distilling the knowledge in a neural network](https://arxiv.org/abs/1503.02531). The original form of knowledge distillation minimizes the kl-divergence of the output probability distributions between the teacher and student models.
95 | - **Calibration**. After training on imbalanced datasets, models typically have a strong tendency to misclassify minority classes into majority classes. This is unacceptable when minority classes are more important, e.g., Cancer Detection. Calibration aims to achieve a balance between minority classes and majority ones.
96 |
97 | If you are still confused about the above concepts, don't worry, we will refer to them again and again later, and you may gradually understand them as you read.
98 |
99 | # Overview
100 |
101 | In this section, I will present an overview of Class-Incremental Learning and the above-mentioned methods. If you're already familiar with it, you can skip straight to the [Implementation](#Implementation) section or the commented code.
102 |
103 | ## Class-Incremental Learning
104 |
105 |
106 | CIL-Diagram
107 |

108 |
109 |
110 | The above figure illustrates the process of CIL. In the first stage, the model learns classes 1, 2 and tests its performance on classes 1, 2. In the second stage, the model continues to learn categories 3 and 4, but the data of categories 1 and 2 is not available. After training, the accuracy of the model on categories 1, 2, 3, and 4 is calculated. In the third stage, the model continues to learn categories 5 and 6, but the data of categories 1, 2, 3 and 4 is not available. After training, the accuracy of the model on categories 1, 2, 3, 4, 5, and 6 is calculated.
111 |
112 | ## Catastrophic Forgetting
113 |
114 | However, by simply fine-tuning the model on new training data, you may see the following phenomenon. Although in the first task the model has learned how to discriminate pictures of dogs, after the second task the model seems to have completely forgotten about dogs. It misclassifies the dog into fish. Therefore, we can see that although the recognition ability of the deep learning model based on vanilla SGD optimization in the closed world has reached or even exceeded humans, it completely lost its competitiveness in this dynamically changing world.
115 |
116 |
117 |

118 |
119 |
120 | The above phenomenon can be attributed to the following two main reasons:
121 |
122 | **1. Lack of supervision in old classes**. Due to the loss of access to the data of the old class, when the model performs SGD optimization on new classes, it is likely to violate the optimization direction of old ones, thus destroying the feature representation of old classes.
123 |
124 | **2. Imbalance of old and new classes**. The above setting can be seen as an extremely imbalanced training process where the number of old classes is zero. The imbalance causes a strong bias for majority classes, and thus, the model will never predict an input image as a minority class.
125 |
126 | ## Replay Buffer
127 |
128 | Replay Buffer is naive but effective. Since it lacks supervision in old classes, why not just store some instances for future training? As shown in the following figure, we store some representative instances (exemplars) for training in Stage2.
129 |
130 |
131 |

132 |
133 |
134 | The following question is, how do we choose exemplars? For example, given a limited number of $m$, how do we choose $m$ exemplars from all the training data that can best represent the class.
135 |
136 | Take Gaussian distribution as an example, what we care about most is the mean of the distribution, which is the center of the class cluster. Therefore, we hope that the deviation of the center of these $m$ exemplars from the center of all samples can be as small as possible. However, to find the best solution, we will need to compute ${n \choose m}$ possible solutions, which is intractable when $m$ is large. Therefore, we take a greedy algorithm. We **greedily** add new exemplars to the replay buffer to minimize the deviation of the center of the $m$ exemplars and that of the $n$ instances. That is, we first choose the one $\mathbf x_1$ that is the closest to the center from $n$ instances, and then we choose the one $x_2$ that makes the mean of $\mathbf x_1$ and $\mathbf x_2$ closest to the center. Therefore, the computation complexity is reduced to $\sum_{k=1}^{m}(n-k+1)=\frac{(2n-m+1)m}{2}=O(nm)$。
137 |
138 |
139 |

140 |
141 |
142 | ## Knowledge Distillation
143 |
144 | Although we have selected some exemplars to increase the supervision information for the old class, when the number of optional exemplars, i.e., $m$, is very small, we still can only get very little supervision information for the old classes. Hence, we need additional supervision information to help the model better maintain the representational ability of the old classes.
145 |
146 | **Knowledge Distillation** is a good idea. A real-life example, in East Asia, parents are willing to spend very high educational funds to hire excellent teachers to teach their children. Because they believe that with a good teacher, children can get good grades with less effort. That is also the truth in deep neural networks. A student model can achieve better performance by employing the soft supervision provided by the teacher model.
147 |
148 | $$
149 | \mathcal{L}_{KD}(\mathbf{x})=\sum_{c=1}^{C}-\hat{q}_{c}(\mathbf{x}) \log \left(q_{c}(\mathbf{x})\right),
150 | $$
151 |
152 | where $\hat{q}_{c}(\mathbf{x})=\frac{e^{\hat{o}_{c}(\mathbf{x}) / T}}{\sum_{j=1}^{C} e^{\hat{o}_{j}(\mathbf{x}) / T}}$ is the Softmax of the logits divided by the temperature $T$ of the teacher model, and $q_c(\mathbf x)$ is that of the student model. Some beginners might feel confused about the knowledge distillation form and why we sometimes call the above cross-entropy kl-divergence equivalently. Here we provide a brief explanation in a formula. Think why the following equation is correct.
153 |
154 | $$
155 | \begin{align}
156 | \min_{\theta} &\quad\mathrm{KL}\left(\hat{q}_{c}(\mathbf{x}\mid \hat\theta)\mid q_{c}(\mathbf{x}\mid \theta)\right)\\
157 | =\min_{\theta}&\quad \sum_{c=1}^{C}\left\{\hat{q}_{c}(\mathbf{x}\mid \hat{\theta}) \log \left(\hat{q}_{c}(\mathbf{x} \mid \hat{\theta})\right)-\hat{q}_{c}(\mathbf{x}\mid \hat{\theta}) \log \left(q_{c}(\mathbf{x}\mid \theta)\right)\right\}\\
158 | =\min_\theta &\quad \sum_{c=1}^{C}-\hat{q}_{c}(\mathbf{x}\mid \hat{\theta}) \log \left(q_{c}(\mathbf{x})\mid \theta\right).
159 | \end{align}
160 | $$
161 |
162 | So, where do we find the teacher?
163 |
164 | The model itself in the previous phase is a good teacher! We restore the model in the previous phase to provide more supervision of old classes.
165 |
166 |
167 |

168 |
169 |
170 | Therefore, the overall loss combines $\mathcal L_{CE}$ and $\mathcal L_{KD}$
171 | $$
172 | \mathcal{L}(\mathbf{x}, y)=(1-\lambda) \mathcal{L}_{C E}(\mathbf{x}, y)+\lambda \mathcal{L}_{K D}(\mathbf{x}),
173 | $$
174 |
175 | where $\lambda$ is set to a default value or dynamically set to $\frac{n}{n
176 | +m}$.
177 |
178 | ## Weight Alignment
179 |
180 | Here, to give an explanation of how weight alignment is achieved and why it works, we decouple the classification model into a feature extractor $\phi(\mathbf x)$and a linear classifier $W$ (ignoring bias $\boldsymbol b$).
181 |
182 | We take VGG16 as an example.
183 |
184 |
185 |

186 |
187 |
188 | Although deep neural networks can be designed as various architectures, most of them can be decoupled as follows. The whole network excluding the final FC layer can be viewed as a feature extractor. It transforms the input image into a hidden feature space where most classes are linearly separable.
189 |
190 |
191 |

192 |
193 |
194 | The Classifier $W$ transforms these features into logits, namely
195 | $$
196 | \mathbf o= W \phi(\mathbf x),
197 | $$
198 | where $\mathbf o$ is the logits (output of the model).
199 |
200 | Ignoring the bias, the classifier $W$ is essentially a matrix that represents a linear transform from feature space $\mathbb R^{d}$ to logits space $\mathbb R^{n+m}$, where $d$ is the dimension of the feature space; $n$ is the number of old categories; and $m$ is the number of new categories. The matrix can be further decomposed into a list of vectors $w_1,w_2,\dots,w_n,w_{n+1},\dots,w_{n+m}$. We call them prototypes.
201 |
202 |
203 |

204 |
205 | Thus the output logits $\mathbf o$
206 | $$
207 | \mathbf o = W\phi(\mathbf x) =\left(\begin{array}{c} W_{old}^T\phi(\mathbf x) \\ W_{new}^T \phi(\mathbf x)\end{array}\right) = \left(\begin{array}{c} w_1^T \phi(\mathbf x) \\ w_2^T \phi(\mathbf x)\\\vdots \\ w_n^T \phi(\mathbf x) \\ w_{n+1}^T \phi(\mathbf x) \\ \vdots \\w_{n+m}^T \phi(\mathbf x) \end{array}\right)
208 | $$
209 | From the above decomposition, it is apparent that the absolute logit of the $i$-th class is proportional to the norm of vector $w_i$. Nowadays, state-of-the-art model architectures use ReLU ($\mathrm {ReLU}(x)=\max (x)$) as the non-linear activation, leading to models tending to have non-negative hidden features and outputs. As a result, **it is typical to find that classes with larger norm values of $w$ tend to have larger logits.**
210 |
211 |
212 |
213 | Extensive experiments show that the norms of new class prototypes are usually much larger than those of old ones. Therefore, the norms difference between old and new prototypes causes a strong classification bias to new classes, destroying all classes' calibration and overall performance.
214 |
215 |
216 |

217 |
218 |
219 | Based on the above analysis, a trivial but effective calibration method is to let new and old prototypes have the same average norm.
220 |
221 | We first calculate the ratio factor of the norms of the old and new classes
222 | $$
223 | \gamma=\frac{\operatorname{Mean}\left(\text { Norm }_{\text {old }}\right)}{\operatorname{Mean}\left(\text { Norm }_{n e w}\right)},
224 | $$
225 | and then we multiply the ratio factor with the new prototypes
226 | $$
227 | \hat{W}_{new}=\gamma W_{new}.
228 | $$
229 | After that the correct logits are
230 | $$
231 | \mathbf o_{correct} =\left(\begin{array}{c} W_{old}^T\phi(\mathbf x) \\ \gamma W_{new}^T \phi(\mathbf x)\end{array}\right) = \left(\begin{array}{c} w_1^T \phi(\mathbf x) \\ w_2^T \phi(\mathbf x)\\\vdots \\ w_n^T \phi(\mathbf x) \\ \gamma w_{n+1}^T \phi(\mathbf x) \\ \vdots \\\gamma w_{n+m}^T \phi(\mathbf x) \end{array}\right).
232 | $$
233 | Then we finished the whole training phase.
234 |
235 | # Implementation
236 |
237 | The sections below briefly describe the implementation.
238 |
239 | They are meant to provide some context, but **details are best understood directly from the code**, which is quite heavily commented.
240 |
241 | ### Dataset
242 |
243 | We will use CIFAR100 and ImageNet-100/1000, which are common benchmarks in Class-Incremental Learning.
244 |
245 | #### Download
246 |
247 | For CIFAR-100, the python scripts will automatically download it. We recommend you test the code with CIFAR-100 since it takes much less computation overhead compared to ImageNet.
248 |
249 | While for ImageNet-100/1000, you should download the dataset from [Image-net.org](https://image-net.org/) and specify the [DATA_PATH] in the code.
250 |
251 | ### Data pipeline
252 |
253 | #### Continuum Class-Incremental Scenario
254 |
255 | See `build_dataset` in [`utils.py`](utils.py).
256 |
257 | ```python
258 | from continuum import ClassIncremental
259 | scenario = ClassIncremental(
260 | dataset,
261 | initial_increment=args.num_bases,
262 | increment=args.increment,
263 | transformations=transform.transforms,
264 | class_order=args.class_order
265 | )
266 | ```
267 |
268 | By utilizing the Class `ClassIncremental`provided by continuum, we can easily generate the required dataset at any stage.
269 |
270 | There are five arguments we are required to specify
271 |
272 | - **dataset**. A PyTorch style dataset (CIFAR100) or a ImageFolder style dataset (ImageNet).
273 | - **initial_increment**. The number of classes in the $0$-th phase.
274 | - **Increment**. The number of classes in the following phases.
275 | - **class_order**. The order of all the classes. For CIFAR100, it is a permutation of 1, 2, 3,..., 100.
276 |
277 | #### Data Transforms
278 |
279 | See `build_transform()` in [`utils.py`](utils.py).
280 |
281 | ```python
282 | from timm.data import create_transform
283 | transform = create_transform(
284 | input_size=args.input_size,
285 | is_training=True,
286 | color_jitter=args.color_jitter,
287 | auto_augment=args.aa,
288 | interpolation='bicubic',
289 | re_prob=args.reprob,
290 | re_mode=args.remode,
291 | re_count=args.recount,
292 | )
293 | ```
294 |
295 | `create_transform` in timm provides a strong baseline augmentation method.
296 |
297 | #### PyTorch Sampler and DataLoader
298 |
299 | Using the distributed training framework that PyTorch has officially implemented, we only need to pass in a DistributedSampler for the DataLoader. The Sampler automatically allocates data resources to each process, speeding up training.
300 |
301 | ```python
302 | train_sampler = DistributedSampler(
303 | dataset_train, num_replicas=args.world_size, rank=args.rank, shuffle=True)
304 |
305 | val_sampler = DistributedSampler(
306 | dataset_val, num_replicas=args.world_size, rank=args.rank, shuffle=False)
307 |
308 | train_loader = DataLoader(
309 | dataset_train, batch_size=args.batch_size, sampler=train_sampler, num_workers=10, pin_memory=True)
310 |
311 | val_loader = DataLoader(
312 | dataset_val, batch_size=args.batch_size, sampler=val_sampler, num_workers=10)
313 |
314 | for epoch in range(args.num_epochs):
315 | sampler.set_epoch(epoch)
316 |
317 | ```
318 |
319 | For each epoch, we pass in a new epoch for the sampler to shuffle the order in which the data appears.
320 |
321 | For more complex data collection methods, we need to **pass a collating function to the `collate_fn` argument**, which instructs the `DataLoader` about how it should combine these varying size tensors. The simplest option would be to use Python lists.
322 |
323 | ### Feature Extractor
324 |
325 | See `CIFAR-ResNet` in [`resnet.py`](resnet.py).
326 |
327 | We use the python code in [PyCIL](https://github.com/G-U-N/PyCIL/blob/master/convs/cifar_resnet.py).
328 |
329 | ### CilClassifier
330 |
331 | See `CilClassifier` in [`template.py`](template.py).
332 |
333 | CilClassifier is a dynamically scalable linear classifier specially designed for CIL. It is cleverly designed and easy to understand.
334 |
335 | At the beginning of a new task, instead of dropping the old classifier and generating a new one, it adds a new small classifier and concatenates it with the original classifier.
336 |
337 | ### CilModel
338 |
339 | See `CilModel`in [`template.py`](template.py).
340 |
341 | As we discussed above, the CilModel consists of two parts
342 |
343 | - **feature extractor** (we call it backbone in the code)
344 | - **Classifier** (we use CilClassifier to support the dynamic expansion)
345 |
346 | And in order to achieve weight alignment and save the teacher model, we also implemented some functions such as parameter freezing and copying, and model expansion.
347 |
348 | # Training
349 |
350 | To **train your model from scratch**, run this file –
351 |
352 | ```bash
353 | CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 template.py
354 | ```
355 |
356 | ### Remarks
357 |
358 | In my implementation, I use **Stochastic Gradient Descent** in batches of `128` images, with an initial learning rate of `1e−1`, momentum of `0.9`, and `5e-4` weight decay. We set the number of epochs to 170 and use the CosineAnnealingLR with T_max=170. Some of the hyperparameters are inconsistent with the implementation of the original paper, but it does not affect much on the performance.
359 |
360 | With 4 RTX 3090 distributed training, the experiments on CIFAR-100 with base 50 and increment 10 end in 30 minutes.
361 |
362 | # Contact
363 |
364 | If there are any questions, please feel free to propose new features by opening an issue or contact with the author: **Fu-Yun Wang** ([wangfuyun@smail.nju.edu.cn](mailto:wangfuyun@smail.nju.edu.cn)). Enjoy the code.
365 |
366 | > Note: This repository is still under development. Interested researchers are welcome to contribute improvements to the code and tutorials
367 |
368 | # Acknowledgments
369 |
370 | We thank the following repos for providing helpful components/functions in our work.
371 |
372 | [PyCIL](https://github.com/G-U-N/PyCIL)
373 |
374 | [DyTox](https://github.com/arthurdouillard/dytox)
375 |
376 |
377 |
378 | **End.**
379 |
380 |
381 |

382 |
383 |
--------------------------------------------------------------------------------
/README.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/a-PyTorch-Tutorial-to-Class-Incremental-Learning/762069142e231b54598462b49514f709d9027a07/README.pdf
--------------------------------------------------------------------------------
/imgs/CIL-diagram.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/a-PyTorch-Tutorial-to-Class-Incremental-Learning/762069142e231b54598462b49514f709d9027a07/imgs/CIL-diagram.png
--------------------------------------------------------------------------------
/imgs/catastrophicforgetting.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/a-PyTorch-Tutorial-to-Class-Incremental-Learning/762069142e231b54598462b49514f709d9027a07/imgs/catastrophicforgetting.png
--------------------------------------------------------------------------------
/imgs/cifar.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/a-PyTorch-Tutorial-to-Class-Incremental-Learning/762069142e231b54598462b49514f709d9027a07/imgs/cifar.png
--------------------------------------------------------------------------------
/imgs/classifier.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/a-PyTorch-Tutorial-to-Class-Incremental-Learning/762069142e231b54598462b49514f709d9027a07/imgs/classifier.png
--------------------------------------------------------------------------------
/imgs/convnet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/a-PyTorch-Tutorial-to-Class-Incremental-Learning/762069142e231b54598462b49514f709d9027a07/imgs/convnet.png
--------------------------------------------------------------------------------
/imgs/cub.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/a-PyTorch-Tutorial-to-Class-Incremental-Learning/762069142e231b54598462b49514f709d9027a07/imgs/cub.png
--------------------------------------------------------------------------------
/imgs/decouple.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/a-PyTorch-Tutorial-to-Class-Incremental-Learning/762069142e231b54598462b49514f709d9027a07/imgs/decouple.png
--------------------------------------------------------------------------------
/imgs/herding.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/a-PyTorch-Tutorial-to-Class-Incremental-Learning/762069142e231b54598462b49514f709d9027a07/imgs/herding.png
--------------------------------------------------------------------------------
/imgs/icarl.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/a-PyTorch-Tutorial-to-Class-Incremental-Learning/762069142e231b54598462b49514f709d9027a07/imgs/icarl.png
--------------------------------------------------------------------------------
/imgs/imagenet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/a-PyTorch-Tutorial-to-Class-Incremental-Learning/762069142e231b54598462b49514f709d9027a07/imgs/imagenet.png
--------------------------------------------------------------------------------
/imgs/kd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/a-PyTorch-Tutorial-to-Class-Incremental-Learning/762069142e231b54598462b49514f709d9027a07/imgs/kd.png
--------------------------------------------------------------------------------
/imgs/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/a-PyTorch-Tutorial-to-Class-Incremental-Learning/762069142e231b54598462b49514f709d9027a07/imgs/logo.png
--------------------------------------------------------------------------------
/imgs/maze.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/a-PyTorch-Tutorial-to-Class-Incremental-Learning/762069142e231b54598462b49514f709d9027a07/imgs/maze.jpg
--------------------------------------------------------------------------------
/imgs/norm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/a-PyTorch-Tutorial-to-Class-Incremental-Learning/762069142e231b54598462b49514f709d9027a07/imgs/norm.png
--------------------------------------------------------------------------------
/imgs/outputprobabilities.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/a-PyTorch-Tutorial-to-Class-Incremental-Learning/762069142e231b54598462b49514f709d9027a07/imgs/outputprobabilities.png
--------------------------------------------------------------------------------
/imgs/replay.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/a-PyTorch-Tutorial-to-Class-Incremental-Learning/762069142e231b54598462b49514f709d9027a07/imgs/replay.png
--------------------------------------------------------------------------------
/imgs/vgg16.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/a-PyTorch-Tutorial-to-Class-Incremental-Learning/762069142e231b54598462b49514f709d9027a07/imgs/vgg16.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | continuum==1.2.2
2 | numpy==1.22.3
3 | timm==0.5.4
4 | torch==1.11.0+cu113
5 | torchvision==0.12.0+cu113
6 |
--------------------------------------------------------------------------------
/resnet.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch.nn import init
7 |
8 |
9 | class DownsampleA(nn.Module):
10 | def __init__(self, nIn, nOut, stride):
11 | super(DownsampleA, self).__init__()
12 | assert stride == 2
13 | self.avg = nn.AvgPool2d(kernel_size=1, stride=stride)
14 |
15 | def forward(self, x):
16 | x = self.avg(x)
17 | return torch.cat((x, x.mul(0)), 1)
18 |
19 |
20 | class ResNetBasicblock(nn.Module):
21 | expansion = 1
22 | """
23 | RexNet basicblock (https://github.com/facebook/fb.resnet.torch/blob/master/models/resnet.lua)
24 | """
25 |
26 | def __init__(self, inplanes, planes, stride=1, downsample=None):
27 | super(ResNetBasicblock, self).__init__()
28 |
29 | self.conv_a = nn.Conv2d(
30 | inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
31 | self.bn_a = nn.BatchNorm2d(planes)
32 |
33 | self.conv_b = nn.Conv2d(
34 | planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
35 | self.bn_b = nn.BatchNorm2d(planes)
36 |
37 | self.downsample = downsample
38 | self.featureSize = 64
39 |
40 | def forward(self, x):
41 | residual = x
42 |
43 | basicblock = self.conv_a(x)
44 | basicblock = self.bn_a(basicblock)
45 | basicblock = F.relu(basicblock, inplace=True)
46 |
47 | basicblock = self.conv_b(basicblock)
48 | basicblock = self.bn_b(basicblock)
49 |
50 | if self.downsample is not None:
51 | residual = self.downsample(x)
52 |
53 | return F.relu(residual + basicblock, inplace=True)
54 |
55 |
56 | class CifarResNet(nn.Module):
57 | """
58 | ResNet optimized for the Cifar Dataset, as specified in
59 | https://arxiv.org/abs/1512.03385.pdf
60 | """
61 |
62 | def __init__(self, block, depth, num_classes, channels=3):
63 | super(CifarResNet, self).__init__()
64 |
65 | self.featureSize = 64
66 | assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110'
67 | layer_blocks = (depth - 2) // 6
68 |
69 | self.num_classes = num_classes
70 |
71 | self.conv_1_3x3 = nn.Conv2d(
72 | channels, 16, kernel_size=3, stride=1, padding=1, bias=False)
73 | self.bn_1 = nn.BatchNorm2d(16)
74 |
75 | self.inplanes = 16
76 | self.stage_1 = self._make_layer(block, 16, layer_blocks, 1)
77 | self.stage_2 = self._make_layer(block, 32, layer_blocks, 2)
78 | self.stage_3 = self._make_layer(block, 64, layer_blocks, 2)
79 | self.avgpool = nn.AvgPool2d(8)
80 | self.out_dim = 64 * block.expansion
81 |
82 | for m in self.modules():
83 | if isinstance(m, nn.Conv2d):
84 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
85 | m.weight.data.normal_(0, math.sqrt(2. / n))
86 | elif isinstance(m, nn.BatchNorm2d):
87 | m.weight.data.fill_(1)
88 | m.bias.data.zero_()
89 | elif isinstance(m, nn.Linear):
90 | init.kaiming_normal(m.weight)
91 | m.bias.data.zero_()
92 |
93 | def _make_layer(self, block, planes, blocks, stride=1):
94 | downsample = None
95 | if stride != 1 or self.inplanes != planes * block.expansion:
96 | downsample = DownsampleA(
97 | self.inplanes, planes * block.expansion, stride)
98 |
99 | layers = []
100 | layers.append(block(self.inplanes, planes, stride, downsample))
101 | self.inplanes = planes * block.expansion
102 | for i in range(1, blocks):
103 | layers.append(block(self.inplanes, planes))
104 |
105 | return nn.Sequential(*layers)
106 |
107 | def forward(self, x, feature=False, T=1, labels=False, scale=None, keep=None):
108 |
109 | x = self.conv_1_3x3(x)
110 | x = F.relu(self.bn_1(x), inplace=True)
111 | x = self.stage_1(x)
112 | x = self.stage_2(x)
113 | x = self.stage_3(x)
114 | x = self.avgpool(x)
115 | x = x.view(x.size(0), -1)
116 | return x
117 |
118 | def forwardFeature(self, x):
119 | pass
120 |
121 |
122 | def resnet20(num_classes=10):
123 | model = CifarResNet(ResNetBasicblock, 20, num_classes)
124 | return model
125 |
126 |
127 | def resnet10mnist(num_classes=10):
128 | model = CifarResNet(ResNetBasicblock, 10, num_classes, 1)
129 | return model
130 |
131 |
132 | def resnet20mnist(num_classes=10):
133 | model = CifarResNet(ResNetBasicblock, 20, num_classes, 1)
134 | return model
135 |
136 |
137 | def resnet32mnist(num_classes=10, channels=1):
138 | model = CifarResNet(ResNetBasicblock, 32, num_classes, channels)
139 | return model
140 |
141 |
142 | def resnet32(num_classes=10):
143 | model = CifarResNet(ResNetBasicblock, 32, num_classes)
144 | return model
145 |
146 |
147 | def resnet44(num_classes=10):
148 | model = CifarResNet(ResNetBasicblock, 44, num_classes)
149 | return model
150 |
151 |
152 | def resnet56(num_classes=10):
153 | model = CifarResNet(ResNetBasicblock, 56, num_classes)
154 | return model
155 |
156 |
157 | def resnet110(num_classes=10):
158 | model = CifarResNet(ResNetBasicblock, 110, num_classes)
159 | return model
160 |
--------------------------------------------------------------------------------
/template.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import copy
3 | import torch
4 | from torch.utils.data import DataLoader
5 | import numpy as np
6 | from resnet import resnet20, resnet32, resnet44, resnet56
7 | import torch.nn as nn
8 | import timm
9 | from continuum import rehearsal
10 | from utils import MetricLogger, SoftTarget, init_distributed_mode, build_dataset
11 |
12 |
13 | def get_args_parser():
14 | parser = argparse.ArgumentParser(
15 | 'Class-Incremental Learning training and evaluation script', add_help=False)
16 | parser.add_argument('--seed', default=0, type=int)
17 | parser.add_argument('--num_bases', default=50, type=int)
18 | parser.add_argument('--increment', default=10, type=int)
19 | parser.add_argument('--backbone', default="resnet32", type=str)
20 | parser.add_argument('--batch_size', default=128, type=int)
21 | parser.add_argument('--input_size', default=32, type=int)
22 | parser.add_argument('--color_jitter', default=0.4, type=float)
23 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
24 | help='Use AutoAugment policy. "v0" or "original". " + \
25 | "(default: rand-m9-mstd0.5-inc1)'),
26 | parser.add_argument('--reprob', type=float, default=0.0, metavar='PCT',
27 | help='Random erase prob (default: 0.25)')
28 | parser.add_argument('--remode', type=str, default='pixel',
29 | help='Random erase mode (default: "pixel")')
30 | parser.add_argument('--recount', type=int, default=1,
31 | help='Random erase count (default: 1)')
32 | parser.add_argument('--resplit', action='store_true', default=False,
33 | help='Do not random erase first (clean) augmentation split')
34 | parser.add_argument('--herding_method', default="barycenter", type=str)
35 | parser.add_argument('--memory_size', default=2000, type=int)
36 | parser.add_argument('--fixed_memory', default=False, action="store_true")
37 | parser.add_argument('--lr', default=0.1, type=float)
38 | parser.add_argument('--momentum', default=0.9, type=float)
39 | parser.add_argument('--weight_decay', default=5e-4, type=float)
40 | parser.add_argument('--num_epochs', default=140, type=int)
41 | parser.add_argument('--smooth', default=0.0, type=float)
42 | parser.add_argument('--eval_every_epoch', default=5, type=float)
43 | parser.add_argument('--dist_url', default='env://',
44 | help='url used to set up distributed training')
45 | parser.add_argument('--data_set', default='cifar')
46 | parser.add_argument('--data_path', default='/data/data/data/cifar100')
47 | parser.add_argument('--lambda_kd', default=0.5, type=float)
48 | parser.add_argument('--dynamic_lambda_kd', action="store_true")
49 | return parser
50 |
51 |
52 | def init_seed(args):
53 | np.random.seed(args.seed)
54 | torch.manual_seed(args.seed)
55 | torch.cuda.manual_seed(args.seed)
56 | torch.cuda.manual_seed_all(args.seed)
57 | torch.backends.cudnn.deterministic = True
58 | torch.backends.cudnn.benchmark = False
59 |
60 |
61 | def freeze_parameters(m, requires_grad=False):
62 | if m is None:
63 | return
64 |
65 | if isinstance(m, nn.Parameter):
66 | m.requires_grad = requires_grad
67 | else:
68 | for p in m.parameters():
69 | p.requires_grad = requires_grad
70 |
71 |
72 | def get_backbone(args):
73 | if args.backbone == "resnet32":
74 | backbone = resnet32()
75 | elif args.backbone == "resnet20":
76 | backbone = resnet20()
77 | elif args.backbone == "resnet44":
78 | backbone = resnet44()
79 | elif args.backbone == "resnet56":
80 | backbone = resnet56()
81 | else:
82 | raise NotImplementedError(f'Unknown backbone {args.model}')
83 |
84 | return backbone
85 |
86 |
87 | class CilClassifier(nn.Module):
88 | def __init__(self, embed_dim, nb_classes):
89 | super().__init__()
90 | self.embed_dim = embed_dim
91 | self.heads = nn.ModuleList([nn.Linear(embed_dim, nb_classes).cuda()])
92 |
93 | def __getitem__(self, index):
94 | return self.heads[index]
95 |
96 | def __len__(self):
97 | return len(self.heads)
98 |
99 | def forward(self, x):
100 | logits = torch.cat([head(x) for head in self.heads], dim=1)
101 | return logits
102 |
103 | def adaption(self, nb_classes):
104 | self.heads.append(nn.Linear(self.embed_dim, nb_classes).cuda())
105 |
106 |
107 | class CilModel(nn.Module):
108 | def __init__(self, backbone):
109 | super(CilModel, self).__init__()
110 | self.backbone = get_backbone(backbone)
111 | self.fc = None
112 |
113 | @property
114 | def feature_dim(self):
115 | return self.backbone.out_dim
116 |
117 | def extract_vector(self, x):
118 | return self.backbone(x)
119 |
120 | def forward(self, x):
121 | x = self.backbone(x)
122 | out = self.fc(x)
123 | return out, x
124 |
125 | def copy(self):
126 | return copy.deepcopy(self)
127 |
128 | def freeze(self, names=["all"]):
129 | freeze_parameters(self, requires_grad=True)
130 | self.train()
131 | for name in names:
132 | if name == 'fc':
133 | freeze_parameters(self.fc)
134 | self.fc.eval()
135 | elif name == 'backbone':
136 | freeze_parameters(self.backbone)
137 | self.backbone.eval()
138 | elif name == 'all':
139 | freeze_parameters(self)
140 | self.eval()
141 | else:
142 | raise NotImplementedError(
143 | f'Unknown module name to freeze {name}')
144 | return self
145 |
146 | def prev_model_adaption(self, nb_classes):
147 | if self.fc is None:
148 | self.fc = CilClassifier(self.feature_dim, nb_classes).cuda()
149 | else:
150 | self.fc.adaption(nb_classes)
151 |
152 | def after_model_adaption(self, nb_classes, args):
153 | if args.task_id > 0:
154 | self.weight_align(nb_classes)
155 |
156 | @torch.no_grad()
157 | def weight_align(self, nb_new_classes):
158 | w = torch.cat([head.weight.data for head in self.fc], dim=0)
159 | norms = torch.norm(w, dim=1)
160 |
161 | norm_old = norms[:-nb_new_classes]
162 | norm_new = norms[-nb_new_classes:]
163 |
164 | gamma = torch.mean(norm_old) / torch.mean(norm_new)
165 | print(f"old norm / new norm ={gamma}")
166 | self.fc[-1].weight.data = gamma * w[-nb_new_classes:]
167 |
168 |
169 | @torch.no_grad()
170 | def eval(model, val_loader):
171 | metric_logger = MetricLogger(delimiter=" ")
172 | criterion = nn.CrossEntropyLoss()
173 | model.eval()
174 | for images, target, task_ids in val_loader:
175 | images = images.cuda(non_blocking=True)
176 | target = target.cuda(non_blocking=True)
177 | logits, _ = model(images)
178 | loss = criterion(logits, target)
179 | acc1, acc5 = timm.utils.accuracy(
180 | logits, target, topk=(1, min(5, logits.shape[1])))
181 | batch_size = images.shape[0]
182 | metric_logger.update(loss=loss)
183 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
184 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
185 | metric_logger.synchronize_between_processes()
186 | print(' Acc@1 {top1.global_avg:.3f} loss {losses.global_avg:.3f}'
187 | .format(top1=metric_logger.acc1, losses=metric_logger.loss))
188 | return metric_logger.acc1.global_avg
189 |
190 |
191 | if __name__ == "__main__":
192 |
193 | parser = argparse.ArgumentParser(
194 | 'Class-Incremental Learning training and evaluation script', parents=[get_args_parser()])
195 | args = parser.parse_args()
196 |
197 | init_distributed_mode(args)
198 |
199 | init_seed(args)
200 |
201 | args.class_order = [68, 56, 78, 8,
202 | 23, 84, 90, 65, 74, 76, 40, 89, 3, 92, 55, 9, 26, 80, 43, 38, 58, 70, 77, 1, 85, 19, 17, 50, 28, 53, 13, 81, 45, 82, 6, 59, 83, 16, 15, 44, 91, 41, 72, 60, 79, 52, 20, 10, 31, 54, 37, 95, 14, 71, 96, 98, 97, 2, 64, 66, 42, 22, 35, 86, 24, 34, 87, 21, 99, 0, 88, 27, 18, 94, 11, 12, 47, 25, 30, 46, 62, 69, 36, 61, 7, 63, 75, 5, 32, 4, 51, 48, 73, 93, 39, 67, 29, 49, 57, 33]
203 | scenario_train, args.nb_classes = build_dataset(is_train=True, args=args)
204 | scenario_val, _ = build_dataset(is_train=False, args=args)
205 |
206 | model = CilModel(args)
207 | model = model.cuda()
208 | model_without_ddp = model
209 |
210 | torch.distributed.barrier()
211 |
212 | memory = rehearsal.RehearsalMemory(
213 | memory_size=args.memory_size,
214 | herding_method=args.herding_method,
215 | fixed_memory=args.fixed_memory
216 | )
217 | teacher_model = None
218 |
219 | criterion = nn.CrossEntropyLoss(label_smoothing=args.smooth)
220 |
221 | kd_criterion = SoftTarget(T=2)
222 | args.increment_per_task = [args.num_bases] + \
223 | [args.increment for _ in range(len(scenario_train) - 1)]
224 | args.known_classes = 0
225 | acc1s = []
226 | for task_id, dataset_train in enumerate(scenario_train):
227 | args.task_id = task_id
228 |
229 | dataset_val = scenario_val[:task_id + 1]
230 | if task_id > 0:
231 | dataset_train.add_samples(*memory.get())
232 | train_sampler = torch.utils.data.DistributedSampler(
233 | dataset_train, num_replicas=args.world_size, rank=args.rank, shuffle=True)
234 | val_sampler = torch.utils.data.DistributedSampler(
235 | dataset_val, num_replicas=args.world_size, rank=args.rank, shuffle=False)
236 | train_loader = DataLoader(dataset_train, batch_size=args.batch_size,
237 | sampler=train_sampler, num_workers=10, pin_memory=True)
238 | val_loader = DataLoader(
239 | dataset_val, batch_size=args.batch_size, sampler=val_sampler, num_workers=10)
240 |
241 | model_without_ddp.prev_model_adaption(args.increment_per_task[task_id])
242 |
243 | model = torch.nn.parallel.DistributedDataParallel(
244 | model_without_ddp, device_ids=[args.rank])
245 |
246 | optimizer = torch.optim.SGD(model_without_ddp.parameters(
247 | ), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
248 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
249 | optimizer, T_max=args.num_epochs)
250 |
251 | for epoch in range(args.num_epochs):
252 | model.train()
253 | train_sampler.set_epoch(epoch)
254 | metric_logger = MetricLogger(delimiter=" ")
255 | for idx, (inputs, targets, task_ids) in enumerate(train_loader):
256 | inputs = inputs.cuda(non_blocking=True)
257 | targets = targets.cuda(non_blocking=True)
258 | logits, _ = model(inputs)
259 | loss_ce = criterion(logits, targets)
260 | if teacher_model is not None:
261 | t_logits, _ = teacher_model(inputs)
262 | loss_kd = args.lambda_kd * \
263 | kd_criterion(logits[:, :args.known_classes], t_logits)
264 | else:
265 | loss_kd = torch.tensor(0.).cuda(non_blocking=True)
266 | loss = loss_ce + loss_kd
267 | acc1, acc5 = timm.utils.accuracy(
268 | logits, targets, topk=(1, min(5, logits.shape[1])))
269 | optimizer.zero_grad()
270 | loss.backward()
271 | optimizer.step()
272 | torch.distributed.barrier()
273 | metric_logger.update(ce=loss_ce)
274 | metric_logger.update(kd=loss_kd)
275 | metric_logger.update(loss=loss)
276 | metric_logger.update(acc1=acc1)
277 | metric_logger.synchronize_between_processes()
278 | lr_scheduler.step()
279 | print(
280 | f"train states: epoch :[{epoch+1}/{args.num_epochs}] {metric_logger}")
281 |
282 | if (epoch+1) % args.eval_every_epoch == 0:
283 | eval(model, val_loader)
284 |
285 | model_without_ddp.after_model_adaption(
286 | args.increment_per_task[task_id], args)
287 | acc1 = eval(model, val_loader)
288 | acc1s.append(acc1)
289 | print(f"task id = {task_id} @Acc1 = {acc1:.5f}, acc1s = {acc1s}")
290 | teacher_model = model_without_ddp.copy().freeze()
291 |
292 | unshuffle_train_loader = DataLoader(
293 | dataset_train, batch_size=args.batch_size, shuffle=False)
294 | features = []
295 | for i, (inputs, labels, task_ids) in enumerate(unshuffle_train_loader):
296 | inputs = inputs.cuda(non_blocking=True)
297 | features.append(model_without_ddp.extract_vector(
298 | inputs).detach().cpu().numpy())
299 | features = np.concatenate(features, axis=0)
300 | memory.add(
301 | *dataset_train.get_raw_samples(), features
302 | )
303 | args.known_classes += args.increment_per_task[task_id]
304 |
305 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | from collections import deque
2 | import os
3 | from collections import defaultdict, deque
4 | import warnings
5 | from xml.dom import NotSupportedErr
6 | import numpy as np
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | import torch
10 | import torch.distributed as dist
11 | from continuum.datasets import CIFAR100, ImageFolderDataset
12 | from continuum import ClassIncremental
13 | from timm.data import create_transform
14 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
15 | from torchvision import transforms
16 | try:
17 | interpolation = torch.transforms.functional.InterpolationMode.BICUBIC
18 | except:
19 | interpolation = 3
20 |
21 |
22 | class SmoothedValue(object):
23 | def __init__(self, window_size=20, fmt=None):
24 | if fmt is None:
25 | fmt = "{median:.4f} ({global_avg:.4f})"
26 | self.deque = deque(maxlen=window_size)
27 | self.total = 0.0
28 | self.count = 0
29 | self.fmt = fmt
30 |
31 | def update(self, value, n=1):
32 | self.deque.append(value)
33 | self.count += n
34 | self.total += value * n
35 |
36 | def synchronize_between_processes(self):
37 | t = torch.tensor([self.count, self.total],
38 | dtype=torch.float64, device='cuda')
39 | dist.barrier()
40 | dist.all_reduce(t)
41 | t = t.tolist()
42 | self.count = int(t[0])
43 | self.total = t[1]
44 |
45 | @property
46 | def median(self):
47 | d = torch.tensor(list(self.deque))
48 | return d.median().item()
49 |
50 | @property
51 | def avg(self):
52 | d = torch.tensor(list(self.deque), dtype=torch.float32)
53 | return d.mean().item()
54 |
55 | @property
56 | def global_avg(self):
57 | return self.total / self.count
58 |
59 | @property
60 | def max(self):
61 | return max(self.deque)
62 |
63 | @property
64 | def value(self):
65 | return self.deque[-1]
66 |
67 | def __str__(self):
68 | return self.fmt.format(
69 | median=self.median,
70 | avg=self.avg,
71 | global_avg=self.global_avg,
72 | max=self.max,
73 | value=self.value)
74 |
75 |
76 | class MetricLogger(object):
77 | def __init__(self, delimiter="\t"):
78 | self.meters = defaultdict(SmoothedValue)
79 | self.delimiter = delimiter
80 |
81 | def update(self, **kwargs):
82 | for k, v in kwargs.items():
83 | if v is None:
84 | continue
85 | if isinstance(v, torch.Tensor):
86 | v = v.item()
87 | assert isinstance(v, (float, int))
88 | self.meters[k].update(v)
89 |
90 | def update_dict(self, d):
91 | for k, v in d.items():
92 | if isinstance(v, torch.Tensor):
93 | v = v.item()
94 | assert isinstance(v, (float, int))
95 | self.meters[k].update(v)
96 |
97 | def __getattr__(self, attr):
98 | if attr in self.meters:
99 | return self.meters[attr]
100 | if attr in self.__dict__:
101 | return self.__dict__[attr]
102 | raise AttributeError("'{}' object has no attribute '{}'".format(
103 | type(self).__name__, attr))
104 |
105 | def __str__(self):
106 | loss_str = []
107 | for name, meter in self.meters.items():
108 | loss_str.append(
109 | "{}: {}".format(name, str(meter))
110 | )
111 | return self.delimiter.join(loss_str)
112 |
113 | def synchronize_between_processes(self):
114 | for meter in self.meters.values():
115 | meter.synchronize_between_processes()
116 |
117 | def add_meter(self, name, meter):
118 | self.meters[name] = meter
119 |
120 |
121 | class SoftTarget(nn.Module):
122 |
123 | def __init__(self, T=2):
124 | super(SoftTarget, self).__init__()
125 | self.T = T
126 |
127 | def forward(self, out_s, out_t):
128 | loss = F.kl_div(F.log_softmax(out_s/self.T, dim=1),
129 | F.softmax(out_t/self.T, dim=1),
130 | reduction='batchmean') * self.T * self.T
131 |
132 | return loss
133 |
134 |
135 | def init_distributed_mode(args):
136 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
137 | args.rank = int(os.environ["RANK"])
138 | args.world_size = int(os.environ['WORLD_SIZE'])
139 | args.gpu = int(os.environ['LOCAL_RANK'])
140 | else:
141 | print('Not using distributed mode')
142 | args.distributed = False
143 | raise NotSupportedErr("not supported yet!")
144 | return
145 | args.distributed = True
146 | torch.cuda.set_device(args.gpu)
147 | args.dist_backend = 'nccl'
148 | print('| distributed init (rank {}): {}'.format(
149 | args.rank, args.dist_url), flush=True)
150 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
151 | world_size=args.world_size, rank=args.rank)
152 | torch.distributed.barrier()
153 | setup_for_distributed(args.rank == 0)
154 |
155 |
156 | def is_main_process():
157 | return dist.get_rank() == 0
158 |
159 |
160 | def setup_for_distributed(is_master):
161 | import builtins as __builtin__
162 | builtin_print = __builtin__.print
163 |
164 | def print(*args, **kwargs):
165 | force = kwargs.pop('force', False)
166 | if is_master or force:
167 | builtin_print(*args, **kwargs)
168 | __builtin__.print = print
169 |
170 |
171 | class ImageNet1000(ImageFolderDataset):
172 | def __init__(
173 | self,
174 | data_path: str,
175 | train: bool = True,
176 | download: bool = False,
177 | ):
178 | super().__init__(data_path=data_path, train=train, download=download)
179 |
180 | def get_data(self):
181 | if self.train:
182 | self.data_path = os.path.join(self.data_path, "train")
183 | else:
184 | self.data_path = os.path.join(self.data_path, "val")
185 | return super().get_data()
186 |
187 |
188 | def build_dataset(is_train, args):
189 | transform = build_transform(is_train, args)
190 |
191 | if args.data_set.lower() == 'cifar':
192 | dataset = CIFAR100(args.data_path, train=is_train, download=True)
193 | elif args.data_set.lower() == 'imagenet1000':
194 | dataset = ImageNet1000(args.data_path, train=is_train)
195 | else:
196 | raise ValueError(f'Unknown dataset {args.data_set}.')
197 |
198 | scenario = ClassIncremental(
199 | dataset,
200 | initial_increment=args.num_bases,
201 | increment=args.increment,
202 | transformations=transform.transforms,
203 | class_order=args.class_order
204 | )
205 | nb_classes = scenario.nb_classes
206 |
207 | return scenario, nb_classes
208 |
209 |
210 | def build_transform(is_train, args):
211 | if args.aa == 'none':
212 | args.aa = None
213 |
214 | with warnings.catch_warnings():
215 | resize_im = args.input_size > 32
216 | if is_train:
217 | transform = create_transform(
218 | input_size=args.input_size,
219 | is_training=True,
220 | color_jitter=args.color_jitter,
221 | auto_augment=args.aa,
222 | interpolation='bicubic',
223 | re_prob=args.reprob,
224 | re_mode=args.remode,
225 | re_count=args.recount,
226 | )
227 | if not resize_im:
228 | transform.transforms[0] = transforms.RandomCrop(
229 | args.input_size, padding=4)
230 |
231 | if args.input_size == 32 and args.data_set == 'CIFAR':
232 | transform.transforms[-1] = transforms.Normalize(
233 | (0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
234 | return transform
235 |
236 | t = []
237 | if resize_im:
238 | size = int((256 / 224) * args.input_size)
239 | t.append(
240 | transforms.Resize(size, interpolation=interpolation),
241 | )
242 | t.append(transforms.CenterCrop(args.input_size))
243 |
244 | t.append(transforms.ToTensor())
245 | if args.input_size == 32 and args.data_set == 'CIFAR':
246 | t.append(transforms.Normalize(
247 | (0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)))
248 | else:
249 | t.append(transforms.Normalize(
250 | IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
251 | return transforms.Compose(t)
252 |
--------------------------------------------------------------------------------