├── .gitignore
├── LICENSE
├── README.md
├── arguments.py
├── assets
└── concept_figure.png
├── augmentations
├── __init__.py
├── eval_aug.py
├── gaussian_blur.py
└── simsiam_aug.py
├── configs
├── __init__.py
├── cifar10
│ ├── distil.yaml
│ ├── distilbuf.yaml
│ └── qdi.yaml
├── cifar100
│ ├── distil.yaml
│ ├── distilbuf.yaml
│ └── qdi.yaml
└── tinyimg
│ ├── distil.yaml
│ ├── distilbuf.yaml
│ └── qdi.yaml
├── datasets
├── __init__.py
├── datasets_utils.py
├── random_dataset.py
├── seq_cifar10.py
├── seq_cifar100.py
├── seq_tinyimagenet.py
├── test
│ ├── seq-cifar10.pt
│ ├── seq-cifar100.pt
│ ├── seq-domainnet.pt
│ └── seq-tinyimg.pt
├── transforms
│ ├── __init__.py
│ ├── denormalization.py
│ ├── permutation.py
│ └── rotation.py
└── utils
│ ├── __init__.py
│ ├── continual_dataset.py
│ └── validation.py
├── linear_eval_alltasks.py
├── main.py
├── models
├── __init__.py
├── backbones
│ ├── Alexnet.py
│ ├── Densenet.py
│ ├── Inception.py
│ ├── Lenet.py
│ ├── Regnet.py
│ ├── ResNet18.py
│ ├── ResNext.py
│ ├── Senet.py
│ ├── Swin.py
│ ├── Vgg.py
│ ├── __init__.py
│ └── utils
│ │ ├── __init__.py
│ │ └── modules.py
├── distil.py
├── distilbuf.py
├── optimizers
│ ├── __init__.py
│ ├── lars.py
│ └── lr_scheduler.py
├── qdi.py
├── simsiam.py
└── utils
│ ├── __init__.py
│ └── continual_model.py
├── requirements.txt
├── tools
├── __init__.py
├── accuracy.py
├── average_meter.py
├── file_exist_fn.py
├── knn_monitor.py
├── logger.py
└── plotter.py
└── utils
├── __init__.py
├── args.py
├── batch_norm.py
├── buffer.py
├── conf.py
├── continual_training.py
├── deep_inversion.py
├── loggers.py
├── losses.py
├── metrics.py
├── status.py
└── tb_logger.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.npy
2 | __pycache__/
3 | checkpoints/
4 | data/
5 | logs/
6 | wandb/
7 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2023, NVIDIA Corporation. All rights reserved.
2 |
3 | Nvidia Source Code License-NC
4 |
5 | 1. Definitions
6 |
7 | “Licensor” means any person or entity that distributes its Work.
8 |
9 | “Work” means (a) the original work of authorship made available under this license, which may include software, documentation,
10 | or other files, and (b) any additions to or derivative works thereof that are made available under this license.
11 |
12 | The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as provided under U.S.
13 | copyright law; provided, however, that for the purposes of this license, derivative works shall not include works that
14 | remain separable from, or merely link (or bind by name) to the interfaces of, the Work.
15 |
16 | Works are “made available” under this license by including in or with the Work either (a) a copyright notice referencing
17 | the applicability of this license to the Work, or (b) a copy of this license.
18 |
19 | 2. License Grant
20 |
21 | 2.1 Copyright Grant. Subject to the terms and conditions of this license, each Licensor grants to you a perpetual,
22 | worldwide, non-exclusive, royalty-free, copyright license to use, reproduce, prepare derivative works of, publicly
23 | display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form.
24 |
25 | 3. Limitations
26 |
27 | 3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this license, (b) you include a
28 | complete copy of this license with your distribution, and (c) you retain without modification any copyright, patent,
29 | trademark, or attribution notices that are present in the Work.
30 |
31 | 3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and distribution
32 | of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use limitation in Section 3.3
33 | applies to your derivative works, and (b) you identify the specific derivative works that are subject to Your Terms.
34 | Notwithstanding Your Terms, this license (including the redistribution requirements in Section 3.1) will continue to apply
35 | to the Work itself.
36 |
37 | 3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use non-commercially.
38 | Notwithstanding the foregoing, NVIDIA Corporation and its affiliates may use the Work and any derivative works commercially.
39 | As used herein, “non-commercially” means for research or evaluation purposes only.
40 |
41 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim, cross-claim
42 | or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then your rights under
43 | this license from such Licensor (including the grant in Section 2.1) will terminate immediately.
44 |
45 | 3.5 Trademarks. This license does not grant any rights to use any Licensor’s or its affiliates’ names, logos, or trademarks,
46 | except as necessary to reproduce the notices described in this license.
47 |
48 | 3.6 Termination. If you violate any term of this license, then your rights under this license (including the grant in Section 2.1)
49 | will terminate immediately.
50 |
51 | 4. Disclaimer of Warranty.
52 |
53 | THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES
54 | OR CONDITIONS OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING
55 | ANY ACTIVITIES UNDER THIS LICENSE.
56 |
57 | 5. Limitation of Liability.
58 |
59 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT,
60 | OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL
61 | DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL,
62 | BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER DAMAGES OR LOSSES), EVEN IF THE LICENSOR
63 | HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
64 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # Heterogeneous Continual Learning
4 | [-FFD93D.svg)](https://cvpr.thecvf.com/)
5 | [](https://arxiv.org/abs/2306.08593)
6 |
7 |
8 | Official PyTorch implementation of CVPR 2023 Highlight (Top 10%) paper [**Heterogeneous Continual Learning**](https://arxiv.org/abs/2306.08593).
9 |
10 | **Authors**: [Divyam Madaan](https://dmadaan.com/), [Hongxu Yin](https://hongxu-yin.github.i), [Wonmin Byeon](https://wonmin-byeon.github.i), [Pavlo Molchanov](https://research.nvidia.com/person/pavlo-molchano),
11 |
12 | For business inquiries, please visit our website and submit the form: [NVIDIA Research Licensing](https://www.nvidia.com/en-us/research/inquiries/)
13 |
14 | **TL;DR: First continual learning approach in which the architecture continuously evolves with the data.**
15 |
16 | 
17 |
18 |
19 | ## Abstract
20 | We propose a novel framework and a solution to tackle
21 | the continual learning (CL) problem with changing network
22 | architectures. Most CL methods focus on adapting a single
23 | architecture to a new task/class by modifying its weights.
24 | However, with rapid progress in architecture design, the
25 | problem of adapting existing solutions to novel architectures
26 | becomes relevant. To address this limitation, we propose
27 | Heterogeneous Continual Learning (HCL), where a wide
28 | range of evolving network architectures emerge continually
29 | together with novel data/tasks. As a solution, we build on
30 | top of the distillation family of techniques and modify it
31 | to a new setting where a weaker model takes the role of a
32 | teacher; meanwhile, a new stronger architecture acts as a
33 | student. Furthermore, we consider a setup of limited access
34 | to previous data and propose Quick Deep Inversion (QDI) to
35 | recover prior task visual features to support knowledge trans-
36 | fer. QDI significantly reduces computational costs compared
37 | to previous solutions and improves overall performance. In
38 | summary, we propose a new setup for CL with a modified
39 | knowledge distillation paradigm and design a quick data
40 | inversion method to enhance distillation. Our evaluation
41 | of various benchmarks shows a significant improvement on
42 | accuracy in comparison to state-of-the-art methods over
43 | various networks architectures.
44 |
45 | __Contribution of this work__
46 |
47 | - We propose a novel CL framework called Heteroge-
48 | neous Continual Learning (HCL) to learn a stream of
49 | different architectures on a sequence of tasks while
50 | transferring the knowledge from past representations.
51 | - We revisit knowledge distillation and propose Quick
52 | Deep Inversion (QDI), which inverts the previous task
53 | parameters while interpolating the current task exam-
54 | ples with minimal additional cost.
55 | - We benchmark existing state-of-the-art solutions in the
56 | new setting and outperform them with our proposed
57 | method across a diverse stream of architectures for both
58 | task-incremental and class-incremental CL.
59 |
60 | ## Prerequisites
61 |
62 | ```
63 | $ pip install -r requirements.txt
64 | ```
65 |
66 | ## 🚀 Quick start
67 |
68 | ### Training
69 |
70 | ```python
71 | python main.py --data_dir ../data/ --log_dir ./logs/scl/ -c configs/cifar10/distil.yaml --ckpt_dir ./checkpoints/c10/scl/distil/ --hide_progress --cl_default --validation --hcl
72 |
73 | ```
74 |
75 | ### Evaluation
76 |
77 | ```python
78 | python linear_eval_alltasks.py --data_dir ../data/ --log_dir ./logs/scl/ -c configs/cifar10/distil.yaml --ckpt_dir ./checkpoints/c10/scl/distil/ --hide_progress --cl_default --hcl
79 |
80 | ```
81 |
82 |
83 | To change the dataset and method, use the configuration files from `./configs`.
84 |
85 | # Contributing
86 |
87 | We'd love to accept your contributions to this project. Please feel free to open an issue, or submit a pull request as necessary. If you have implementations of this repository in other ML frameworks, please reach out so we may highlight them here.
88 |
89 | ## 🎗️ Acknowledgment
90 |
91 | The code is build upon [aimagelab/mammoth](https://github.com/aimagelab/mammoth), [divyam3897/UCL](https://github.com/divyam3897/UCL), [kuangliu/pytorch-cifar](https://github.com/kuangliu/pytorch-cifar/tree/master), [sutd-visual-computing-group/LS-KD-compatibility](https://github.com/sutd-visual-computing-group/LS-KD-compatibility), and [berniwal/swin-transformer-pytorch](https://github.com/berniwal/swin-transformer-pytorch).
92 |
93 | We thank the authors for their amazing work and releasing the code base.
94 |
95 |
96 | ## Licenses
97 |
98 | Copyright © 2023, NVIDIA Corporation. All rights reserved.
99 |
100 | This work is made available under the NVIDIA Source Code License-NC. Click [here](LICENSE) to view a copy of this license.
101 |
102 | For license information regarding the mammoth repository, please refer to its [repository](https://github.com/aimagelab/mammoth/blob/master/LICENSE).
103 | For license information regarding the UCL repository, please refer to its [repository](https://github.com/divyam3897/UCL/blob/main/LICENSE).
104 | For license information regarding the pytorch-cifar repository, please refer to its [repository](https://github.com/kuangliu/pytorch-cifar/blob/master/LICENSE).
105 | For license information regarding the LS-KD repository, please refer to its [repository](https://github.com/sutd-visual-computing-group/LS-KD-compatibility/blob/master/LICENSE).
106 | For license information regarding the swin-transformer repository, please refer to its [repository](https://github.com/berniwal/swin-transformer-pytorch/blob/master/LICENSE).
107 |
108 |
109 | ## 📌 Citation
110 |
111 | If you find this paper useful, please consider staring 🌟 this repo and citing 📑 our paper:
112 |
113 | ```bibtex
114 | @inproceedings{madaan2023heterogeneous,
115 | title={Heterogeneous Continual Learning},
116 | author={Madaan, Divyam and Yin, Hongxu and Byeon, Wonmin and Kautz, Jan and Molchanov, Pavlo},
117 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
118 | year={2023}
119 |
120 | ```
121 |
--------------------------------------------------------------------------------
/arguments.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import torch
4 |
5 | import numpy as np
6 | import torch
7 | import random
8 |
9 | import re
10 | import yaml
11 |
12 | import shutil
13 | import warnings
14 |
15 | from datetime import datetime
16 |
17 |
18 | class Namespace(object):
19 | def __init__(self, somedict):
20 | for key, value in somedict.items():
21 | assert isinstance(key, str) and re.match("[A-Za-z_-]", key)
22 | if isinstance(value, dict):
23 | self.__dict__[key] = Namespace(value)
24 | else:
25 | self.__dict__[key] = value
26 |
27 | def __getattr__(self, attribute):
28 |
29 | raise AttributeError(f"Can not find {attribute} in namespace. Please write {attribute} in your config file(xxx.yaml)!")
30 |
31 |
32 | def set_deterministic(seed):
33 | # seed by default is None
34 | if seed is not None:
35 | print(f"Deterministic with seed = {seed}")
36 | random.seed(seed)
37 | np.random.seed(seed)
38 | torch.manual_seed(seed)
39 | torch.cuda.manual_seed(seed)
40 | torch.backends.cudnn.deterministic = True
41 | torch.backends.cudnn.benchmark = False
42 |
43 | def get_args():
44 | parser = argparse.ArgumentParser()
45 | parser.add_argument('-c', '--config-file', required=True, type=str, help="xxx.yaml")
46 | parser.add_argument('--debug', action='store_true')
47 | parser.add_argument('--debug_subset_size', type=int, default=8)
48 | parser.add_argument('--download', action='store_true', help="if can't find dataset, download from web")
49 | parser.add_argument('--data_dir', type=str, default=os.getenv('DATA'))
50 | parser.add_argument('--log_dir', type=str, default=os.getenv('LOG'))
51 | parser.add_argument('--ckpt_dir', type=str, default=os.getenv('CHECKPOINT'))
52 | parser.add_argument('--ckpt_dir_1', type=str, default=os.getenv('CHECKPOINT'))
53 | parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
54 | parser.add_argument('--eval_from', type=str, default=None)
55 | parser.add_argument('--hide_progress', action='store_true')
56 | parser.add_argument('--cl_default', action='store_true')
57 | parser.add_argument('--server', action='store_true')
58 | parser.add_argument('--hcl', action='store_true')
59 | parser.add_argument('--buffer_qdi', action='store_true')
60 | parser.add_argument('--validation', action='store_true',
61 | help='Test on the validation set')
62 | parser.add_argument('--ood_eval', action='store_true',
63 | help='Test on the OOD set')
64 | parser.add_argument('--alpha', type=float, default=0.3)
65 | args = parser.parse_args()
66 |
67 |
68 | with open(args.config_file, 'r') as f:
69 | for key, value in Namespace(yaml.load(f, Loader=yaml.FullLoader)).__dict__.items():
70 | vars(args)[key] = value
71 |
72 | if args.debug:
73 | if args.train:
74 | args.train.batch_size = 2
75 | args.train.num_epochs = 1
76 | args.train.stop_at_epoch = 1
77 | if args.eval:
78 | args.eval.batch_size = 2
79 | args.eval.num_epochs = 1 # train only one epoch
80 | args.dataset.num_workers = 0
81 |
82 |
83 | assert not None in [args.log_dir, args.data_dir, args.ckpt_dir, args.name]
84 |
85 | args.log_dir = os.path.join(args.log_dir, 'in-progress_'+datetime.now().strftime('%m%d%H%M%S_')+args.name)
86 |
87 | os.makedirs(args.log_dir, exist_ok=False)
88 | print(f'creating file {args.log_dir}')
89 | os.makedirs(args.ckpt_dir, exist_ok=True)
90 |
91 | shutil.copy2(args.config_file, args.log_dir)
92 | set_deterministic(args.seed)
93 |
94 |
95 | vars(args)['aug_kwargs'] = {
96 | 'name':args.model.name,
97 | 'image_size': args.dataset.image_size,
98 | 'cl_default': args.cl_default
99 | }
100 | vars(args)['dataset_kwargs'] = {
101 | # 'name':args.model.name,
102 | # 'image_size': args.dataset.image_size,
103 | 'dataset':args.dataset.name,
104 | 'data_dir': args.data_dir,
105 | 'download':args.download,
106 | 'debug_subset_size': args.debug_subset_size if args.debug else None,
107 | # 'drop_last': True,
108 | # 'pin_memory': True,
109 | # 'num_workers': args.dataset.num_workers,
110 | }
111 | vars(args)['dataloader_kwargs'] = {
112 | 'drop_last': True,
113 | 'pin_memory': True,
114 | 'num_workers': args.dataset.num_workers,
115 | }
116 |
117 | return args
118 |
--------------------------------------------------------------------------------
/assets/concept_figure.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/HCL/4aa98387995946061ade7b3eede3715e7c29c1ba/assets/concept_figure.png
--------------------------------------------------------------------------------
/augmentations/__init__.py:
--------------------------------------------------------------------------------
1 | from .simsiam_aug import SimSiamTransform
2 | from .eval_aug import Transform_single
3 |
4 |
5 | def get_aug(name='simsiam', image_size=224, train=True, train_classifier=None, mean_std=None, **aug_kwargs):
6 | if train==True:
7 | augmentation = SimSiamTransform(image_size, mean_std=mean_std, **aug_kwargs)
8 | elif train==False:
9 | if train_classifier is None:
10 | raise Exception
11 | augmentation = Transform_single(image_size, train=train_classifier, mean_std=mean_std)
12 | else:
13 | raise Exception
14 |
15 | return augmentation
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
--------------------------------------------------------------------------------
/augmentations/eval_aug.py:
--------------------------------------------------------------------------------
1 | from torchvision import transforms
2 | from PIL import Image
3 |
4 |
5 | class Transform_single():
6 | def __init__(self, image_size, train, mean_std):
7 | if train == True:
8 | self.transform = transforms.Compose([
9 | transforms.RandomResizedCrop(image_size, scale=(0.08, 1.0), ratio=(3.0/4.0,4.0/3.0), interpolation=Image.BICUBIC),
10 | # transforms.RandomCrop(image_size, padding=4),
11 | transforms.RandomHorizontalFlip(),
12 | transforms.ToTensor(),
13 | transforms.Normalize(*mean_std)
14 | ])
15 | else:
16 | self.transform = transforms.Compose([
17 | # transforms.Resize(int(image_size*(8/7)), interpolation=Image.BICUBIC), # 224 -> 256
18 | # transforms.CenterCrop(image_size),
19 | transforms.ToTensor(),
20 | transforms.Normalize(*mean_std)
21 | ])
22 |
23 | def __call__(self, x):
24 | return self.transform(x)
25 |
--------------------------------------------------------------------------------
/augmentations/gaussian_blur.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 | from torchvision.transforms.functional import to_pil_image, to_tensor
4 | from torch.nn.functional import conv2d, pad as torch_pad
5 | from typing import Any, List, Sequence, Optional
6 | import numbers
7 | import numpy as np
8 | import torch
9 | from PIL import Image
10 | from typing import Tuple
11 |
12 | class GaussianBlur(torch.nn.Module):
13 | """Blurs image with randomly chosen Gaussian blur.
14 | The image can be a PIL Image or a Tensor, in which case it is expected
15 | to have [..., C, H, W] shape, where ... means an arbitrary number of leading
16 | dimensions
17 |
18 | Args:
19 | kernel_size (int or sequence): Size of the Gaussian kernel.
20 | sigma (float or tuple of float (min, max)): Standard deviation to be used for
21 | creating kernel to perform blurring. If float, sigma is fixed. If it is tuple
22 | of float (min, max), sigma is chosen uniformly at random to lie in the
23 | given range.
24 |
25 | Returns:
26 | PIL Image or Tensor: Gaussian blurred version of the input image.
27 |
28 | """
29 |
30 | def __init__(self, kernel_size, sigma=(0.1, 2.0)):
31 | super().__init__()
32 | self.kernel_size = _setup_size(kernel_size, "Kernel size should be a tuple/list of two integers")
33 | for ks in self.kernel_size:
34 | if ks <= 0 or ks % 2 == 0:
35 | raise ValueError("Kernel size value should be an odd and positive number.")
36 |
37 | if isinstance(sigma, numbers.Number):
38 | if sigma <= 0:
39 | raise ValueError("If sigma is a single number, it must be positive.")
40 | sigma = (sigma, sigma)
41 | elif isinstance(sigma, Sequence) and len(sigma) == 2:
42 | if not 0. < sigma[0] <= sigma[1]:
43 | raise ValueError("sigma values should be positive and of the form (min, max).")
44 | else:
45 | raise ValueError("sigma should be a single number or a list/tuple with length 2.")
46 |
47 | self.sigma = sigma
48 |
49 | @staticmethod
50 | def get_params(sigma_min: float, sigma_max: float) -> float:
51 | """Choose sigma for random gaussian blurring.
52 |
53 | Args:
54 | sigma_min (float): Minimum standard deviation that can be chosen for blurring kernel.
55 | sigma_max (float): Maximum standard deviation that can be chosen for blurring kernel.
56 |
57 | Returns:
58 | float: Standard deviation to be passed to calculate kernel for gaussian blurring.
59 | """
60 | return torch.empty(1).uniform_(sigma_min, sigma_max).item()
61 |
62 | def forward(self, img: Tensor) -> Tensor:
63 | """
64 | Args:
65 | img (PIL Image or Tensor): image to be blurred.
66 |
67 | Returns:
68 | PIL Image or Tensor: Gaussian blurred image
69 | """
70 | sigma = self.get_params(self.sigma[0], self.sigma[1])
71 | return gaussian_blur(img, self.kernel_size, [sigma, sigma])
72 |
73 | def __repr__(self):
74 | s = '(kernel_size={}, '.format(self.kernel_size)
75 | s += 'sigma={})'.format(self.sigma)
76 | return self.__class__.__name__ + s
77 |
78 | @torch.jit.unused
79 | def _is_pil_image(img: Any) -> bool:
80 | return isinstance(img, Image.Image)
81 | def _setup_size(size, error_msg):
82 | if isinstance(size, numbers.Number):
83 | return int(size), int(size)
84 |
85 | if isinstance(size, Sequence) and len(size) == 1:
86 | return size[0], size[0]
87 |
88 | if len(size) != 2:
89 | raise ValueError(error_msg)
90 |
91 | return size
92 | def _is_tensor_a_torch_image(x: Tensor) -> bool:
93 | return x.ndim >= 2
94 | def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> Tensor:
95 | ksize_half = (kernel_size - 1) * 0.5
96 |
97 | x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
98 | pdf = torch.exp(-0.5 * (x / sigma).pow(2))
99 | kernel1d = pdf / pdf.sum()
100 |
101 | return kernel1d
102 |
103 | def _cast_squeeze_in(img: Tensor, req_dtype: torch.dtype) -> Tuple[Tensor, bool, bool, torch.dtype]:
104 | need_squeeze = False
105 | # make image NCHW
106 | if img.ndim < 4:
107 | img = img.unsqueeze(dim=0)
108 | need_squeeze = True
109 |
110 | out_dtype = img.dtype
111 | need_cast = False
112 | if out_dtype != req_dtype:
113 | need_cast = True
114 | img = img.to(req_dtype)
115 | return img, need_cast, need_squeeze, out_dtype
116 | def _cast_squeeze_out(img: Tensor, need_cast: bool, need_squeeze: bool, out_dtype: torch.dtype):
117 | if need_squeeze:
118 | img = img.squeeze(dim=0)
119 |
120 | if need_cast:
121 | # it is better to round before cast
122 | img = torch.round(img).to(out_dtype)
123 |
124 | return img
125 | def _get_gaussian_kernel2d(
126 | kernel_size: List[int], sigma: List[float], dtype: torch.dtype, device: torch.device
127 | ) -> Tensor:
128 | kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0]).to(device, dtype=dtype)
129 | kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1]).to(device, dtype=dtype)
130 | kernel2d = torch.mm(kernel1d_y[:, None], kernel1d_x[None, :])
131 | return kernel2d
132 | def _gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Tensor:
133 | """PRIVATE METHOD. Performs Gaussian blurring on the img by given kernel.
134 |
135 | .. warning::
136 |
137 | Module ``transforms.functional_tensor`` is private and should not be used in user application.
138 | Please, consider instead using methods from `transforms.functional` module.
139 |
140 | Args:
141 | img (Tensor): Image to be blurred
142 | kernel_size (sequence of int or int): Kernel size of the Gaussian kernel ``(kx, ky)``.
143 | sigma (sequence of float or float, optional): Standard deviation of the Gaussian kernel ``(sx, sy)``.
144 |
145 | Returns:
146 | Tensor: An image that is blurred using gaussian kernel of given parameters
147 | """
148 | if not (isinstance(img, torch.Tensor) or _is_tensor_a_torch_image(img)):
149 | raise TypeError('img should be Tensor Image. Got {}'.format(type(img)))
150 |
151 | dtype = img.dtype if torch.is_floating_point(img) else torch.float32
152 | kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype, device=img.device)
153 | kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1])
154 |
155 | img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, kernel.dtype)
156 |
157 | # padding = (left, right, top, bottom)
158 | padding = [kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[1] // 2]
159 | img = torch_pad(img, padding, mode="reflect")
160 | img = conv2d(img, kernel, groups=img.shape[-3])
161 |
162 | img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype)
163 | return img
164 |
165 | def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Tensor:
166 | """Performs Gaussian blurring on the img by given kernel.
167 | The image can be a PIL Image or a Tensor, in which case it is expected
168 | to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
169 |
170 | Args:
171 | img (PIL Image or Tensor): Image to be blurred
172 | kernel_size (sequence of ints or int): Gaussian kernel size. Can be a sequence of integers
173 | like ``(kx, ky)`` or a single integer for square kernels.
174 | In torchscript mode kernel_size as single int is not supported, use a tuple or
175 | list of length 1: ``[ksize, ]``.
176 | sigma (sequence of floats or float, optional): Gaussian kernel standard deviation. Can be a
177 | sequence of floats like ``(sigma_x, sigma_y)`` or a single float to define the
178 | same sigma in both X/Y directions. If None, then it is computed using
179 | ``kernel_size`` as ``sigma = 0.3 * ((kernel_size - 1) * 0.5 - 1) + 0.8``.
180 | Default, None. In torchscript mode sigma as single float is
181 | not supported, use a tuple or list of length 1: ``[sigma, ]``.
182 |
183 | Returns:
184 | PIL Image or Tensor: Gaussian Blurred version of the image.
185 | """
186 | if not isinstance(kernel_size, (int, list, tuple)):
187 | raise TypeError('kernel_size should be int or a sequence of integers. Got {}'.format(type(kernel_size)))
188 | if isinstance(kernel_size, int):
189 | kernel_size = [kernel_size, kernel_size]
190 | if len(kernel_size) != 2:
191 | raise ValueError('If kernel_size is a sequence its length should be 2. Got {}'.format(len(kernel_size)))
192 | for ksize in kernel_size:
193 | if ksize % 2 == 0 or ksize < 0:
194 | raise ValueError('kernel_size should have odd and positive integers. Got {}'.format(kernel_size))
195 |
196 | if sigma is None:
197 | sigma = [ksize * 0.15 + 0.35 for ksize in kernel_size]
198 |
199 | if sigma is not None and not isinstance(sigma, (int, float, list, tuple)):
200 | raise TypeError('sigma should be either float or sequence of floats. Got {}'.format(type(sigma)))
201 | if isinstance(sigma, (int, float)):
202 | sigma = [float(sigma), float(sigma)]
203 | if isinstance(sigma, (list, tuple)) and len(sigma) == 1:
204 | sigma = [sigma[0], sigma[0]]
205 | if len(sigma) != 2:
206 | raise ValueError('If sigma is a sequence, its length should be 2. Got {}'.format(len(sigma)))
207 | for s in sigma:
208 | if s <= 0.:
209 | raise ValueError('sigma should have positive values. Got {}'.format(sigma))
210 |
211 | t_img = img
212 | if not isinstance(img, torch.Tensor):
213 | if not _is_pil_image(img):
214 | raise TypeError('img should be PIL Image or Tensor. Got {}'.format(type(img)))
215 |
216 | t_img = to_tensor(img)
217 |
218 | output = _gaussian_blur(t_img, kernel_size, sigma)
219 |
220 | if not isinstance(img, torch.Tensor):
221 | output = to_pil_image(output)
222 | return output
223 |
224 |
225 |
226 |
227 | # if __name__ == "__main__":
228 | # gaussian_blur = GaussianBlur(kernel_size=23)
229 |
--------------------------------------------------------------------------------
/augmentations/simsiam_aug.py:
--------------------------------------------------------------------------------
1 | import torchvision.transforms as T
2 | from PIL import Image
3 | try:
4 | from torchvision.transforms import GaussianBlur
5 | except ImportError:
6 | from .gaussian_blur import GaussianBlur
7 | T.GaussianBlur = GaussianBlur
8 |
9 |
10 | class SimSiamTransform():
11 | def __init__(self, image_size, mean_std, **aug_kwargs):
12 | p_blur = 0.5 if image_size > 32 else 0 # exclude cifar
13 | # self.not_aug_transform = T.Compose([T.ToTensor(), T.Normalize(*mean_std)])
14 | self.not_aug_transform = T.Compose([T.ToTensor()])
15 |
16 | random_crop = T.RandomCrop(image_size, padding=4) if aug_kwargs['cl_default'] else T.RandomResizedCrop(image_size, scale=(0.2, 1.0))
17 | self.transform = T.Compose([
18 | random_crop,
19 | T.RandomHorizontalFlip(),
20 | T.RandomApply([T.ColorJitter(0.4,0.4,0.4,0.1)], p=0.8),
21 | T.RandomGrayscale(p=0.2),
22 | T.RandomApply([T.GaussianBlur(kernel_size=image_size//20*2+1, sigma=(0.1, 2.0))], p=p_blur),
23 | T.ToTensor(),
24 | T.Normalize(*mean_std)
25 | ])
26 | def __call__(self, x):
27 | x1 = self.transform(x)
28 | x2 = self.transform(x)
29 | not_aug_x = self.not_aug_transform(x)
30 | return x1, x2, not_aug_x
31 |
32 |
33 | def to_pil_image(pic, mode=None):
34 | """Convert a tensor or an ndarray to PIL Image.
35 |
36 | See :class:`~torchvision.transforms.ToPILImage` for more details.
37 |
38 | Args:
39 | pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.
40 | mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).
41 |
42 | .. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
43 |
44 | Returns:
45 | PIL Image: Image converted to PIL Image.
46 | """
47 | if not(isinstance(pic, torch.Tensor) or isinstance(pic, np.ndarray)):
48 | raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic)))
49 |
50 | elif isinstance(pic, torch.Tensor):
51 | if pic.ndimension() not in {2, 3}:
52 | raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndimension()))
53 |
54 | elif pic.ndimension() == 2:
55 | # if 2D image, add channel dimension (CHW)
56 | pic = pic.unsqueeze(0)
57 |
58 | elif isinstance(pic, np.ndarray):
59 | if pic.ndim not in {2, 3}:
60 | raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndim))
61 |
62 | elif pic.ndim == 2:
63 | # if 2D image, add channel dimension (HWC)
64 | pic = np.expand_dims(pic, 2)
65 |
66 | npimg = pic
67 | if isinstance(pic, torch.Tensor):
68 | if pic.is_floating_point() and mode != 'F':
69 | pic = pic.mul(255).byte()
70 | npimg = np.transpose(pic.cpu().numpy(), (1, 2, 0))
71 |
72 | if not isinstance(npimg, np.ndarray):
73 | raise TypeError('Input pic must be a torch.Tensor or NumPy ndarray, ' +
74 | 'not {}'.format(type(npimg)))
75 |
76 | if npimg.shape[2] == 1:
77 | expected_mode = None
78 | npimg = npimg[:, :, 0]
79 | if npimg.dtype == np.uint8:
80 | expected_mode = 'L'
81 | elif npimg.dtype == np.int16:
82 | expected_mode = 'I;16'
83 | elif npimg.dtype == np.int32:
84 | expected_mode = 'I'
85 | elif npimg.dtype == np.float32:
86 | expected_mode = 'F'
87 | if mode is not None and mode != expected_mode:
88 | raise ValueError("Incorrect mode ({}) supplied for input type {}. Should be {}"
89 | .format(mode, np.dtype, expected_mode))
90 | mode = expected_mode
91 |
92 | elif npimg.shape[2] == 2:
93 | permitted_2_channel_modes = ['LA']
94 | if mode is not None and mode not in permitted_2_channel_modes:
95 | raise ValueError("Only modes {} are supported for 2D inputs".format(permitted_2_channel_modes))
96 |
97 | if mode is None and npimg.dtype == np.uint8:
98 | mode = 'LA'
99 |
100 | elif npimg.shape[2] == 4:
101 | permitted_4_channel_modes = ['RGBA', 'CMYK', 'RGBX']
102 | if mode is not None and mode not in permitted_4_channel_modes:
103 | raise ValueError("Only modes {} are supported for 4D inputs".format(permitted_4_channel_modes))
104 |
105 | if mode is None and npimg.dtype == np.uint8:
106 | mode = 'RGBA'
107 | else:
108 | permitted_3_channel_modes = ['RGB', 'YCbCr', 'HSV']
109 | if mode is not None and mode not in permitted_3_channel_modes:
110 | raise ValueError("Only modes {} are supported for 3D inputs".format(permitted_3_channel_modes))
111 | if mode is None and npimg.dtype == np.uint8:
112 | mode = 'RGB'
113 |
114 | if mode is None:
115 | raise TypeError('Input type {} is not supported'.format(npimg.dtype))
116 |
117 | return Image.fromarray(npimg, mode=mode)
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
--------------------------------------------------------------------------------
/configs/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/HCL/4aa98387995946061ade7b3eede3715e7c29c1ba/configs/__init__.py
--------------------------------------------------------------------------------
/configs/cifar10/distil.yaml:
--------------------------------------------------------------------------------
1 | name: c10-experiment-hcl
2 | dataset:
3 | name: seq-cifar10
4 | image_size: 32
5 | num_workers: 4
6 |
7 | model:
8 | name: simsiam
9 | backbone: resnet18
10 | cl_model: distil
11 | proj_layers: 2
12 | buffer_size: 200
13 |
14 | train:
15 | optimizer:
16 | name: sgd
17 | weight_decay: 0.0005
18 | momentum: 0.9
19 | warmup_epochs: 10
20 | warmup_lr: 0
21 | base_lr: 0.03
22 | final_lr: 0
23 | num_epochs: 200 # this parameter influence the lr decay
24 | stop_at_epoch: 200 # has to be smaller than num_epochs
25 | batch_size: 32
26 | alpha: 3.0
27 | eval: # linear evaluation, False will turn off automatic evaluation after training
28 | optimizer:
29 | name: sgd
30 | weight_decay: 0
31 | momentum: 0.9
32 | warmup_lr: 0
33 | warmup_epochs: 0
34 | base_lr: 30
35 | final_lr: 0
36 | batch_size: 256
37 | num_epochs: 100
38 |
39 | logger:
40 | csv_log: True
41 | tensorboard: True
42 | matplotlib: True
43 |
44 | seed: null # None type for yaml file
45 | # two things might lead to stochastic behavior other than seed:
46 | # worker_init_fn from dataloader and torch.nn.functional.interpolate
47 | # (keep this in mind if you want to achieve 100% deterministic)
48 |
49 |
50 |
51 |
52 |
--------------------------------------------------------------------------------
/configs/cifar10/distilbuf.yaml:
--------------------------------------------------------------------------------
1 | name: c10-experiment-hcl
2 | dataset:
3 | name: seq-cifar10
4 | image_size: 32
5 | num_workers: 4
6 |
7 | model:
8 | name: simsiam
9 | backbone: resnet18
10 | cl_model: distilbuf
11 | proj_layers: 2
12 | buffer_size: 200
13 |
14 | train:
15 | optimizer:
16 | name: sgd
17 | weight_decay: 0.0005
18 | momentum: 0.9
19 | warmup_epochs: 10
20 | warmup_lr: 0
21 | base_lr: 0.03
22 | final_lr: 0
23 | num_epochs: 200 # this parameter influence the lr decay
24 | stop_at_epoch: 200 # has to be smaller than num_epochs
25 | batch_size: 32
26 | alpha: 1.0
27 | eval: # linear evaluation, False will turn off automatic evaluation after training
28 | optimizer:
29 | name: sgd
30 | weight_decay: 0
31 | momentum: 0.9
32 | warmup_lr: 0
33 | warmup_epochs: 0
34 | base_lr: 30
35 | final_lr: 0
36 | batch_size: 256
37 | num_epochs: 100
38 |
39 | logger:
40 | csv_log: True
41 | tensorboard: True
42 | matplotlib: True
43 |
44 | seed: null # None type for yaml file
45 | # two things might lead to stochastic behavior other than seed:
46 | # worker_init_fn from dataloader and torch.nn.functional.interpolate
47 | # (keep this in mind if you want to achieve 100% deterministic)
48 |
49 |
50 |
51 |
52 |
--------------------------------------------------------------------------------
/configs/cifar10/qdi.yaml:
--------------------------------------------------------------------------------
1 | name: simsiam-c10-experiment-resnet18
2 | dataset:
3 | name: seq-cifar10
4 | image_size: 32
5 | num_workers: 4
6 |
7 | model:
8 | name: simsiam
9 | backbone: resnet18
10 | cl_model: qdi
11 | proj_layers: 2
12 | buffer_size: 200
13 |
14 | train:
15 | optimizer:
16 | name: sgd
17 | weight_decay: 0.0005
18 | momentum: 0.9
19 | warmup_epochs: 10
20 | warmup_lr: 0
21 | base_lr: 0.03
22 | final_lr: 0
23 | num_epochs: 200 # this parameter influence the lr decay
24 | stop_at_epoch: 200 # has to be smaller than num_epochs
25 | batch_size: 32
26 | alpha: 1.0
27 | di_lr: 0.005
28 | di_var: 0.001
29 | di_l2: 0.
30 | di_feature: 0.1
31 | di_itrs: 500
32 | eval: # linear evaluation, False will turn off automatic evaluation after training
33 | optimizer:
34 | name: sgd
35 | weight_decay: 0
36 | momentum: 0.9
37 | warmup_lr: 0
38 | warmup_epochs: 0
39 | base_lr: 30
40 | final_lr: 0
41 | batch_size: 256
42 | num_epochs: 100
43 |
44 | logger:
45 | csv_log: True
46 | tensorboard: True
47 | matplotlib: True
48 |
49 | seed: null # None type for yaml file
50 | # two things might lead to stochastic behavior other than seed:
51 | # worker_init_fn from dataloader and torch.nn.functional.interpolate
52 | # (keep this in mind if you want to achieve 100% deterministic)
53 |
54 |
55 |
56 |
57 |
--------------------------------------------------------------------------------
/configs/cifar100/distil.yaml:
--------------------------------------------------------------------------------
1 | name: c100-experiment-hcl
2 | dataset:
3 | name: seq-cifar100
4 | image_size: 32
5 | num_workers: 4
6 |
7 | model:
8 | name: simsiam
9 | backbone: resnet18
10 | cl_model: distil
11 | proj_layers: 2
12 | buffer_size: 200
13 |
14 | train:
15 | optimizer:
16 | name: sgd
17 | weight_decay: 0.0005
18 | momentum: 0.9
19 | warmup_epochs: 10
20 | warmup_lr: 0
21 | base_lr: 0.03
22 | final_lr: 0
23 | num_epochs: 200 # this parameter influence the lr decay
24 | stop_at_epoch: 200 # has to be smaller than num_epochs
25 | batch_size: 32
26 | alpha: 3.0
27 | eval: # linear evaluation, False will turn off automatic evaluation after training
28 | optimizer:
29 | name: sgd
30 | weight_decay: 0
31 | momentum: 0.9
32 | warmup_lr: 0
33 | warmup_epochs: 0
34 | base_lr: 30
35 | final_lr: 0
36 | batch_size: 256
37 | num_epochs: 100
38 |
39 | logger:
40 | csv_log: True
41 | tensorboard: True
42 | matplotlib: True
43 |
44 | seed: null # None type for yaml file
45 | # two things might lead to stochastic behavior other than seed:
46 | # worker_init_fn from dataloader and torch.nn.functional.interpolate
47 | # (keep this in mind if you want to achieve 100% deterministic)
48 |
49 |
50 |
51 |
52 |
--------------------------------------------------------------------------------
/configs/cifar100/distilbuf.yaml:
--------------------------------------------------------------------------------
1 | name: c100-experiment-hcl
2 | dataset:
3 | name: seq-cifar100
4 | image_size: 32
5 | num_workers: 4
6 |
7 | model:
8 | name: simsiam
9 | backbone: resnet18
10 | cl_model: distilbuf
11 | proj_layers: 2
12 | buffer_size: 200
13 |
14 | train:
15 | optimizer:
16 | name: sgd
17 | weight_decay: 0.0005
18 | momentum: 0.9
19 | warmup_epochs: 10
20 | warmup_lr: 0
21 | base_lr: 0.03
22 | final_lr: 0
23 | num_epochs: 200 # this parameter influence the lr decay
24 | stop_at_epoch: 200 # has to be smaller than num_epochs
25 | batch_size: 32
26 | alpha: 3.0
27 | eval: # linear evaluation, False will turn off automatic evaluation after training
28 | optimizer:
29 | name: sgd
30 | weight_decay: 0
31 | momentum: 0.9
32 | warmup_lr: 0
33 | warmup_epochs: 0
34 | base_lr: 30
35 | final_lr: 0
36 | batch_size: 256
37 | num_epochs: 100
38 |
39 | logger:
40 | csv_log: True
41 | tensorboard: True
42 | matplotlib: True
43 |
44 | seed: null # None type for yaml file
45 | # two things might lead to stochastic behavior other than seed:
46 | # worker_init_fn from dataloader and torch.nn.functional.interpolate
47 | # (keep this in mind if you want to achieve 100% deterministic)
48 |
49 |
50 |
51 |
52 |
--------------------------------------------------------------------------------
/configs/cifar100/qdi.yaml:
--------------------------------------------------------------------------------
1 | name: simsiam-c100-experiment-resnet18
2 | dataset:
3 | name: seq-cifar100
4 | image_size: 32
5 | num_workers: 4
6 |
7 | model:
8 | name: simsiam
9 | backbone: resnet18
10 | cl_model: qdi
11 | proj_layers: 2
12 | buffer_size: 200
13 |
14 | train:
15 | optimizer:
16 | name: sgd
17 | weight_decay: 0.0005
18 | momentum: 0.9
19 | warmup_epochs: 10
20 | warmup_lr: 0
21 | base_lr: 0.03
22 | final_lr: 0
23 | num_epochs: 200 # this parameter influence the lr decay
24 | stop_at_epoch: 200 # has to be smaller than num_epochs
25 | batch_size: 32
26 | alpha: 3.0
27 | di_var: 0.003
28 | di_l2: 0.003
29 | di_feature: 0.2
30 | di_itrs: 500
31 | di_lr: 0.03
32 | eval: # linear evaluation, False will turn off automatic evaluation after training
33 | optimizer:
34 | name: sgd
35 | weight_decay: 0
36 | momentum: 0.9
37 | warmup_lr: 0
38 | warmup_epochs: 0
39 | base_lr: 30
40 | final_lr: 0
41 | batch_size: 256
42 | num_epochs: 100
43 |
44 | logger:
45 | csv_log: True
46 | tensorboard: True
47 | matplotlib: True
48 |
49 | seed: null # None type for yaml file
50 | # two things might lead to stochastic behavior other than seed:
51 | # worker_init_fn from dataloader and torch.nn.functional.interpolate
52 | # (keep this in mind if you want to achieve 100% deterministic)
53 |
54 |
55 |
56 |
57 |
--------------------------------------------------------------------------------
/configs/tinyimg/distil.yaml:
--------------------------------------------------------------------------------
1 | name: tinyimagenet-experiment-hcl
2 | dataset:
3 | name: seq-tinyimg
4 | image_size: 64
5 | num_workers: 4
6 |
7 | model:
8 | name: simsiam
9 | backbone: resnet18
10 | cl_model: distil
11 | proj_layers: 2
12 | buffer_size: 200
13 |
14 | train:
15 | optimizer:
16 | name: sgd
17 | weight_decay: 0.0005
18 | momentum: 0.9
19 | warmup_epochs: 10
20 | warmup_lr: 0
21 | base_lr: 0.03
22 | final_lr: 0
23 | num_epochs: 200 # this parameter influence the lr decay
24 | stop_at_epoch: 200 # has to be smaller than num_epochs
25 | batch_size: 32
26 | alpha: 3.0
27 | eval: # linear evaluation, False will turn off automatic evaluation after training
28 | optimizer:
29 | name: sgd
30 | weight_decay: 0
31 | momentum: 0.9
32 | warmup_lr: 0
33 | warmup_epochs: 0
34 | base_lr: 30
35 | final_lr: 0
36 | batch_size: 256
37 | num_epochs: 100
38 |
39 | logger:
40 | csv_log: True
41 | tensorboard: True
42 | matplotlib: True
43 |
44 | seed: null # None type for yaml file
45 | # two things might lead to stochastic behavior other than seed:
46 | # worker_init_fn from dataloader and torch.nn.functional.interpolate
47 | # (keep this in mind if you want to achieve 100% deterministic)
48 |
49 |
50 |
51 |
52 |
--------------------------------------------------------------------------------
/configs/tinyimg/distilbuf.yaml:
--------------------------------------------------------------------------------
1 | name: tinyimagenet-experiment-hcl
2 | dataset:
3 | name: seq-tinyimg
4 | image_size: 64
5 | num_workers: 4
6 |
7 | model:
8 | name: simsiam
9 | backbone: resnet18
10 | cl_model: distilbuf
11 | proj_layers: 2
12 | buffer_size: 200
13 |
14 | train:
15 | optimizer:
16 | name: sgd
17 | weight_decay: 0.0005
18 | momentum: 0.9
19 | warmup_epochs: 10
20 | warmup_lr: 0
21 | base_lr: 0.03
22 | final_lr: 0
23 | num_epochs: 200 # this parameter influence the lr decay
24 | stop_at_epoch: 200 # has to be smaller than num_epochs
25 | batch_size: 32
26 | alpha: 1.0
27 | eval: # linear evaluation, False will turn off automatic evaluation after training
28 | optimizer:
29 | name: sgd
30 | weight_decay: 0
31 | momentum: 0.9
32 | warmup_lr: 0
33 | warmup_epochs: 0
34 | base_lr: 30
35 | final_lr: 0
36 | batch_size: 256
37 | num_epochs: 100
38 |
39 | logger:
40 | csv_log: True
41 | tensorboard: True
42 | matplotlib: True
43 |
44 | seed: null # None type for yaml file
45 | # two things might lead to stochastic behavior other than seed:
46 | # worker_init_fn from dataloader and torch.nn.functional.interpolate
47 | # (keep this in mind if you want to achieve 100% deterministic)
48 |
49 |
50 |
51 |
52 |
--------------------------------------------------------------------------------
/configs/tinyimg/qdi.yaml:
--------------------------------------------------------------------------------
1 | name: tinyimg-experiment-resnet18
2 | dataset:
3 | name: seq-tinyimg
4 | image_size: 64
5 | num_workers: 4
6 |
7 | model:
8 | name: simsiam
9 | backbone: resnet18
10 | cl_model: qdi
11 | proj_layers: 2
12 | buffer_size: 200
13 |
14 | train:
15 | optimizer:
16 | name: sgd
17 | weight_decay: 0.0005
18 | momentum: 0.9
19 | warmup_epochs: 10
20 | warmup_lr: 0
21 | base_lr: 0.03
22 | final_lr: 0
23 | num_epochs: 200 # this parameter influence the lr decay
24 | stop_at_epoch: 200 # has to be smaller than num_epochs
25 | batch_size: 32
26 | alpha: 1.0
27 | di_var: 0.003
28 | di_l2: 0.003
29 | di_feature: 0.2
30 | di_itrs: 500
31 | di_lr: 0.03
32 | eval: # linear evaluation, False will turn off automatic evaluation after training
33 | optimizer:
34 | name: sgd
35 | weight_decay: 0
36 | momentum: 0.9
37 | warmup_lr: 0
38 | warmup_epochs: 0
39 | base_lr: 30
40 | final_lr: 0
41 | batch_size: 256
42 | num_epochs: 100
43 |
44 | logger:
45 | csv_log: True
46 | tensorboard: True
47 | matplotlib: True
48 |
49 | seed: null # None type for yaml file
50 | # two things might lead to stochastic behavior other than seed:
51 | # worker_init_fn from dataloader and torch.nn.functional.interpolate
52 | # (keep this in mind if you want to achieve 100% deterministic)
53 |
54 |
55 |
56 |
57 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from datasets.seq_cifar10 import SequentialCIFAR10
7 | from datasets.seq_cifar100 import SequentialCIFAR100
8 | from datasets.seq_tinyimagenet import SequentialTinyImagenet
9 | from datasets.utils.continual_dataset import ContinualDataset
10 | from argparse import Namespace
11 | import torchvision
12 |
13 | NAMES = {
14 | SequentialCIFAR10.NAME: SequentialCIFAR10,
15 | SequentialCIFAR100.NAME: SequentialCIFAR100,
16 | SequentialTinyImagenet.NAME: SequentialTinyImagenet,
17 | }
18 |
19 | N_CLASSES = {'seq-cifar10': 10, 'seq-cifar100': 100, 'seq-tinyimg': 200}
20 | BACKBONES = {'seq-cifar10': ["lenet", "resnet18", "densenet", "senet", "regnet"],
21 | 'seq-cifar100': ["lenet","lenet", "alexnet", "alexnet", "vgg16", "vgg16", "inception", "inception", "resnet18", "resnet18", "resnext", "resnext", "densenet", "densenet", "senet", "senet", "regnet", "regnet", "regnet", "regnet"],
22 | 'seq-tinyimg': ["lenet", "lenet", "resnet18", "resnet18", "resnext", "resnext", "senet", "senet", "regnet", "regnet"],
23 | }
24 |
25 |
26 | def get_dataset(args: Namespace) -> ContinualDataset:
27 | """
28 | Creates and returns a continual dataset.
29 | :param args: the arguments which contains the hyperparameters
30 | :return: the continual dataset
31 | """
32 | assert args.dataset_kwargs['dataset'] in NAMES.keys()
33 | return NAMES[args.dataset_kwargs['dataset']](args)
34 |
35 |
36 | def get_gcl_dataset(args: Namespace):
37 | """
38 | Creates and returns a GCL dataset.
39 | :param args: the arguments which contains the hyperparameters
40 | :return: the continual dataset
41 | """
42 | assert args.dataset in GCL_NAMES.keys()
43 | return GCL_NAMES[args.dataset](args)
44 |
--------------------------------------------------------------------------------
/datasets/datasets_utils.py:
--------------------------------------------------------------------------------
1 | ##########################################
2 | # Code from https://github.com/joansj/hat
3 | ##########################################
4 |
5 | import os,sys
6 | import os.path
7 | import numpy as np
8 | import torch
9 | import torch.utils.data
10 | from torchvision import datasets,transforms
11 | from sklearn.utils import shuffle
12 | import urllib.request
13 | from PIL import Image
14 | import pickle
15 |
16 |
17 | class FashionMNIST(datasets.MNIST):
18 | """`Fashion MNIST `_ Dataset.
19 | """
20 | urls = [
21 | 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz',
22 | 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz',
23 | 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz',
24 | 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz',
25 | ]
26 |
27 |
--------------------------------------------------------------------------------
/datasets/random_dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | class RandomDataset(torch.utils.data.Dataset):
4 | def __init__(self, root=None, train=True, transform=None, target_transform=None):
5 | self.transform = transform
6 | self.target_transform = target_transform
7 |
8 | self.size = 1000
9 | def __getitem__(self, idx):
10 | if idx < self.size:
11 | return [torch.randn((3, 224, 224)), torch.randn((3, 224, 224))], [0,0,0]
12 | else:
13 | raise Exception
14 |
15 | def __len__(self):
16 | return self.size
17 |
--------------------------------------------------------------------------------
/datasets/seq_cifar10.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from torchvision.datasets import CIFAR10
7 | import torchvision.transforms as transforms
8 | import torch.nn.functional as F
9 | from datasets.seq_tinyimagenet import base_path
10 | from PIL import Image
11 | from datasets.utils.validation import get_train_val
12 | from datasets.utils.continual_dataset import ContinualDataset, store_masked_loaders
13 | from datasets.utils.continual_dataset import get_previous_train_loader
14 | from typing import Tuple
15 | from datasets.transforms.denormalization import DeNormalize
16 | import torch
17 | from augmentations import get_aug
18 | from PIL import Image
19 |
20 | class MyCIFAR10(CIFAR10):
21 | """
22 | Overrides the CIFAR10 dataset to change the getitem function.
23 | """
24 | def __init__(self, root, train=True, transform=None,
25 | target_transform=None, download=False) -> None:
26 | super(MyCIFAR10, self).__init__(root, train, transform, target_transform, download)
27 |
28 | def __getitem__(self, index: int) -> Tuple[type(Image), int, type(Image)]:
29 | """
30 | Gets the requested element from the dataset.
31 | :param index: index of the element to be returned
32 | :returns: tuple: (image, target) where target is index of the target class.
33 | """
34 | img, target = self.data[index], self.targets[index]
35 | img = Image.fromarray(img, mode='RGB')
36 | original_img = img.copy()
37 |
38 |
39 | img, img1, not_aug_img = self.transform(original_img)
40 |
41 | if hasattr(self, 'logits'):
42 | return (img, img1, not_aug_img), target, self.logits[index]
43 |
44 | return (img, img1, not_aug_img), target
45 |
46 |
47 | class SequentialCIFAR10(ContinualDataset):
48 |
49 | NAME = 'seq-cifar10'
50 | SETTING = 'class-il'
51 | N_CLASSES_PER_TASK = 2
52 | N_TASKS = 5
53 |
54 | def get_data_loaders(self, args):
55 | cifar_norm = [[0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2615]]
56 | transform = get_aug(train=True, mean_std=cifar_norm, **args.aug_kwargs)
57 | test_transform = get_aug(train=False, train_classifier=False, mean_std=cifar_norm, **args.aug_kwargs)
58 |
59 | if args.server:
60 | train_dataset = MyCIFAR10('/cifar10-pytorch', train=True,
61 | download=False, transform=transform)
62 | memory_dataset = MyCIFAR10('/cifar10-pytorch', train=True,
63 | download=False, transform=test_transform)
64 | else:
65 | train_dataset = MyCIFAR10(base_path() + 'CIFAR10', train=True,
66 | download=True, transform=transform)
67 | memory_dataset = MyCIFAR10(base_path() + 'CIFAR10', train=True,
68 | download=True, transform=test_transform)
69 |
70 | if self.args.validation:
71 | train_dataset, test_dataset = get_train_val(train_dataset, test_transform, self.NAME)
72 | memory_dataset, _ = get_train_val(memory_dataset, test_transform, self.NAME)
73 | else:
74 | if args.server:
75 | test_dataset = CIFAR10('/cifar10-pytorch',train=False,
76 | download=False, transform=test_transform)
77 | else:
78 | test_dataset = CIFAR10(base_path() + 'CIFAR10',train=False,
79 | download=True, transform=test_transform)
80 |
81 | train, memory, test = store_masked_loaders(train_dataset, test_dataset, memory_dataset, self)
82 | return train, memory, test
83 |
84 |
85 | def get_transform(self, args):
86 | cifar_norm = [[0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2615]]
87 | if args.cl_default:
88 | transform = transforms.Compose(
89 | [transforms.ToPILImage(),
90 | transforms.RandomCrop(32, padding=4),
91 | transforms.RandomHorizontalFlip(),
92 | transforms.ToTensor(),
93 | transforms.Normalize(*cifar_norm)
94 | ])
95 | else:
96 | transform = transforms.Compose(
97 | [transforms.ToPILImage(),
98 | transforms.RandomResizedCrop(32, scale=(0.08, 1.0), ratio=(3.0/4.0,4.0/3.0), interpolation=Image.BICUBIC),
99 | transforms.RandomHorizontalFlip(),
100 | transforms.ToTensor(),
101 | transforms.Normalize(*cifar_norm)
102 | ])
103 |
104 | return transform
105 |
106 | def not_aug_dataloader(self, batch_size):
107 | cifar_norm = [[0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2615]]
108 | transform = transforms.Compose([transforms.ToTensor(),
109 | transforms.Normalize(*cifar_norm)])
110 |
111 | train_dataset = CIFAR10(base_path() + 'CIFAR10', train=True,
112 | download=True, transform=transform)
113 | train_loader = get_previous_train_loader(train_dataset, batch_size, self)
114 |
115 | return train_loader
116 |
--------------------------------------------------------------------------------
/datasets/seq_cifar100.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from torchvision.datasets import CIFAR100
7 | import torchvision.transforms as transforms
8 | import torch.nn.functional as F
9 | from datasets.seq_tinyimagenet import base_path
10 | from PIL import Image
11 | from datasets.utils.validation import get_train_val
12 | from datasets.utils.continual_dataset import ContinualDataset, store_masked_loaders
13 | from datasets.utils.continual_dataset import get_previous_train_loader
14 | from typing import Tuple
15 | from datasets.transforms.denormalization import DeNormalize
16 | import torch
17 | from augmentations import get_aug
18 | from PIL import Image
19 |
20 |
21 | class MyCIFAR100(CIFAR100):
22 | """
23 | Overrides the CIFAR10 dataset to change the getitem function.
24 | """
25 | def __init__(self, root, train=True, transform=None,
26 | target_transform=None, download=False) -> None:
27 | super(MyCIFAR100, self).__init__(root, train, transform, target_transform, download)
28 |
29 | def __getitem__(self, index: int) -> Tuple[type(Image), int, type(Image)]:
30 | """
31 | Gets the requested element from the dataset.
32 | :param index: index of the element to be returned
33 | :returns: tuple: (image, target) where target is index of the target class.
34 | """
35 | img, target = self.data[index], self.targets[index]
36 | img = Image.fromarray(img, mode='RGB')
37 | original_img = img.copy()
38 |
39 | img, img1, not_aug_img = self.transform(original_img)
40 |
41 | if hasattr(self, 'logits'):
42 | return (img, img1, not_aug_img), target, self.logits[index]
43 |
44 | return (img, img1, not_aug_img), target
45 |
46 |
47 | class SequentialCIFAR100(ContinualDataset):
48 |
49 | NAME = 'seq-cifar100'
50 | SETTING = 'class-il'
51 | N_CLASSES_PER_TASK = 5
52 | N_TASKS = 20
53 |
54 | def get_data_loaders(self, args):
55 | cifar_norm = [[0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2615]]
56 | transform = get_aug(train=True, mean_std=cifar_norm, **args.aug_kwargs)
57 | test_transform = get_aug(train=False, train_classifier=False, mean_std=cifar_norm, **args.aug_kwargs)
58 |
59 | if args.server:
60 | train_dataset = MyCIFAR100('/cifar100_data', train=True,
61 | download=False, transform=transform)
62 | memory_dataset = CIFAR100('/cifar100_data', train=True,
63 | download=False, transform=test_transform)
64 | else:
65 | train_dataset = MyCIFAR100(base_path() + 'CIFAR100', train=True,
66 | download=True, transform=transform)
67 | memory_dataset = CIFAR100(base_path() + 'CIFAR100', train=True,
68 | download=True, transform=test_transform)
69 |
70 | if self.args.validation:
71 | train_dataset, test_dataset = get_train_val(train_dataset, test_transform, self.NAME)
72 | memory_dataset, _ = get_train_val(memory_dataset, test_transform, self.NAME)
73 | else:
74 | if args.server:
75 | test_dataset = CIFAR100('/cifar100_data', train=False,
76 | download=False, transform=test_transform)
77 | else:
78 | test_dataset = CIFAR100(base_path() + 'CIFAR100', train=False,
79 | download=True, transform=test_transform)
80 |
81 | train, memory, test = store_masked_loaders(train_dataset, test_dataset, memory_dataset, self)
82 | return train, memory, test
83 |
84 | def get_transform(self, args):
85 | cifar_norm = [[0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2615]]
86 | if args.cl_default:
87 | transform = transforms.Compose(
88 | [transforms.ToPILImage(),
89 | transforms.RandomCrop(32, padding=4),
90 | transforms.RandomHorizontalFlip(),
91 | transforms.ToTensor(),
92 | transforms.Normalize(*cifar_norm)
93 | ])
94 | else:
95 | transform = transforms.Compose(
96 | [transforms.ToPILImage(),
97 | transforms.RandomResizedCrop(32, scale=(0.08, 1.0), ratio=(3.0/4.0,4.0/3.0), interpolation=Image.BICUBIC),
98 | transforms.RandomHorizontalFlip(),
99 | transforms.ToTensor(),
100 | transforms.Normalize(*cifar_norm)
101 | ])
102 |
103 | return transform
104 |
105 |
106 |
107 | def not_aug_dataloader(self, batch_size):
108 | cifar_norm = [[0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2615]]
109 | transform = transforms.Compose([transforms.ToTensor(),
110 | transforms.Normalize(*cifar_norm)])
111 |
112 | train_dataset = MyCIFAR100(base_path() + 'CIFAR100', train=True,
113 | download=True, transform=transform)
114 | train_loader = get_previous_train_loader(train_dataset, batch_size, self)
115 |
116 | return train_loader
117 |
--------------------------------------------------------------------------------
/datasets/seq_tinyimagenet.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import numpy as np
7 | import torchvision.transforms as transforms
8 | from torch.utils.data import Dataset
9 | import torch.nn.functional as F
10 | from utils.conf import base_path
11 | from PIL import Image
12 | import os
13 | from datasets.utils.validation import get_train_val
14 | from datasets.utils.continual_dataset import ContinualDataset, store_masked_loaders
15 | from datasets.utils.continual_dataset import get_previous_train_loader
16 | from datasets.transforms.denormalization import DeNormalize
17 | from augmentations import get_aug
18 |
19 |
20 | class TinyImagenet(Dataset):
21 | """
22 | Defines Tiny Imagenet as for the others pytorch datasets.
23 | """
24 | def __init__(self, root: str, train: bool=True, transform: transforms=None,
25 | target_transform: transforms=None, download: bool=False) -> None:
26 | self.not_aug_transform = transforms.Compose([transforms.ToTensor()])
27 | self.root = root
28 | self.train = train
29 | self.transform = transform
30 | self.target_transform = target_transform
31 | self.download = download
32 |
33 | if download:
34 | if os.path.isdir(root) and len(os.listdir(root)) > 0:
35 | print('Download not needed, files already on disk.')
36 | else:
37 | import gdown
38 | import zipfile
39 | # https://drive.google.com/file/d/1Sy3ScMBr0F4se8VZ6TAwDYF-nNGAAdxj/view
40 | url = 'https://drive.google.com/uc?id=1Sy3ScMBr0F4se8VZ6TAwDYF-nNGAAdxj'
41 | if not os.path.exists(root): os.makedirs(root)
42 | gdown.download(url, root, quiet=False, fuzzy=True)
43 | with zipfile.ZipFile(os.listdir(root), "r") as f:
44 | f.extractall(path=root)
45 | gdown.extractall(root)
46 |
47 | self.data = []
48 | for num in range(20):
49 | self.data.append(np.load(os.path.join(
50 | root, 'processed/x_%s_%02d.npy' %
51 | ('train' if self.train else 'val', num+1))))
52 | self.data = np.concatenate(np.array(self.data))
53 |
54 | self.targets = []
55 | for num in range(20):
56 | self.targets.append(np.load(os.path.join(
57 | root, 'processed/y_%s_%02d.npy' %
58 | ('train' if self.train else 'val', num+1))))
59 | self.targets = np.concatenate(np.array(self.targets))
60 |
61 | def __len__(self):
62 | return len(self.data)
63 |
64 | def __getitem__(self, index):
65 | img, target = self.data[index], self.targets[index]
66 |
67 | # doing this so that it is consistent with all other datasets
68 | # to return a PIL Image
69 | img = Image.fromarray(np.uint8(255 * img))
70 | original_img = img.copy()
71 |
72 | if self.transform is not None:
73 | img = self.transform(img)
74 |
75 | if self.target_transform is not None:
76 | target = self.target_transform(target)
77 |
78 | if hasattr(self, 'logits'):
79 | return img, target, original_img, self.logits[index]
80 |
81 | return img, target
82 |
83 |
84 | class SequentialTinyImagenet(ContinualDataset):
85 |
86 | NAME = 'seq-tinyimg'
87 | SETTING = 'class-il'
88 | N_CLASSES_PER_TASK = 20
89 | N_TASKS = 10
90 | TRANSFORM = transforms.Compose(
91 | [transforms.RandomCrop(64, padding=4),
92 | transforms.RandomHorizontalFlip(),
93 | transforms.ToTensor(),
94 | transforms.Normalize((0.4802, 0.4480, 0.3975),
95 | (0.2770, 0.2691, 0.2821))])
96 |
97 | def get_data_loaders(self, args):
98 | imagenet_norm = [[0.4802, 0.4480, 0.3975], [0.2770, 0.2691, 0.2821]]
99 | transform = get_aug(train=True, mean_std=imagenet_norm, **args.aug_kwargs)
100 | test_transform = get_aug(train=False, train_classifier=False, mean_std=imagenet_norm, **args.aug_kwargs)
101 |
102 | if args.server:
103 | train_dataset = TinyImagenet('/tinyimg_data', train=True,
104 | download=False, transform=transform)
105 | memory_dataset = TinyImagenet('/tinyimg_data', train=True,
106 | download=False, transform=test_transform)
107 | else:
108 | train_dataset = TinyImagenet(base_path() + 'TINYIMG',
109 | train=True, download=True, transform=transform)
110 |
111 | memory_dataset = TinyImagenet(base_path() + 'TINYIMG',
112 | train=True, download=True, transform=test_transform)
113 | if self.args.validation:
114 | train_dataset, test_dataset = get_train_val(train_dataset,
115 | test_transform, self.NAME)
116 | memory_dataset, _ = get_train_val(memory_dataset, test_transform, self.NAME)
117 | else:
118 | if args.server:
119 | test_dataset = TinyImagenet('/tinyimg_data', train=False,
120 | download=False, transform=test_transform)
121 | else:
122 | test_dataset = TinyImagenet(base_path() + 'TINYIMG',
123 | train=False, download=True, transform=test_transform)
124 |
125 | train, memory, test = store_masked_loaders(train_dataset, test_dataset, memory_dataset, self)
126 | return train, memory, test
127 |
128 | def get_transform(self, args):
129 | imagenet_norm = [[0.4802, 0.4480, 0.3975], [0.2770, 0.2691, 0.2821]]
130 | if args.cl_default:
131 | transform = transforms.Compose(
132 | [transforms.ToPILImage(),
133 | transforms.RandomCrop(64, padding=4),
134 | transforms.RandomHorizontalFlip(),
135 | transforms.ToTensor(),
136 | transforms.Normalize(*imagenet_norm)
137 | ])
138 | else:
139 | transform = transforms.Compose(
140 | [transforms.ToPILImage(),
141 | transforms.RandomResizedCrop(64, scale=(0.08, 1.0), ratio=(3.0/4.0,4.0/3.0), interpolation=Image.BICUBIC),
142 | transforms.RandomHorizontalFlip(),
143 | transforms.ToTensor(),
144 | transforms.Normalize(*imagenet_norm)
145 | ])
146 |
147 | return transform
148 |
149 | def not_aug_dataloader(self, batch_size):
150 | imagenet_norm = [[0.4802, 0.4480, 0.3975], [0.2770, 0.2691, 0.2821]]
151 | transform = transforms.Compose([transforms.ToTensor(),
152 | transforms.Normalize(*imagenet_norm)])
153 |
154 | train_dataset = TinyImagenet(base_path() + 'TINYIMG',
155 | train=True, download=True, transform=transform)
156 | train_loader = get_previous_train_loader(train_dataset, batch_size, self)
157 |
158 | return train_loader
159 |
--------------------------------------------------------------------------------
/datasets/test/seq-cifar10.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/HCL/4aa98387995946061ade7b3eede3715e7c29c1ba/datasets/test/seq-cifar10.pt
--------------------------------------------------------------------------------
/datasets/test/seq-cifar100.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/HCL/4aa98387995946061ade7b3eede3715e7c29c1ba/datasets/test/seq-cifar100.pt
--------------------------------------------------------------------------------
/datasets/test/seq-domainnet.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/HCL/4aa98387995946061ade7b3eede3715e7c29c1ba/datasets/test/seq-domainnet.pt
--------------------------------------------------------------------------------
/datasets/test/seq-tinyimg.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/HCL/4aa98387995946061ade7b3eede3715e7c29c1ba/datasets/test/seq-tinyimg.pt
--------------------------------------------------------------------------------
/datasets/transforms/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/HCL/4aa98387995946061ade7b3eede3715e7c29c1ba/datasets/transforms/__init__.py
--------------------------------------------------------------------------------
/datasets/transforms/denormalization.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 |
7 | class DeNormalize(object):
8 | def __init__(self, mean, std):
9 | self.mean = mean
10 | self.std = std
11 |
12 | def __call__(self, tensor):
13 | """
14 | Args:
15 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
16 | Returns:
17 | Tensor: Normalized image.
18 | """
19 | for t, m, s in zip(tensor, self.mean, self.std):
20 | t.mul_(s).add_(m)
21 | return tensor
22 |
--------------------------------------------------------------------------------
/datasets/transforms/permutation.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import numpy as np
7 |
8 |
9 | class Permutation(object):
10 | """
11 | Defines a fixed permutation for a numpy array.
12 | """
13 | def __init__(self) -> None:
14 | """
15 | Initializes the permutation.
16 | """
17 | self.perm = None
18 |
19 | def __call__(self, sample: np.ndarray) -> np.ndarray:
20 | """
21 | Randomly defines the permutation and applies the transformation.
22 | :param sample: image to be permuted
23 | :return: permuted image
24 | """
25 | old_shape = sample.shape
26 | if self.perm is None:
27 | self.perm = np.random.permutation(len(sample.flatten()))
28 |
29 | return sample.flatten()[self.perm].reshape(old_shape)
30 |
31 |
32 | class FixedPermutation(object):
33 | """
34 | Defines a fixed permutation (given the seed) for a numpy array.
35 | """
36 | def __init__(self, seed: int) -> None:
37 | """
38 | Defines the seed.
39 | :param seed: seed of the permutation
40 | """
41 | self.perm = None
42 | self.seed = seed
43 |
44 | def __call__(self, sample: np.ndarray) -> np.ndarray:
45 | """
46 | Defines the permutation and applies the transformation.
47 | :param sample: image to be permuted
48 | :return: permuted image
49 | """
50 | old_shape = sample.shape
51 | if self.perm is None:
52 | np.random.seed(self.seed)
53 | self.perm = np.random.permutation(len(sample.flatten()))
54 |
55 | return sample.flatten()[self.perm].reshape(old_shape)
56 |
--------------------------------------------------------------------------------
/datasets/transforms/rotation.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import numpy as np
7 | import torchvision.transforms.functional as F
8 |
9 |
10 | class Rotation(object):
11 | """
12 | Defines a fixed rotation for a numpy array.
13 | """
14 |
15 | def __init__(self, deg_min: int = 0, deg_max: int = 180) -> None:
16 | """
17 | Initializes the rotation with a random angle.
18 | :param deg_min: lower extreme of the possible random angle
19 | :param deg_max: upper extreme of the possible random angle
20 | """
21 | self.deg_min = deg_min
22 | self.deg_max = deg_max
23 | self.degrees = np.random.uniform(self.deg_min, self.deg_max)
24 |
25 | def __call__(self, x: np.ndarray) -> np.ndarray:
26 | """
27 | Applies the rotation.
28 | :param x: image to be rotated
29 | :return: rotated image
30 | """
31 | return F.rotate(x, self.degrees)
32 |
33 |
34 | class FixedRotation(object):
35 | """
36 | Defines a fixed rotation for a numpy array.
37 | """
38 |
39 | def __init__(self, seed: int, deg_min: int = 0, deg_max: int = 180) -> None:
40 | """
41 | Initializes the rotation with a random angle.
42 | :param seed: seed of the rotation
43 | :param deg_min: lower extreme of the possible random angle
44 | :param deg_max: upper extreme of the possible random angle
45 | """
46 | self.seed = seed
47 | self.deg_min = deg_min
48 | self.deg_max = deg_max
49 |
50 | np.random.seed(seed)
51 | self.degrees = np.random.uniform(self.deg_min, self.deg_max)
52 |
53 | def __call__(self, x: np.ndarray) -> np.ndarray:
54 | """
55 | Applies the rotation.
56 | :param x: image to be rotated
57 | :return: rotated image
58 | """
59 | return F.rotate(x, self.degrees)
60 |
61 |
62 | class IncrementalRotation(object):
63 | """
64 | Defines an incremental rotation for a numpy array.
65 | """
66 |
67 | def __init__(self, init_deg: int = 0, increase_per_iteration: float = 0.006) -> None:
68 | """
69 | Defines the initial angle as well as the increase for each rotation
70 | :param init_deg:
71 | :param increase_per_iteration:
72 | """
73 | self.increase_per_iteration = increase_per_iteration
74 | self.iteration = 0
75 | self.degrees = init_deg
76 |
77 | def __call__(self, x: np.ndarray) -> np.ndarray:
78 | """
79 | Applies the rotation.
80 | :param x: image to be rotated
81 | :return: rotated image
82 | """
83 | degs = (self.iteration * self.increase_per_iteration + self.degrees) % 360
84 | self.iteration += 1
85 | return F.rotate(x, degs)
86 |
87 | def set_iteration(self, x: int) -> None:
88 | """
89 | Set the iteration to a given integer
90 | :param x: iteration index
91 | """
92 | self.iteration = x
93 |
--------------------------------------------------------------------------------
/datasets/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/HCL/4aa98387995946061ade7b3eede3715e7c29c1ba/datasets/utils/__init__.py
--------------------------------------------------------------------------------
/datasets/utils/continual_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from abc import abstractmethod
7 | from argparse import Namespace
8 | from torch import nn as nn
9 | from torchvision.transforms import transforms
10 | from torch.utils.data import DataLoader
11 | from typing import Tuple
12 | from torchvision import datasets
13 | import numpy as np
14 |
15 |
16 | class ContinualDataset:
17 | """
18 | Continual learning evaluation setting.
19 | """
20 | NAME = None
21 | SETTING = None
22 | N_CLASSES_PER_TASK = None
23 | N_TASKS = None
24 | TRANSFORM = None
25 |
26 | def __init__(self, args: Namespace) -> None:
27 | """
28 | Initializes the train and test lists of dataloaders.
29 | :param args: the arguments which contains the hyperparameters
30 | """
31 | self.train_loader = None
32 | self.test_loaders = []
33 | self.memory_loaders = []
34 | self.train_loaders = []
35 | self.i = 0
36 | self.args = args
37 |
38 | @abstractmethod
39 | def get_data_loaders(self) -> Tuple[DataLoader, DataLoader]:
40 | """
41 | Creates and returns the training and test loaders for the current task.
42 | The current training loader and all test loaders are stored in self.
43 | :return: the current training and test loaders
44 | """
45 | pass
46 |
47 | @abstractmethod
48 | def not_aug_dataloader(self, batch_size: int) -> DataLoader:
49 | """
50 | Returns the dataloader of the current task,
51 | not applying data augmentation.
52 | :param batch_size: the batch size of the loader
53 | :return: the current training loader
54 | """
55 | pass
56 |
57 | @staticmethod
58 | @abstractmethod
59 | def get_backbone() -> nn.Module:
60 | """
61 | Returns the backbone to be used for to the current dataset.
62 | """
63 | pass
64 |
65 | @staticmethod
66 | @abstractmethod
67 | def get_transform() -> transforms:
68 | """
69 | Returns the transform to be used for to the current dataset.
70 | """
71 | pass
72 |
73 | @staticmethod
74 | @abstractmethod
75 | def get_loss() -> nn.functional:
76 | """
77 | Returns the loss to be used for to the current dataset.
78 | """
79 | pass
80 |
81 | @staticmethod
82 | @abstractmethod
83 | def get_normalization_transform() -> transforms:
84 | """
85 | Returns the transform used for normalizing the current dataset.
86 | """
87 | pass
88 |
89 | @staticmethod
90 | @abstractmethod
91 | def get_denormalization_transform() -> transforms:
92 | """
93 | Returns the transform used for denormalizing the current dataset.
94 | """
95 | pass
96 |
97 |
98 | def store_masked_loaders(train_dataset: datasets, test_dataset: datasets, memory_dataset: datasets,
99 | setting: ContinualDataset) -> Tuple[DataLoader, DataLoader]:
100 | """
101 | Divides the dataset into tasks.
102 | :param train_dataset: train dataset
103 | :param test_dataset: test dataset
104 | :param setting: continual learning setting
105 | :return: train and test loaders
106 | """
107 | train_mask = np.logical_and(np.array(train_dataset.targets) >= setting.i,
108 | np.array(train_dataset.targets) < setting.i + setting.N_CLASSES_PER_TASK)
109 | test_mask = np.logical_and(np.array(test_dataset.targets) >= setting.i,
110 | np.array(test_dataset.targets) < setting.i + setting.N_CLASSES_PER_TASK)
111 |
112 | train_dataset.data = train_dataset.data[train_mask]
113 | test_dataset.data = test_dataset.data[test_mask]
114 |
115 | train_dataset.targets = np.array(train_dataset.targets)[train_mask]
116 | test_dataset.targets = np.array(test_dataset.targets)[test_mask]
117 |
118 | memory_dataset.data = memory_dataset.data[train_mask]
119 | memory_dataset.targets = np.array(memory_dataset.targets)[train_mask]
120 |
121 | train_loader = DataLoader(train_dataset,
122 | batch_size=setting.args.train.batch_size, shuffle=True, num_workers=4, pin_memory=True)
123 | test_loader = DataLoader(test_dataset,
124 | batch_size=setting.args.train.batch_size, shuffle=False, num_workers=4, pin_memory=True)
125 | memory_loader = DataLoader(memory_dataset,
126 | batch_size=setting.args.train.batch_size, shuffle=False, num_workers=4)
127 |
128 | setting.test_loaders.append(test_loader)
129 | setting.train_loaders.append(train_loader)
130 | setting.memory_loaders.append(memory_loader)
131 | setting.train_loader = train_loader
132 |
133 | setting.i += setting.N_CLASSES_PER_TASK
134 | return train_loader, memory_loader, test_loader
135 |
136 |
137 | def store_masked_label_loaders(train_dataset: datasets, test_dataset: datasets, memory_dataset: datasets,
138 | setting: ContinualDataset) -> Tuple[DataLoader, DataLoader]:
139 | """
140 | Divides the dataset into tasks.
141 | :param train_dataset: train dataset
142 | :param test_dataset: test dataset
143 | :param setting: continual learning setting
144 | :return: train and test loaders
145 | """
146 | train_mask = np.logical_and(np.array(train_dataset.labels) >= setting.i,
147 | np.array(train_dataset.labels) < setting.i + setting.N_CLASSES_PER_TASK)
148 | test_mask = np.logical_and(np.array(test_dataset.labels) >= setting.i,
149 | np.array(test_dataset.labels) < setting.i + setting.N_CLASSES_PER_TASK)
150 |
151 | train_dataset.data = train_dataset.data[train_mask]
152 | test_dataset.data = test_dataset.data[test_mask]
153 |
154 | train_dataset.targets = np.array(train_dataset.labels)[train_mask]
155 | test_dataset.targets = np.array(test_dataset.labels)[test_mask]
156 |
157 | memory_dataset.data = memory_dataset.data[train_mask]
158 | memory_dataset.targets = np.array(memory_dataset.labels)[train_mask]
159 |
160 | train_loader = DataLoader(train_dataset,
161 | batch_size=setting.args.train.batch_size, shuffle=True, num_workers=4)
162 | test_loader = DataLoader(test_dataset,
163 | batch_size=setting.args.train.batch_size, shuffle=False, num_workers=4)
164 | memory_loader = DataLoader(memory_dataset,
165 | batch_size=setting.args.train.batch_size, shuffle=False, num_workers=4)
166 |
167 | setting.test_loaders.append(test_loader)
168 | setting.train_loaders.append(train_loader)
169 | setting.memory_loaders.append(memory_loader)
170 | setting.train_loader = train_loader
171 |
172 | setting.i += setting.N_CLASSES_PER_TASK
173 | return train_loader, memory_loader, test_loader
174 |
175 | def store_domain_loaders(train_dataset: datasets, test_dataset: datasets, memory_dataset: datasets,
176 | setting: ContinualDataset) -> Tuple[DataLoader, DataLoader]:
177 | """
178 | Divides the dataset into tasks.
179 | :param train_dataset: train dataset
180 | :param test_dataset: test dataset
181 | :param setting: continual learning setting
182 | :return: train and test loaders
183 | """
184 | train_loader = DataLoader(train_dataset,
185 | batch_size=setting.args.train.batch_size, shuffle=True, num_workers=4, pin_memory=True)
186 | test_loader = DataLoader(test_dataset,
187 | batch_size=setting.args.train.batch_size, shuffle=False, num_workers=4, pin_memory=True)
188 | memory_loader = DataLoader(memory_dataset,
189 | batch_size=setting.args.train.batch_size, shuffle=False, num_workers=4)
190 |
191 | setting.test_loaders.append(test_loader)
192 | setting.train_loaders.append(train_loader)
193 | setting.memory_loaders.append(memory_loader)
194 | setting.train_loader = train_loader
195 |
196 | # setting.i += setting.N_CLASSES_PER_TASK
197 | return train_loader, memory_loader, test_loader
198 |
199 |
200 |
201 |
202 | def get_previous_train_loader(train_dataset: datasets, batch_size: int,
203 | setting: ContinualDataset) -> DataLoader:
204 | """
205 | Creates a dataloader for the previous task.
206 | :param train_dataset: the entire training set
207 | :param batch_size: the desired batch size
208 | :param setting: the continual dataset at hand
209 | :return: a dataloader
210 | """
211 | train_mask = np.logical_and(np.array(train_dataset.targets) >=
212 | setting.i - setting.N_CLASSES_PER_TASK, np.array(train_dataset.targets)
213 | < setting.i - setting.N_CLASSES_PER_TASK + setting.N_CLASSES_PER_TASK)
214 |
215 | train_dataset.data = train_dataset.data[train_mask]
216 | train_dataset.targets = np.array(train_dataset.targets)[train_mask]
217 |
218 | return DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
219 |
--------------------------------------------------------------------------------
/datasets/utils/validation.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 | from PIL import Image
8 | import numpy as np
9 | import os
10 | from utils import create_if_not_exists
11 | import torchvision.transforms.transforms as transforms
12 | from torchvision import datasets
13 |
14 |
15 | class ValidationDataset(torch.utils.data.Dataset):
16 | def __init__(self, data: torch.Tensor, targets: np.ndarray,
17 | transform: transforms=None, target_transform: transforms=None) -> None:
18 | self.data = data
19 | self.targets = targets
20 | self.transform = transform
21 | self.target_transform = target_transform
22 |
23 | def __len__(self):
24 | return self.data.shape[0]
25 |
26 | def __getitem__(self, index):
27 | img, target = self.data[index], self.targets[index]
28 |
29 | # doing this so that it is consistent with all other datasets
30 | # to return a PIL Image
31 | if isinstance(img, np.ndarray):
32 | if np.max(img) < 2:
33 | img = Image.fromarray(np.uint8(img * 255))
34 | else:
35 | img = Image.fromarray(img)
36 | else:
37 | img = Image.fromarray(img.numpy())
38 |
39 | if self.transform is not None:
40 | img = self.transform(img)
41 |
42 | if self.target_transform is not None:
43 | target = self.target_transform(target)
44 |
45 | return img, target
46 |
47 | def get_train_val(train: datasets, test_transform: transforms,
48 | dataset: str, val_perc: float=0.1):
49 | """
50 | Extract val_perc% of the training set as the validation set.
51 | :param train: training dataset
52 | :param test_transform: transformation of the test dataset
53 | :param dataset: dataset name
54 | :param val_perc: percentage of the training set to be extracted
55 | :return: the training set and the validation set
56 | """
57 | dataset_length = train.data.shape[0]
58 | directory = 'datasets/val_permutations/'
59 | create_if_not_exists(directory)
60 | file_name = dataset + '.pt'
61 | if os.path.exists(directory + file_name):
62 | perm = torch.load(directory + file_name)
63 | else:
64 | perm = torch.randperm(dataset_length)
65 | torch.save(perm, directory + file_name)
66 | train.data = train.data[perm]
67 | train.targets = np.array(train.targets)[perm]
68 | test_dataset = ValidationDataset(train.data[:int(val_perc * dataset_length)],
69 | train.targets[:int(val_perc * dataset_length)],
70 | transform=test_transform)
71 | train.data = train.data[int(val_perc * dataset_length):]
72 | train.targets = train.targets[int(val_perc * dataset_length):]
73 |
74 | return train, test_dataset
75 |
--------------------------------------------------------------------------------
/linear_eval_alltasks.py:
--------------------------------------------------------------------------------
1 | """
2 | Evaluation script
3 |
4 | Originated from https://github.com/divyam3897/UCL/blob/main/linear_eval_alltasks.py
5 |
6 | Hacked together by / Copyright 2023 Divyam Madaan (https://github.com/divyam3897)
7 | """
8 | import os
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | import torchvision
13 | from tqdm import tqdm
14 | from arguments import get_args
15 | from augmentations import get_aug
16 | from models import get_model, get_backbone
17 | from tools import AverageMeter, knn_monitor
18 | from datasets import get_dataset
19 | from models.optimizers import get_optimizer, LR_Scheduler
20 | from utils.loggers import *
21 | from utils.metrics import forgetting
22 |
23 |
24 | def evaluate_single(model, dataset, test_loader, memory_loader, device, k, last=False) -> Tuple[list, list, list, list]:
25 | accs, accs_mask_classes = [], []
26 | knn_accs, knn_accs_mask_classes = [], []
27 | correct = correct_mask_classes = total = 0
28 | knn_acc, knn_acc_mask = knn_monitor(model.net.module.backbone, dataset, memory_loader, test_loader, device, args.cl_default, task_id=k, k=min(args.train.knn_k, len(dataset.memory_loaders[k].dataset)))
29 |
30 | return knn_acc
31 |
32 |
33 | def evaluate(model, dataset, device, classifier=None, last=False) -> Tuple[list, list]:
34 | """
35 | Evaluates the accuracy of the model for each past task.
36 | :param model: the model to be evaluated
37 | :param dataset: the continual dataset at hand
38 | :return: a tuple of lists, containing the class-il
39 | and task-il accuracy for each task
40 | """
41 | status = model.training
42 | model.eval()
43 | accs, accs_mask_classes = [], []
44 | for k, test_loader in enumerate(dataset.test_loaders):
45 | if last and k < len(dataset.test_loaders) - 1:
46 | continue
47 | correct, correct_mask_classes, total = 0.0, 0.0, 0.0
48 | for data in test_loader:
49 | inputs, labels = data
50 | inputs, labels = inputs.to(device), labels.to(device)
51 | outputs = model(inputs)
52 | if classifier is not None:
53 | outputs = classifier(outputs)
54 |
55 | _, pred = torch.max(outputs.data, 1)
56 | correct += torch.sum(pred == labels).item()
57 | total += labels.shape[0]
58 |
59 | if dataset.SETTING == 'class-il':
60 | mask_classes(outputs, dataset, k)
61 | _, pred = torch.max(outputs.data, 1)
62 | correct_mask_classes += torch.sum(pred == labels).item()
63 |
64 | accs.append(correct / total * 100)
65 | accs_mask_classes.append(correct_mask_classes / total * 100)
66 |
67 | model.train(status)
68 | return accs, accs_mask_classes
69 |
70 |
71 | def main(device, args):
72 |
73 | dataset = get_dataset(args)
74 |
75 | results, results_mask_classes = [], []
76 | for t in tqdm(range(0, dataset.N_TASKS), desc='Evaluatinng'):
77 | train_loader, memory_loader, test_loader = dataset.get_data_loaders(args)
78 | model_path = os.path.join(args.ckpt_dir, f"{args.model.cl_model}_{args.name}_{t}.pth")
79 | save_dict = torch.load(model_path, map_location='cpu')
80 | mean_norm = [[0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2615]]
81 | model = get_model(args, device, len(train_loader), get_aug(train=False, train_classifier=False, mean_std=mean_norm), task_id=t)
82 |
83 | msg = model.net.module.backbone.load_state_dict({k[16:]:v for k, v in save_dict['state_dict'].items() if 'backbone.' in k}, strict=True)
84 | model = model.to(args.device)
85 |
86 | accs = evaluate(model.net.module.backbone, dataset, device)
87 | results.append(accs[0])
88 | results_mask_classes.append(accs[1])
89 | mean_acc = np.mean(accs, axis=1)
90 | print_mean_accuracy(mean_acc, t + 1, dataset.SETTING)
91 |
92 | ci_mean_fgt = forgetting(results)
93 | ti_mean_fgt = forgetting(results_mask_classes)
94 | print(f'CI Forgetting: {ci_mean_fgt} \t TI Forgetting: {ti_mean_fgt}')
95 |
96 |
97 | if __name__ == "__main__":
98 | args = get_args()
99 | main(device=args.device, args=args)
100 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | """
2 | Training script
3 |
4 | Originated from https://github.com/divyam3897/UCL/blob/main/main.py
5 |
6 | Hacked together by / Copyright 2023 Divyam Madaan (https://github.com/divyam3897)
7 | """
8 | import os
9 | import copy
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 | import torchvision
14 | import numpy as np
15 | from tqdm import tqdm
16 | from arguments import get_args
17 | from augmentations import get_aug
18 | from models import get_model
19 | from tools import AverageMeter, Logger, file_exist_check
20 | from datasets import get_dataset
21 | from datetime import datetime
22 | from utils.loggers import *
23 | from utils.metrics import mask_classes
24 | from utils.loggers import CsvLogger
25 | from datasets.utils.continual_dataset import ContinualDataset
26 | from models.utils.continual_model import ContinualModel
27 | from utils.tb_logger import TensorboardLogger
28 | from typing import Tuple
29 | from datasets import BACKBONES
30 | import wandb
31 | from pytorch_model_summary import summary
32 |
33 |
34 | def evaluate(model: ContinualModel, dataset: ContinualDataset, device, classifier=None, last=False) -> Tuple[list, list]:
35 | """
36 | Evaluates the accuracy of the model for each past task.
37 | :param model: the model to be evaluated
38 | :param dataset: the continual dataset at hand
39 | :return: a tuple of lists, containing the class-il
40 | and task-il accuracy for each task
41 | """
42 | status = model.training
43 | model.eval()
44 | accs, accs_mask_classes = [], []
45 | for k, test_loader in enumerate(dataset.test_loaders):
46 | if last and k < len(dataset.test_loaders) - 1:
47 | continue
48 | correct, correct_mask_classes, total = 0.0, 0.0, 0.0
49 | for data in test_loader:
50 | inputs, labels = data
51 | inputs, labels = inputs.to(device), labels.to(device)
52 | outputs = model(inputs)
53 | if classifier is not None:
54 | outputs = classifier(outputs)
55 |
56 | _, pred = torch.max(outputs.data, 1)
57 | correct += torch.sum(pred == labels).item()
58 | total += labels.shape[0]
59 |
60 | if dataset.SETTING == 'class-il':
61 | mask_classes(outputs, dataset, k)
62 | _, pred = torch.max(outputs.data, 1)
63 | correct_mask_classes += torch.sum(pred == labels).item()
64 |
65 | accs.append(correct / total * 100)
66 | accs_mask_classes.append(correct_mask_classes / total * 100)
67 |
68 | model.train(status)
69 | return accs, accs_mask_classes
70 |
71 |
72 | def main(device, args):
73 |
74 | dataset = get_dataset(args)
75 | dataset_copy = get_dataset(args)
76 | train_loader, memory_loader, test_loader = dataset_copy.get_data_loaders(args)
77 | wandb.init(project="poc_lwf", sync_tensorboard=True)
78 | wandb.run.name = f"{args.model.cl_model}_{args.dataset.name}_n_alpha_{args.alpha}"
79 |
80 | # define model
81 | global_model = get_model(args, device, dataset_copy, dataset.get_transform(args), global_model=None)
82 | model = get_model(args, device, dataset_copy, dataset.get_transform(args), global_model=global_model)
83 |
84 | logger = Logger(matplotlib=args.logger.matplotlib, log_dir=args.log_dir)
85 | tb_logger = TensorboardLogger(args, dataset.SETTING)
86 | csv_logger = CsvLogger(dataset.SETTING, dataset.NAME, args.model.backbone)
87 | accuracy = 0
88 | results, results_mask_classes = [], []
89 |
90 | for t in range(dataset.N_TASKS):
91 | train_loader, memory_loader, test_loader = dataset.get_data_loaders(args)
92 |
93 | global_progress = tqdm(range(0, args.train.stop_at_epoch), desc=f'Training')
94 | prev_mean_acc = 0.
95 | best_epoch = 0.
96 |
97 | if args.hcl and BACKBONES[args.dataset.name][t] != BACKBONES[args.dataset.name][t - 1]:
98 | model = get_model(args, device, dataset_copy, dataset.get_transform(args), task_id=t, global_model=global_model)
99 | print(summary(model.net.module.backbone, torch.zeros((1, 3, args.dataset.image_size, args.dataset.image_size)).to(device), show_input=True))
100 |
101 | if hasattr(model, 'begin_task'):
102 | model.begin_task(t, dataset)
103 |
104 | if t:
105 | accs = evaluate(model, dataset, device, last=True)
106 | results[t-1] = results[t-1] + accs[0]
107 | results_mask_classes[t-1] = results_mask_classes[t-1] + accs[1]
108 |
109 | for epoch in global_progress:
110 | model.train()
111 |
112 | local_progress=tqdm(train_loader, desc=f'Epoch {epoch}/{args.train.num_epochs}', disable=args.hide_progress)
113 | for idx, data in enumerate(local_progress):
114 | (images1, images2, notaug_images), labels = data
115 | data_dict = model.observe(images1, labels, images2, notaug_images, t)
116 |
117 | logger.update_scalers(data_dict)
118 | tb_logger.log_loss(data_dict['loss'], args, epoch, t, idx)
119 | tb_logger.log_penalty(data_dict['penalty'], args, epoch, t, idx)
120 | tb_logger.log_lr(data_dict['lr'], args, epoch, t, idx)
121 |
122 | global_progress.set_postfix(data_dict)
123 |
124 | accs = evaluate(model.net.module.backbone, dataset, device)
125 | mean_acc = np.mean(accs, axis=1)
126 |
127 | epoch_dict = {"epoch":epoch, "accuracy": mean_acc}
128 | global_progress.set_postfix(epoch_dict)
129 | logger.update_scalers(epoch_dict)
130 | tb_logger.log_accuracy(accs, mean_acc, args, t)
131 |
132 | if (sum(mean_acc)/2.) - prev_mean_acc < -0.2:
133 | continue
134 | if args.cl_default:
135 | best_model = copy.deepcopy(model.net.module.backbone)
136 | else:
137 | best_model = copy.deepcopy(model.net.module)
138 | prev_mean_acc = sum(mean_acc)/2.
139 | best_epoch = epoch
140 |
141 | accs = evaluate(best_model, dataset, device)
142 | results.append(accs[0])
143 | results_mask_classes.append(accs[1])
144 | mean_acc = np.mean(accs, axis=1)
145 | print_mean_accuracy(mean_acc, t + 1, dataset.SETTING)
146 |
147 | if args.cl_default:
148 | model.global_model.net.module.backbone = copy.deepcopy(best_model)
149 | else:
150 | model.global_model.net.module = copy.deepcopy(best_model)
151 | print(f"Updated global model at epoch {best_epoch} with accuracy {prev_mean_acc}.")
152 |
153 | model_path = os.path.join(args.ckpt_dir, f"{args.model.cl_model}_{args.name}_{t}.pth")
154 | torch.save({
155 | 'epoch': best_epoch+1,
156 | 'state_dict': model.global_model.net.state_dict(),
157 | }, model_path)
158 | print(f"Task Model saved to {model_path}")
159 | with open(os.path.join(args.log_dir, f"checkpoint_path.txt"), 'w+') as f:
160 | f.write(f'{model_path}')
161 |
162 | if hasattr(model, 'end_task'):
163 | model.end_task(dataset)
164 |
165 | csv_logger.add_bwt(results, results_mask_classes)
166 | csv_logger.add_forgetting(results, results_mask_classes)
167 | csv_logger.write(args.ckpt_dir, vars(args))
168 | tb_logger.close()
169 | if args.eval is not False and args.cl_default is False:
170 | args.eval_from = model_path
171 |
172 | if __name__ == "__main__":
173 | args = get_args()
174 | main(device=args.device, args=args)
175 | completed_log_dir = args.log_dir.replace('in-progress', 'debug' if args.debug else 'completed')
176 | os.rename(args.log_dir, completed_log_dir)
177 | print(f'Log file has been saved to {completed_log_dir}')
178 |
179 |
180 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | import importlib
3 | from .simsiam import SimSiam
4 | import torch
5 | from .backbones import resnet18, lenet, vgg16, alexnet, densenet, senet, regnet, inception, swin, resnext
6 | from datasets import N_CLASSES, BACKBONES
7 | from utils.losses import LabelSmoothing, KL_div_Loss
8 |
9 |
10 | def get_backbone(args, task_id=0):
11 | if args.hcl:
12 | backbone = BACKBONES[args.dataset.name][task_id]
13 | else:
14 | backbone = args.model.backbone
15 |
16 | net = eval(f"{backbone}(num_classes=N_CLASSES[args.dataset.name], args=args)")
17 | print("Backbone changed to ", backbone)
18 |
19 | net.n_classes = N_CLASSES[args.dataset.name]
20 | net.output_dim = net.fc.in_features
21 | if not args.cl_default:
22 | net.fc = torch.nn.Identity()
23 |
24 | return net
25 |
26 |
27 | def get_all_models():
28 | return [model.split('.')[0] for model in os.listdir('models')
29 | if not model.find('__') > -1 and 'py' in model]
30 |
31 | def get_model(args, device, dataset, transform, global_model=None, task_id=0):
32 | allowed_models = ["distil", "qdi", "distilbuf"]
33 | if args.model.cl_model in allowed_models:
34 | loss = LabelSmoothing(smoothing=0.1)
35 | else:
36 | loss = torch.nn.CrossEntropyLoss()
37 | if args.model.name == 'simsiam':
38 | backbone = SimSiam(get_backbone(args, task_id=task_id)).to(device)
39 | if args.model.proj_layers is not None:
40 | backbone.projector.set_layers(args.model.proj_layers)
41 |
42 | names = {}
43 | for model in get_all_models():
44 | mod = importlib.import_module('models.' + model)
45 | class_name = {x.lower():x for x in mod.__dir__()}[model.replace('_', '')]
46 | names[model] = getattr(mod, class_name)
47 |
48 | return names[args.model.cl_model](backbone, loss, args, dataset, transform, global_model)
49 |
50 |
--------------------------------------------------------------------------------
/models/backbones/Alexnet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | class AlexNet(nn.Module):
4 | def __init__(self, num_classes, args):
5 | super(AlexNet, self).__init__()
6 | self.features = nn.Sequential(
7 | nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),
8 | nn.ReLU(inplace=True),
9 | nn.MaxPool2d(kernel_size=2),
10 | nn.Conv2d(64, 192, kernel_size=3, padding=1),
11 | nn.ReLU(inplace=True),
12 | nn.MaxPool2d(kernel_size=2),
13 | nn.Conv2d(192, 384, kernel_size=3, padding=1),
14 | nn.ReLU(inplace=True),
15 | nn.Conv2d(384, 256, kernel_size=3, padding=1),
16 | nn.ReLU(inplace=True),
17 | nn.Conv2d(256, 256, kernel_size=3, padding=1),
18 | nn.ReLU(inplace=True),
19 | nn.MaxPool2d(kernel_size=2),
20 | )
21 | self.classifier = nn.Sequential(
22 | nn.Dropout(),
23 | nn.Linear(256 * 2 * 2, 4096),
24 | nn.ReLU(inplace=True),
25 | nn.Dropout(),
26 | nn.Linear(4096, 4096),
27 | nn.ReLU(inplace=True),
28 | )
29 | self.fc = nn.Linear(4096, num_classes)
30 |
31 | def forward(self, x, return_features=False):
32 | x = self.features(x)
33 | x = x.view(x.size(0), 256 * 2 * 2)
34 | x = self.classifier(x)
35 | if return_features:
36 | return x
37 | x = self.fc(x)
38 | return x
39 |
40 |
41 | def alexnet(num_classes, args):
42 | return AlexNet(num_classes, args)
43 |
--------------------------------------------------------------------------------
/models/backbones/Densenet.py:
--------------------------------------------------------------------------------
1 | # Originated from from https://github.com/kuangliu/pytorch-cifar/blob/master/models/densenet.py
2 | import math
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 |
9 | class Bottleneck(nn.Module):
10 | def __init__(self, in_planes, growth_rate):
11 | super(Bottleneck, self).__init__()
12 | self.bn1 = nn.BatchNorm2d(in_planes)
13 | self.conv1 = nn.Conv2d(in_planes, 4*growth_rate, kernel_size=1, bias=False)
14 | self.bn2 = nn.BatchNorm2d(4*growth_rate)
15 | self.conv2 = nn.Conv2d(4*growth_rate, growth_rate, kernel_size=3, padding=1, bias=False)
16 |
17 | def forward(self, x):
18 | out = self.conv1(F.relu(self.bn1(x)))
19 | out = self.conv2(F.relu(self.bn2(out)))
20 | out = torch.cat([out,x], 1)
21 | return out
22 |
23 |
24 | class Transition(nn.Module):
25 | def __init__(self, in_planes, out_planes):
26 | super(Transition, self).__init__()
27 | self.bn = nn.BatchNorm2d(in_planes)
28 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False)
29 |
30 | def forward(self, x):
31 | out = self.conv(F.relu(self.bn(x)))
32 | out = F.avg_pool2d(out, 2)
33 | return out
34 |
35 |
36 | class DenseNet(nn.Module):
37 | def __init__(self, block, nblocks, growth_rate=12, reduction=0.5, num_classes=10, args=None):
38 | super(DenseNet, self).__init__()
39 | self.growth_rate = growth_rate
40 |
41 | num_planes = 2*growth_rate
42 | self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, padding=1, bias=False)
43 |
44 | self.dense1 = self._make_dense_layers(block, num_planes, nblocks[0])
45 | num_planes += nblocks[0]*growth_rate
46 | out_planes = int(math.floor(num_planes*reduction))
47 | self.trans1 = Transition(num_planes, out_planes)
48 | num_planes = out_planes
49 |
50 | self.dense2 = self._make_dense_layers(block, num_planes, nblocks[1])
51 | num_planes += nblocks[1]*growth_rate
52 | out_planes = int(math.floor(num_planes*reduction))
53 | self.trans2 = Transition(num_planes, out_planes)
54 | num_planes = out_planes
55 |
56 | self.dense3 = self._make_dense_layers(block, num_planes, nblocks[2])
57 | num_planes += nblocks[2]*growth_rate
58 | out_planes = int(math.floor(num_planes*reduction))
59 | self.trans3 = Transition(num_planes, out_planes)
60 | num_planes = out_planes
61 |
62 | self.dense4 = self._make_dense_layers(block, num_planes, nblocks[3])
63 | num_planes += nblocks[3]*growth_rate
64 |
65 | self.bn = nn.BatchNorm2d(num_planes)
66 | self.fc = nn.Linear(num_planes, num_classes)
67 |
68 | def _make_dense_layers(self, block, in_planes, nblock):
69 | layers = []
70 | for i in range(nblock):
71 | layers.append(block(in_planes, self.growth_rate))
72 | in_planes += self.growth_rate
73 | return nn.Sequential(*layers)
74 |
75 | def forward(self, x, return_features=False):
76 | out = self.conv1(x)
77 | out = self.trans1(self.dense1(out))
78 | out = self.trans2(self.dense2(out))
79 | out = self.trans3(self.dense3(out))
80 | out = self.dense4(out)
81 | out = F.avg_pool2d(F.relu(self.bn(out)), 4)
82 | out = out.view(out.size(0), -1)
83 | if return_features:
84 | return out
85 | out = self.fc(out)
86 | return out
87 |
88 | def DenseNet121(num_classes, args):
89 | return DenseNet(Bottleneck, [6,12,24,16], growth_rate=32, num_classes=num_classes, args=args)
90 |
91 | def DenseNet169(num_classes, args):
92 | return DenseNet(Bottleneck, [6,12,32,32], growth_rate=32, num_classes=num_classes, args=args)
93 |
94 | def DenseNet201(num_classes, args):
95 | return DenseNet(Bottleneck, [6,12,48,32], growth_rate=32, num_classes=num_classes, args=args)
96 |
97 | def DenseNet161(num_classes, args):
98 | return DenseNet(Bottleneck, [6,12,36,24], growth_rate=48, num_classes=num_classes, args=args)
99 |
100 | def densenet_cifar(num_classes, args):
101 | return DenseNet(Bottleneck, [6,12,24,16], growth_rate=12, num_classes=num_classes, args=args)
102 |
103 | def densenet(num_classes, args):
104 | return DenseNet121(num_classes, args)
105 |
--------------------------------------------------------------------------------
/models/backbones/Inception.py:
--------------------------------------------------------------------------------
1 | # Originated from https://github.com/kuangliu/pytorch-cifar/blob/master/models/googlenet.py
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 |
7 | class Inception(nn.Module):
8 | def __init__(self, in_planes, kernel_1_x, kernel_3_in, kernel_3_x, kernel_5_in, kernel_5_x, pool_planes):
9 | super(Inception, self).__init__()
10 | # 1x1 conv branch
11 | self.b1 = nn.Sequential(
12 | nn.Conv2d(in_planes, kernel_1_x, kernel_size=1),
13 | nn.BatchNorm2d(kernel_1_x),
14 | nn.ReLU(True),
15 | )
16 |
17 | # 1x1 conv -> 3x3 conv branch
18 | self.b2 = nn.Sequential(
19 | nn.Conv2d(in_planes, kernel_3_in, kernel_size=1),
20 | nn.BatchNorm2d(kernel_3_in),
21 | nn.ReLU(True),
22 | nn.Conv2d(kernel_3_in, kernel_3_x, kernel_size=3, padding=1),
23 | nn.BatchNorm2d(kernel_3_x),
24 | nn.ReLU(True),
25 | )
26 |
27 | # 1x1 conv -> 5x5 conv branch
28 | self.b3 = nn.Sequential(
29 | nn.Conv2d(in_planes, kernel_5_in, kernel_size=1),
30 | nn.BatchNorm2d(kernel_5_in),
31 | nn.ReLU(True),
32 | nn.Conv2d(kernel_5_in, kernel_5_x, kernel_size=3, padding=1),
33 | nn.BatchNorm2d(kernel_5_x),
34 | nn.ReLU(True),
35 | nn.Conv2d(kernel_5_x, kernel_5_x, kernel_size=3, padding=1),
36 | nn.BatchNorm2d(kernel_5_x),
37 | nn.ReLU(True),
38 | )
39 |
40 | # 3x3 pool -> 1x1 conv branch
41 | self.b4 = nn.Sequential(
42 | nn.MaxPool2d(3, stride=1, padding=1),
43 | nn.Conv2d(in_planes, pool_planes, kernel_size=1),
44 | nn.BatchNorm2d(pool_planes),
45 | nn.ReLU(True),
46 | )
47 |
48 | def forward(self, x):
49 | y1 = self.b1(x)
50 | y2 = self.b2(x)
51 | y3 = self.b3(x)
52 | y4 = self.b4(x)
53 | return torch.cat([y1,y2,y3,y4], 1)
54 |
55 |
56 | class GoogleNet(nn.Module):
57 | def __init__(self, num_classes, args):
58 | super(GoogleNet, self).__init__()
59 | self.pre_layers = nn.Sequential(
60 | nn.Conv2d(3, 192, kernel_size=3, padding=1),
61 | nn.BatchNorm2d(192),
62 | nn.ReLU(True),
63 | )
64 |
65 | self.a3 = Inception(192, 64, 96, 128, 16, 32, 32)
66 | self.b3 = Inception(256, 128, 128, 192, 32, 96, 64)
67 |
68 | self.max_pool = nn.MaxPool2d(3, stride=2, padding=1)
69 |
70 | self.a4 = Inception(480, 192, 96, 208, 16, 48, 64)
71 | self.b4 = Inception(512, 160, 112, 224, 24, 64, 64)
72 | self.c4 = Inception(512, 128, 128, 256, 24, 64, 64)
73 | self.d4 = Inception(512, 112, 144, 288, 32, 64, 64)
74 | self.e4 = Inception(528, 256, 160, 320, 32, 128, 128)
75 |
76 | self.a5 = Inception(832, 256, 160, 320, 32, 128, 128)
77 | self.b5 = Inception(832, 384, 192, 384, 48, 128, 128)
78 |
79 | self.avgpool = nn.AvgPool2d(8, stride=1)
80 | self.fc = nn.Linear(1024, num_classes)
81 |
82 | def forward(self, x, return_features=False):
83 | x = self.pre_layers(x)
84 | x = self.a3(x)
85 | x = self.b3(x)
86 | x = self.max_pool(x)
87 | x = self.a4(x)
88 | x = self.b4(x)
89 | x = self.c4(x)
90 | x = self.d4(x)
91 | x = self.e4(x)
92 | x = self.max_pool(x)
93 | x = self.a5(x)
94 | x = self.b5(x)
95 | x = self.avgpool(x)
96 | x = x.view(x.size(0), -1)
97 | if return_features:
98 | return x
99 | x = self.fc(x)
100 | return x
101 |
102 | def inception(num_classes, args):
103 | return GoogleNet(num_classes, args)
104 |
--------------------------------------------------------------------------------
/models/backbones/Lenet.py:
--------------------------------------------------------------------------------
1 | ## Originated from https://github.com/kuangliu/pytorch-cifar/blob/master/models/lenet.py
2 |
3 | import torch.nn as nn
4 | import torch.nn.functional as func
5 |
6 |
7 | class LeNet(nn.Module):
8 | def __init__(self, num_classes, args):
9 | super(LeNet, self).__init__()
10 | self.conv1 = nn.Conv2d(3, 6, kernel_size=5)
11 | self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
12 | if args.dataset.image_size == 32:
13 | self.fc1 = nn.Linear(16*5*5, 120)
14 | else:
15 | self.fc1 = nn.Linear(2704, 120)
16 | self.fc2 = nn.Linear(120, 84)
17 | self.fc = nn.Linear(84, num_classes)
18 |
19 | def forward(self, x, return_features=False):
20 | x = func.relu(self.conv1(x))
21 | x = func.max_pool2d(x, 2)
22 | x = func.relu(self.conv2(x))
23 | x = func.max_pool2d(x, 2)
24 | x = x.view(x.size(0), -1)
25 | x = func.relu(self.fc1(x))
26 | x = func.relu(self.fc2(x))
27 | if return_features:
28 | return x
29 | x = self.fc(x)
30 | return x
31 |
32 | def lenet(num_classes, args):
33 | return LeNet(num_classes, args=args)
34 |
--------------------------------------------------------------------------------
/models/backbones/Regnet.py:
--------------------------------------------------------------------------------
1 | '''RegNet in PyTorch.
2 |
3 | Paper: "Designing Network Design Spaces".
4 |
5 | Reference: https://github.com/keras-team/keras-applications/blob/master/keras_applications/efficientnet.py
6 |
7 | Originated from https://github.com/kuangliu/pytorch-cifar/blob/master/models/regnet.py
8 | '''
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 |
13 |
14 | class SE(nn.Module):
15 | '''Squeeze-and-Excitation block.'''
16 |
17 | def __init__(self, in_planes, se_planes):
18 | super(SE, self).__init__()
19 | self.se1 = nn.Conv2d(in_planes, se_planes, kernel_size=1, bias=True)
20 | self.se2 = nn.Conv2d(se_planes, in_planes, kernel_size=1, bias=True)
21 |
22 | def forward(self, x):
23 | out = F.adaptive_avg_pool2d(x, (1, 1))
24 | out = F.relu(self.se1(out))
25 | out = self.se2(out).sigmoid()
26 | out = x * out
27 | return out
28 |
29 |
30 | class Block(nn.Module):
31 | def __init__(self, w_in, w_out, stride, group_width, bottleneck_ratio, se_ratio):
32 | super(Block, self).__init__()
33 | # 1x1
34 | w_b = int(round(w_out * bottleneck_ratio))
35 | self.conv1 = nn.Conv2d(w_in, w_b, kernel_size=1, bias=False)
36 | self.bn1 = nn.BatchNorm2d(w_b)
37 | # 3x3
38 | num_groups = w_b // group_width
39 | self.conv2 = nn.Conv2d(w_b, w_b, kernel_size=3,
40 | stride=stride, padding=1, groups=num_groups, bias=False)
41 | self.bn2 = nn.BatchNorm2d(w_b)
42 | # se
43 | self.with_se = se_ratio > 0
44 | if self.with_se:
45 | w_se = int(round(w_in * se_ratio))
46 | self.se = SE(w_b, w_se)
47 | # 1x1
48 | self.conv3 = nn.Conv2d(w_b, w_out, kernel_size=1, bias=False)
49 | self.bn3 = nn.BatchNorm2d(w_out)
50 |
51 | self.shortcut = nn.Sequential()
52 | if stride != 1 or w_in != w_out:
53 | self.shortcut = nn.Sequential(
54 | nn.Conv2d(w_in, w_out,
55 | kernel_size=1, stride=stride, bias=False),
56 | nn.BatchNorm2d(w_out)
57 | )
58 |
59 | def forward(self, x):
60 | out = F.relu(self.bn1(self.conv1(x)))
61 | out = F.relu(self.bn2(self.conv2(out)))
62 | if self.with_se:
63 | out = self.se(out)
64 | out = self.bn3(self.conv3(out))
65 | out += self.shortcut(x)
66 | out = F.relu(out)
67 | return out
68 |
69 |
70 | class RegNet(nn.Module):
71 | def __init__(self, cfg, num_classes=10):
72 | super(RegNet, self).__init__()
73 | self.cfg = cfg
74 | self.in_planes = 64
75 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
76 | stride=1, padding=1, bias=False)
77 | self.bn1 = nn.BatchNorm2d(64)
78 | self.layer1 = self._make_layer(0)
79 | self.layer2 = self._make_layer(1)
80 | self.layer3 = self._make_layer(2)
81 | self.layer4 = self._make_layer(3)
82 | self.fc = nn.Linear(self.cfg['widths'][-1], num_classes)
83 |
84 | def _make_layer(self, idx):
85 | depth = self.cfg['depths'][idx]
86 | width = self.cfg['widths'][idx]
87 | stride = self.cfg['strides'][idx]
88 | group_width = self.cfg['group_width']
89 | bottleneck_ratio = self.cfg['bottleneck_ratio']
90 | se_ratio = self.cfg['se_ratio']
91 |
92 | layers = []
93 | for i in range(depth):
94 | s = stride if i == 0 else 1
95 | layers.append(Block(self.in_planes, width,
96 | s, group_width, bottleneck_ratio, se_ratio))
97 | self.in_planes = width
98 | return nn.Sequential(*layers)
99 |
100 | def forward(self, x, return_features=False):
101 | out = F.relu(self.bn1(self.conv1(x)))
102 | out = self.layer1(out)
103 | out = self.layer2(out)
104 | out = self.layer3(out)
105 | out = self.layer4(out)
106 | out = F.adaptive_avg_pool2d(out, (1, 1))
107 | out = out.view(out.size(0), -1)
108 | if return_features:
109 | return out
110 | out = self.fc(out)
111 | return out
112 |
113 |
114 | def RegNetX_200MF(num_classes, args):
115 | cfg = {
116 | 'depths': [1, 1, 4, 7],
117 | 'widths': [24, 56, 152, 368],
118 | 'strides': [1, 1, 2, 2],
119 | 'group_width': 8,
120 | 'bottleneck_ratio': 1,
121 | 'se_ratio': 0,
122 | }
123 |
124 | return RegNet(cfg, num_classes)
125 |
126 |
127 | def RegNetX_400MF(num_classes, args):
128 | cfg = {
129 | 'depths': [1, 2, 7, 12],
130 | 'widths': [32, 64, 160, 384],
131 | 'strides': [1, 1, 2, 2],
132 | 'group_width': 16,
133 | 'bottleneck_ratio': 1,
134 | 'se_ratio': 0,
135 | }
136 | return RegNet(cfg, num_classes)
137 |
138 |
139 | def regnet(num_classes, args):
140 | return RegNetX_200MF(num_classes, args)
141 |
142 |
--------------------------------------------------------------------------------
/models/backbones/ResNext.py:
--------------------------------------------------------------------------------
1 | '''ResNeXt in PyTorch.
2 |
3 | See the paper "Aggregated Residual Transformations for Deep Neural Networks" for more details.
4 |
5 | Originated from https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnext.py
6 | '''
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 |
12 | class Block(nn.Module):
13 | '''Grouped convolution block.'''
14 | expansion = 2
15 |
16 | def __init__(self, in_planes, cardinality=32, bottleneck_width=4, stride=1):
17 | super(Block, self).__init__()
18 | group_width = cardinality * bottleneck_width
19 | self.conv1 = nn.Conv2d(in_planes, group_width, kernel_size=1, bias=False)
20 | self.bn1 = nn.BatchNorm2d(group_width)
21 | self.conv2 = nn.Conv2d(group_width, group_width, kernel_size=3, stride=stride, padding=1, groups=cardinality, bias=False)
22 | self.bn2 = nn.BatchNorm2d(group_width)
23 | self.conv3 = nn.Conv2d(group_width, self.expansion*group_width, kernel_size=1, bias=False)
24 | self.bn3 = nn.BatchNorm2d(self.expansion*group_width)
25 |
26 | self.shortcut = nn.Sequential()
27 | if stride != 1 or in_planes != self.expansion*group_width:
28 | self.shortcut = nn.Sequential(
29 | nn.Conv2d(in_planes, self.expansion*group_width, kernel_size=1, stride=stride, bias=False),
30 | nn.BatchNorm2d(self.expansion*group_width)
31 | )
32 |
33 | def forward(self, x):
34 | out = F.relu(self.bn1(self.conv1(x)))
35 | out = F.relu(self.bn2(self.conv2(out)))
36 | out = self.bn3(self.conv3(out))
37 | out += self.shortcut(x)
38 | out = F.relu(out)
39 | return out
40 |
41 |
42 | class ResNeXt(nn.Module):
43 | def __init__(self, num_blocks, cardinality, bottleneck_width, num_classes, args):
44 | super(ResNeXt, self).__init__()
45 | self.cardinality = cardinality
46 | self.bottleneck_width = bottleneck_width
47 | self.in_planes = 64
48 | layer1_stride = 2 if args.dataset.image_size == 64 else 1
49 |
50 | self.conv1 = nn.Conv2d(3, 64, kernel_size=1, bias=False)
51 | self.bn1 = nn.BatchNorm2d(64)
52 | self.layer1 = self._make_layer(num_blocks[0], layer1_stride)
53 | self.layer2 = self._make_layer(num_blocks[1], 2)
54 | self.layer3 = self._make_layer(num_blocks[2], 2)
55 | # self.layer4 = self._make_layer(num_blocks[3], 2)
56 | self.fc = nn.Linear(cardinality*bottleneck_width*8, num_classes)
57 |
58 | def _make_layer(self, num_blocks, stride):
59 | strides = [stride] + [1]*(num_blocks-1)
60 | layers = []
61 | for stride in strides:
62 | layers.append(Block(self.in_planes, self.cardinality, self.bottleneck_width, stride))
63 | self.in_planes = Block.expansion * self.cardinality * self.bottleneck_width
64 | # Increase bottleneck_width by 2 after each stage.
65 | self.bottleneck_width *= 2
66 | return nn.Sequential(*layers)
67 |
68 | def forward(self, x, return_features=False):
69 | out = F.relu(self.bn1(self.conv1(x)))
70 | out = self.layer1(out)
71 | out = self.layer2(out)
72 | out = self.layer3(out)
73 | # out = self.layer4(out)
74 | out = F.avg_pool2d(out, 8)
75 | out = out.view(out.size(0), -1)
76 | if return_features:
77 | return out
78 | out = self.fc(out)
79 | return out
80 |
81 |
82 | def ResNeXt29_2x64d(num_classes, args):
83 | return ResNeXt(num_blocks=[3,3,3], cardinality=2, bottleneck_width=64, num_classes=num_classes, args=args)
84 |
85 | def ResNeXt29_4x64d(num_classes, args):
86 | return ResNeXt(num_blocks=[3,3,3], cardinality=4, bottleneck_width=64, num_classes=num_classes, args=args)
87 |
88 | def ResNeXt29_8x64d(num_classes, args):
89 | return ResNeXt(num_blocks=[3,3,3], cardinality=8, bottleneck_width=64, num_classes=num_classes, args=args)
90 |
91 | def ResNeXt29_32x4d(num_classes, args):
92 | return ResNeXt(num_blocks=[3,3,3], cardinality=32, bottleneck_width=4, num_classes=num_classes, args=args)
93 |
94 | def resnext(num_classes, args):
95 | return ResNeXt29_2x64d(num_classes, args=args)
96 |
--------------------------------------------------------------------------------
/models/backbones/Senet.py:
--------------------------------------------------------------------------------
1 | '''SENet in PyTorch.
2 |
3 | SENet is the winner of ImageNet-2017. The paper is not released yet.
4 |
5 | Originated from https://github.com/kuangliu/pytorch-cifar/blob/master/models/senet.py
6 | '''
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 |
12 | class BasicBlock(nn.Module):
13 | def __init__(self, in_planes, planes, stride=1):
14 | super(BasicBlock, self).__init__()
15 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
16 | self.bn1 = nn.BatchNorm2d(planes)
17 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
18 | self.bn2 = nn.BatchNorm2d(planes)
19 |
20 | self.shortcut = nn.Sequential()
21 | if stride != 1 or in_planes != planes:
22 | self.shortcut = nn.Sequential(
23 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False),
24 | nn.BatchNorm2d(planes)
25 | )
26 |
27 | # SE layers
28 | self.fc1 = nn.Conv2d(planes, planes//16, kernel_size=1) # Use nn.Conv2d instead of nn.Linear
29 | self.fc2 = nn.Conv2d(planes//16, planes, kernel_size=1)
30 |
31 | def forward(self, x):
32 | out = F.relu(self.bn1(self.conv1(x)))
33 | out = self.bn2(self.conv2(out))
34 |
35 | # Squeeze
36 | w = F.avg_pool2d(out, out.size(2))
37 | w = F.relu(self.fc1(w))
38 | w = F.sigmoid(self.fc2(w))
39 | # Excitation
40 | out = out * w # New broadcasting feature from v0.2!
41 |
42 | out += self.shortcut(x)
43 | out = F.relu(out)
44 | return out
45 |
46 |
47 | class PreActBlock(nn.Module):
48 | def __init__(self, in_planes, planes, stride=1):
49 | super(PreActBlock, self).__init__()
50 | self.bn1 = nn.BatchNorm2d(in_planes)
51 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
52 | self.bn2 = nn.BatchNorm2d(planes)
53 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
54 |
55 | if stride != 1 or in_planes != planes:
56 | self.shortcut = nn.Sequential(
57 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False)
58 | )
59 |
60 | # SE layers
61 | self.fc1 = nn.Conv2d(planes, planes//16, kernel_size=1)
62 | self.fc2 = nn.Conv2d(planes//16, planes, kernel_size=1)
63 |
64 | def forward(self, x):
65 | out = F.relu(self.bn1(x))
66 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
67 | out = self.conv1(out)
68 | out = self.conv2(F.relu(self.bn2(out)))
69 |
70 | # Squeeze
71 | w = F.avg_pool2d(out, out.size(2))
72 | w = F.relu(self.fc1(w))
73 | w = F.sigmoid(self.fc2(w))
74 | # Excitation
75 | out = out * w
76 |
77 | out += shortcut
78 | return out
79 |
80 |
81 | class SENet(nn.Module):
82 | def __init__(self, block, num_blocks, num_classes, args):
83 | super(SENet, self).__init__()
84 | self.in_planes = 64
85 | layer1_stride = 2 if args.dataset.image_size == 64 else 1
86 |
87 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
88 | self.bn1 = nn.BatchNorm2d(64)
89 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=layer1_stride)
90 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
91 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
92 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
93 | self.fc = nn.Linear(512, num_classes)
94 |
95 | def _make_layer(self, block, planes, num_blocks, stride):
96 | strides = [stride] + [1]*(num_blocks-1)
97 | layers = []
98 | for stride in strides:
99 | layers.append(block(self.in_planes, planes, stride))
100 | self.in_planes = planes
101 | return nn.Sequential(*layers)
102 |
103 | def forward(self, x, return_features=False):
104 | out = F.relu(self.bn1(self.conv1(x)))
105 | out = self.layer1(out)
106 | out = self.layer2(out)
107 | out = self.layer3(out)
108 | out = self.layer4(out)
109 | out = F.avg_pool2d(out, 4)
110 | out = out.view(out.size(0), -1)
111 | if return_features:
112 | return out
113 | out = self.fc(out)
114 | return out
115 |
116 |
117 | def senet18(num_classes, args):
118 | return SENet(PreActBlock, [2,2,2,2], num_classes, args)
119 |
--------------------------------------------------------------------------------
/models/backbones/Swin.py:
--------------------------------------------------------------------------------
1 | # https://github.com/berniwal/swin-transformer-pytorch
2 |
3 | import torch
4 | from torch import nn, einsum
5 | import numpy as np
6 | from einops import rearrange, repeat
7 |
8 |
9 | class CyclicShift(nn.Module):
10 | def __init__(self, displacement):
11 | super().__init__()
12 | self.displacement = displacement
13 |
14 | def forward(self, x):
15 | return torch.roll(x, shifts=(self.displacement, self.displacement), dims=(1, 2))
16 |
17 |
18 | class Residual(nn.Module):
19 | def __init__(self, fn):
20 | super().__init__()
21 | self.fn = fn
22 |
23 | def forward(self, x, **kwargs):
24 | return self.fn(x, **kwargs) + x
25 |
26 |
27 | class PreNorm(nn.Module):
28 | def __init__(self, dim, fn):
29 | super().__init__()
30 | self.norm = nn.LayerNorm(dim)
31 | self.fn = fn
32 |
33 | def forward(self, x, **kwargs):
34 | return self.fn(self.norm(x), **kwargs)
35 |
36 |
37 | class FeedForward(nn.Module):
38 | def __init__(self, dim, hidden_dim):
39 | super().__init__()
40 | self.net = nn.Sequential(
41 | nn.Linear(dim, hidden_dim),
42 | nn.GELU(),
43 | nn.Linear(hidden_dim, dim),
44 | )
45 |
46 | def forward(self, x):
47 | return self.net(x)
48 |
49 |
50 | def create_mask(window_size, displacement, upper_lower, left_right):
51 | mask = torch.zeros(window_size ** 2, window_size ** 2)
52 |
53 | if upper_lower:
54 | mask[-displacement * window_size:, :-displacement * window_size] = float('-inf')
55 | mask[:-displacement * window_size, -displacement * window_size:] = float('-inf')
56 |
57 | if left_right:
58 | mask = rearrange(mask, '(h1 w1) (h2 w2) -> h1 w1 h2 w2', h1=window_size, h2=window_size)
59 | mask[:, -displacement:, :, :-displacement] = float('-inf')
60 | mask[:, :-displacement, :, -displacement:] = float('-inf')
61 | mask = rearrange(mask, 'h1 w1 h2 w2 -> (h1 w1) (h2 w2)')
62 |
63 | return mask
64 |
65 |
66 | def get_relative_distances(window_size):
67 | indices = torch.tensor(np.array([[x, y] for x in range(window_size) for y in range(window_size)]))
68 | distances = indices[None, :, :] - indices[:, None, :]
69 | return distances
70 |
71 |
72 | class WindowAttention(nn.Module):
73 | def __init__(self, dim, heads, head_dim, shifted, window_size, relative_pos_embedding):
74 | super().__init__()
75 | inner_dim = head_dim * heads
76 |
77 | self.heads = heads
78 | self.scale = head_dim ** -0.5
79 | self.window_size = window_size
80 | self.relative_pos_embedding = relative_pos_embedding
81 | self.shifted = shifted
82 |
83 | if self.shifted:
84 | displacement = window_size // 2
85 | self.cyclic_shift = CyclicShift(-displacement)
86 | self.cyclic_back_shift = CyclicShift(displacement)
87 | self.upper_lower_mask = nn.Parameter(create_mask(window_size=window_size, displacement=displacement,
88 | upper_lower=True, left_right=False), requires_grad=False)
89 | self.left_right_mask = nn.Parameter(create_mask(window_size=window_size, displacement=displacement,
90 | upper_lower=False, left_right=True), requires_grad=False)
91 |
92 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
93 |
94 | if self.relative_pos_embedding:
95 | self.relative_indices = get_relative_distances(window_size) + window_size - 1
96 | self.pos_embedding = nn.Parameter(torch.randn(2 * window_size - 1, 2 * window_size - 1))
97 | else:
98 | self.pos_embedding = nn.Parameter(torch.randn(window_size ** 2, window_size ** 2))
99 |
100 | self.to_out = nn.Linear(inner_dim, dim)
101 |
102 | def forward(self, x):
103 | if self.shifted:
104 | x = self.cyclic_shift(x)
105 |
106 | b, n_h, n_w, _, h = *x.shape, self.heads
107 |
108 | qkv = self.to_qkv(x).chunk(3, dim=-1)
109 | nw_h = n_h // self.window_size
110 | nw_w = n_w // self.window_size
111 |
112 | q, k, v = map(
113 | lambda t: rearrange(t, 'b (nw_h w_h) (nw_w w_w) (h d) -> b h (nw_h nw_w) (w_h w_w) d',
114 | h=h, w_h=self.window_size, w_w=self.window_size), qkv)
115 |
116 | dots = einsum('b h w i d, b h w j d -> b h w i j', q, k) * self.scale
117 |
118 | if self.relative_pos_embedding:
119 | dots += self.pos_embedding[self.relative_indices[:, :, 0], self.relative_indices[:, :, 1]]
120 | else:
121 | dots += self.pos_embedding
122 |
123 | if self.shifted:
124 | dots[:, :, -nw_w:] += self.upper_lower_mask
125 | dots[:, :, nw_w - 1::nw_w] += self.left_right_mask
126 |
127 | attn = dots.softmax(dim=-1)
128 |
129 | out = einsum('b h w i j, b h w j d -> b h w i d', attn, v)
130 | out = rearrange(out, 'b h (nw_h nw_w) (w_h w_w) d -> b (nw_h w_h) (nw_w w_w) (h d)',
131 | h=h, w_h=self.window_size, w_w=self.window_size, nw_h=nw_h, nw_w=nw_w)
132 | out = self.to_out(out)
133 |
134 | if self.shifted:
135 | out = self.cyclic_back_shift(out)
136 | return out
137 |
138 |
139 | class SwinBlock(nn.Module):
140 | def __init__(self, dim, heads, head_dim, mlp_dim, shifted, window_size, relative_pos_embedding):
141 | super().__init__()
142 | self.attention_block = Residual(PreNorm(dim, WindowAttention(dim=dim,
143 | heads=heads,
144 | head_dim=head_dim,
145 | shifted=shifted,
146 | window_size=window_size,
147 | relative_pos_embedding=relative_pos_embedding)))
148 | self.mlp_block = Residual(PreNorm(dim, FeedForward(dim=dim, hidden_dim=mlp_dim)))
149 |
150 | def forward(self, x):
151 | x = self.attention_block(x)
152 | x = self.mlp_block(x)
153 | return x
154 |
155 |
156 | class PatchMerging(nn.Module):
157 | def __init__(self, in_channels, out_channels, downscaling_factor):
158 | super().__init__()
159 | self.downscaling_factor = downscaling_factor
160 | self.patch_merge = nn.Unfold(kernel_size=downscaling_factor, stride=downscaling_factor, padding=0)
161 | self.linear = nn.Linear(in_channels * downscaling_factor ** 2, out_channels)
162 |
163 | def forward(self, x):
164 | b, c, h, w = x.shape
165 | new_h, new_w = h // self.downscaling_factor, w // self.downscaling_factor
166 | x = self.patch_merge(x).view(b, -1, new_h, new_w).permute(0, 2, 3, 1)
167 | x = self.linear(x)
168 | return x
169 |
170 |
171 | class StageModule(nn.Module):
172 | def __init__(self, in_channels, hidden_dimension, layers, downscaling_factor, num_heads, head_dim, window_size,
173 | relative_pos_embedding):
174 | super().__init__()
175 | assert layers % 2 == 0, 'Stage layers need to be divisible by 2 for regular and shifted block.'
176 |
177 | self.patch_partition = PatchMerging(in_channels=in_channels, out_channels=hidden_dimension,
178 | downscaling_factor=downscaling_factor)
179 |
180 | self.layers = nn.ModuleList([])
181 | for _ in range(layers // 2):
182 | self.layers.append(nn.ModuleList([
183 | SwinBlock(dim=hidden_dimension, heads=num_heads, head_dim=head_dim, mlp_dim=hidden_dimension * 4,
184 | shifted=False, window_size=window_size, relative_pos_embedding=relative_pos_embedding),
185 | SwinBlock(dim=hidden_dimension, heads=num_heads, head_dim=head_dim, mlp_dim=hidden_dimension * 4,
186 | shifted=True, window_size=window_size, relative_pos_embedding=relative_pos_embedding),
187 | ]))
188 |
189 | def forward(self, x):
190 | x = self.patch_partition(x)
191 | for regular_block, shifted_block in self.layers:
192 | x = regular_block(x)
193 | x = shifted_block(x)
194 | return x.permute(0, 3, 1, 2)
195 |
196 |
197 | class SwinTransformer(nn.Module):
198 | def __init__(self, *, hidden_dim, layers, heads, channels=3, num_classes=10, head_dim=32, window_size=4,
199 | downscaling_factors=(2, 2, 2, 1), relative_pos_embedding=True, args=None):
200 | super().__init__()
201 |
202 | self.stage1 = StageModule(in_channels=channels, hidden_dimension=hidden_dim, layers=layers[0],
203 | downscaling_factor=downscaling_factors[0], num_heads=heads[0], head_dim=head_dim,
204 | window_size=window_size, relative_pos_embedding=relative_pos_embedding)
205 | self.stage2 = StageModule(in_channels=hidden_dim, hidden_dimension=hidden_dim * 2, layers=layers[1],
206 | downscaling_factor=downscaling_factors[1], num_heads=heads[1], head_dim=head_dim,
207 | window_size=window_size, relative_pos_embedding=relative_pos_embedding)
208 | self.stage3 = StageModule(in_channels=hidden_dim * 2, hidden_dimension=hidden_dim * 4, layers=layers[2],
209 | downscaling_factor=downscaling_factors[2], num_heads=heads[2], head_dim=head_dim,
210 | window_size=window_size, relative_pos_embedding=relative_pos_embedding)
211 | self.stage4 = StageModule(in_channels=hidden_dim * 4, hidden_dimension=hidden_dim * 8, layers=layers[3],
212 | downscaling_factor=downscaling_factors[3], num_heads=heads[3], head_dim=head_dim,
213 | window_size=window_size, relative_pos_embedding=relative_pos_embedding)
214 |
215 | self.final_layer_norm = nn.LayerNorm(hidden_dim * 8)
216 | self.fc = nn.Linear(hidden_dim * 8, num_classes)
217 |
218 | def forward(self, img, return_features=False):
219 | x = self.stage1(img)
220 | x = self.stage2(x)
221 | x = self.stage3(x)
222 | x = self.stage4(x)
223 | x = x.mean(dim=[2, 3])
224 | x = self.final_layer_norm(x)
225 | if return_features:
226 | return x
227 | return self.fc(x)
228 |
229 |
230 | def swin_t(num_classes, args, hidden_dim=96, layers=(2, 2, 6, 2), heads=(3, 6, 12, 24), **kwargs):
231 | return SwinTransformer(hidden_dim=hidden_dim, layers=layers, heads=heads, num_classes=num_classes, args=args, **kwargs)
232 |
233 |
234 | def swin_s(num_classes, hidden_dim=96, layers=(2, 2, 18, 2), heads=(3, 6, 12, 24), **kwargs):
235 | return SwinTransformer(hidden_dim=hidden_dim, layers=layers, heads=heads, num_classes=num_classes, **kwargs)
236 |
237 |
238 | def swin_b(num_classes, hidden_dim=128, layers=(2, 2, 18, 2), heads=(4, 8, 16, 32), **kwargs):
239 | return SwinTransformer(hidden_dim=hidden_dim, layers=layers, heads=heads, **kwargs)
240 |
241 |
242 | def swin_l(num_classes, hidden_dim=192, layers=(2, 2, 18, 2), heads=(6, 12, 24, 48), **kwargs):
243 | return SwinTransformer(hidden_dim=hidden_dim, layers=layers, heads=heads, **kwargs)
244 |
--------------------------------------------------------------------------------
/models/backbones/Vgg.py:
--------------------------------------------------------------------------------
1 | '''
2 | Modified from https://github.com/pytorch/vision.git
3 | '''
4 | import math
5 |
6 | import torch.nn as nn
7 | import torch.nn.init as init
8 |
9 | __all__ = [
10 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
11 | 'vgg19_bn', 'vgg19',
12 | ]
13 |
14 |
15 | class VGG(nn.Module):
16 | '''
17 | VGG model
18 | '''
19 | def __init__(self, features, num_classes, args):
20 | super(VGG, self).__init__()
21 | self.features = features
22 | self.classifier = nn.Sequential(
23 | nn.Dropout(),
24 | nn.Linear(512, 512),
25 | nn.ReLU(True),
26 | nn.Dropout(),
27 | nn.Linear(512, 512),
28 | nn.ReLU(True),
29 | )
30 | self.fc = nn.Linear(512, num_classes)
31 | # Initialize weights
32 | for m in self.modules():
33 | if isinstance(m, nn.Conv2d):
34 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
35 | m.weight.data.normal_(0, math.sqrt(2. / n))
36 | m.bias.data.zero_()
37 |
38 |
39 | def forward(self, x, return_features=False):
40 | x = self.features(x)
41 | x = x.view(x.size(0), -1)
42 | x = self.classifier(x)
43 | if return_features:
44 | return x
45 | x = self.fc(x)
46 | return x
47 |
48 |
49 | def make_layers(cfg, batch_norm=False):
50 | layers = []
51 | in_channels = 3
52 | for v in cfg:
53 | if v == 'M':
54 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
55 | else:
56 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
57 | if batch_norm:
58 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
59 | else:
60 | layers += [conv2d, nn.ReLU(inplace=True)]
61 | in_channels = v
62 | return nn.Sequential(*layers)
63 |
64 |
65 | cfg = {
66 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
67 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
68 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
69 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M',
70 | 512, 512, 512, 512, 'M'],
71 | }
72 |
73 |
74 | def vgg11(num_classes, args):
75 | """VGG 11-layer model (configuration "A")"""
76 | return VGG(make_layers(cfg['A']), num_classes, args)
77 |
78 |
79 | def vgg11_bn(num_classes, args):
80 | """VGG 11-layer model (configuration "A") with batch normalization"""
81 | return VGG(make_layers(cfg['A'], batch_norm=True), num_classes, args)
82 |
83 |
84 | def vgg13(num_classes, args):
85 | """VGG 13-layer model (configuration "B")"""
86 | return VGG(make_layers(cfg['B']), num_classes, args)
87 |
88 |
89 | def vgg13_bn(num_classes, args):
90 | """VGG 13-layer model (configuration "B") with batch normalization"""
91 | return VGG(make_layers(cfg['B'], batch_norm=True), num_classes, args)
92 |
93 |
94 | def vgg16(num_classes, args):
95 | """VGG 16-layer model (configuration "D")"""
96 | return VGG(make_layers(cfg['D']), num_classes, args)
97 |
98 |
99 | def vgg16_bn(num_classes, args):
100 | """VGG 16-layer model (configuration "D") with batch normalization"""
101 | return VGG(make_layers(cfg['D'], batch_norm=True), num_classes, args)
102 |
103 |
104 | def vgg19(num_classes, args):
105 | """VGG 19-layer model (configuration "E")"""
106 | return VGG(make_layers(cfg['E']), num_classes, args)
107 |
108 |
109 | def vgg19_bn(num_classes, args):
110 | """VGG 19-layer model (configuration 'E') with batch normalization"""
111 | return VGG(make_layers(cfg['E'], batch_norm=True), num_classes, args)
112 |
--------------------------------------------------------------------------------
/models/backbones/__init__.py:
--------------------------------------------------------------------------------
1 | from .ResNet18 import resnet18 as resnet18
2 | from .Lenet import lenet as lenet
3 | from .Vgg import vgg16_bn as vgg16
4 | from .Alexnet import alexnet as alexnet
5 | from .Densenet import densenet as densenet
6 | from .Senet import senet18 as senet
7 | from .Regnet import regnet as regnet
8 | from .Inception import inception as inception
9 | from .Swin import swin_t as swin
10 | from .ResNext import resnext as resnext
11 |
12 |
--------------------------------------------------------------------------------
/models/backbones/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/HCL/4aa98387995946061ade7b3eede3715e7c29c1ba/models/backbones/utils/__init__.py
--------------------------------------------------------------------------------
/models/backbones/utils/modules.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 | import torch.nn as nn
8 | from torch.nn.parameter import Parameter
9 |
10 |
11 | class AlphaModule(nn.Module):
12 | def __init__(self, shape):
13 | super(AlphaModule, self).__init__()
14 | if not isinstance(shape, tuple):
15 | shape = (shape,)
16 | self.alpha = Parameter(torch.rand(tuple([1] + list(shape))) * 0.1,
17 | requires_grad=True)
18 |
19 | def forward(self, x):
20 | return x * self.alpha
21 |
22 | def parameters(self, recurse: bool = True):
23 | yield self.alpha
24 |
25 |
26 | class ListModule(nn.Module):
27 | def __init__(self, *args):
28 | super(ListModule, self).__init__()
29 | self.idx = 0
30 | for module in args:
31 | self.add_module(str(self.idx), module)
32 | self.idx += 1
33 |
34 | def append(self, module):
35 | self.add_module(str(self.idx), module)
36 | self.idx += 1
37 |
38 | def __getitem__(self, idx):
39 | if idx < 0:
40 | idx += self.idx
41 | if idx >= len(self._modules):
42 | raise IndexError('index {} is out of range'.format(idx))
43 | it = iter(self._modules.values())
44 | for i in range(idx):
45 | next(it)
46 | return next(it)
47 |
48 | def __iter__(self):
49 | return iter(self._modules.values())
50 |
51 | def __len__(self):
52 | return len(self._modules)
53 |
--------------------------------------------------------------------------------
/models/distil.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # NVIDIA CORPORATION and its licensors retain all intellectual property
6 | # and proprietary rights in and to this software, related documentation
7 | # and any modifications thereto. Any use, reproduction, disclosure or
8 | # distribution of this software and related documentation without an express
9 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
10 |
11 | from utils.buffer import Buffer
12 | from torch.nn import functional as F
13 | from models.utils.continual_model import ContinualModel
14 | from augmentations import get_aug
15 | import torch
16 | from utils.losses import LabelSmoothing, KL_div_Loss
17 | from datasets import get_dataset
18 |
19 | def smooth(logits, temp, dim):
20 | log = logits ** (1 / temp)
21 | return log / torch.sum(log, dim).unsqueeze(1)
22 |
23 |
24 | def modified_kl_div(old, new):
25 | return -torch.mean(torch.sum(old * torch.log(new), 1))
26 |
27 |
28 |
29 | class Distil(ContinualModel):
30 | NAME = 'distil'
31 | COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual']
32 |
33 | def __init__(self, backbone, loss, args, len_train_loader, transform, global_model):
34 | super(Distil, self).__init__(backbone, loss, args, len_train_loader, transform)
35 | self.global_model = global_model
36 | self.buffer = Buffer(self.args.model.buffer_size, self.device)
37 | self.global_model = global_model
38 | self.criterion_kl = KL_div_Loss(temperature=1.0).cuda()
39 | self.soft = torch.nn.Softmax(dim=1)
40 |
41 | def observe(self, inputs1, labels, inputs2, notaug_inputs, task_id):
42 |
43 | self.opt.zero_grad()
44 | inputs1, labels = inputs1.to(self.device), labels.to(self.device)
45 | inputs2 = inputs2.to(self.device)
46 | notaug_inputs = notaug_inputs.to(self.device)
47 | real_batch_size = inputs1.shape[0]
48 |
49 | if task_id:
50 | self.global_model.eval()
51 | outputs = self.net.module.backbone(inputs1)
52 | with torch.no_grad():
53 | outputs_teacher = self.global_model.net.module.backbone(inputs1)
54 |
55 | penalty = self.args.train.alpha * self.criterion_kl(outputs, outputs_teacher)
56 | loss = self.loss(outputs, labels) + penalty
57 | else:
58 | outputs = self.net.module.backbone(inputs1)
59 | loss = self.loss(outputs, labels)
60 |
61 | if task_id:
62 | data_dict = {'loss': loss, 'penalty': penalty}
63 | else:
64 | data_dict = {'loss': loss, 'penalty': 0.}
65 |
66 | loss.backward()
67 | self.opt.step()
68 | data_dict.update({'lr': self.args.train.base_lr})
69 |
70 | return data_dict
71 |
--------------------------------------------------------------------------------
/models/distilbuf.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # NVIDIA CORPORATION and its licensors retain all intellectual property
6 | # and proprietary rights in and to this software, related documentation
7 | # and any modifications thereto. Any use, reproduction, disclosure or
8 | # distribution of this software and related documentation without an express
9 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
10 |
11 | from utils.buffer import Buffer
12 | from torch.nn import functional as F
13 | from models.utils.continual_model import ContinualModel
14 | from augmentations import get_aug
15 | import torch
16 | from utils.losses import LabelSmoothing, KL_div_Loss
17 | from datasets import get_dataset
18 |
19 | def smooth(logits, temp, dim):
20 | log = logits ** (1 / temp)
21 | return log / torch.sum(log, dim).unsqueeze(1)
22 |
23 |
24 | def modified_kl_div(old, new):
25 | return -torch.mean(torch.sum(old * torch.log(new), 1))
26 |
27 |
28 |
29 | class DistilBuf(ContinualModel):
30 | NAME = 'distilbuf'
31 | COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual']
32 |
33 | def __init__(self, backbone, loss, args, len_train_loader, transform, global_model):
34 | super(DistilBuf, self).__init__(backbone, loss, args, len_train_loader, transform)
35 | self.global_model = global_model
36 | self.buffer = Buffer(self.args.model.buffer_size, self.device)
37 | self.global_model = global_model
38 | self.criterion_kl = KL_div_Loss(temperature=1.0).cuda()
39 | self.soft = torch.nn.Softmax(dim=1)
40 |
41 | def observe(self, inputs1, labels, inputs2, notaug_inputs, task_id):
42 |
43 | self.opt.zero_grad()
44 | inputs1, labels = inputs1.to(self.device), labels.to(self.device)
45 | inputs2 = inputs2.to(self.device)
46 | notaug_inputs = notaug_inputs.to(self.device)
47 | real_batch_size = inputs1.shape[0]
48 |
49 | if task_id:
50 | self.global_model.eval()
51 | outputs = self.net.module.backbone(inputs1)
52 | with torch.no_grad():
53 | outputs_teacher = self.global_model.net.module.backbone(inputs1)
54 |
55 | penalty = self.args.train.alpha * self.criterion_kl(outputs, outputs_teacher)
56 | loss = self.loss(outputs, labels) + penalty
57 | else:
58 | outputs = self.net.module.backbone(inputs1)
59 | loss = self.loss(outputs, labels)
60 |
61 | if not self.global_model.buffer.is_empty():
62 | buf_inputs, buf_logits = self.global_model.buffer.get_data(
63 | self.args.train.batch_size, transform=self.transform)
64 | buf_outputs = self.net.module.backbone(buf_inputs)
65 | penalty = 0.3 * self.loss(buf_outputs, buf_logits.long())
66 | loss += penalty
67 |
68 | if task_id:
69 | data_dict = {'loss': loss, 'penalty': penalty}
70 | else:
71 | data_dict = {'loss': loss, 'penalty': 0.}
72 |
73 | loss.backward()
74 | self.opt.step()
75 | data_dict.update({'lr': self.args.train.base_lr})
76 | self.global_model.buffer.add_data(examples=notaug_inputs, labels=labels[:real_batch_size])
77 |
78 | return data_dict
79 |
--------------------------------------------------------------------------------
/models/optimizers/__init__.py:
--------------------------------------------------------------------------------
1 | from .lars import LARS
2 | import torch
3 | from .lr_scheduler import LR_Scheduler
4 |
5 |
6 | def get_optimizer(name, model, lr, momentum, weight_decay, cl_default):
7 |
8 | predictor_prefix = ('module.predictor', 'predictor')
9 | parameters = [{
10 | 'name': 'base',
11 | 'params': [param for name, param in model.named_parameters() if not name.startswith(predictor_prefix)],
12 | 'lr': lr
13 | },{
14 | 'name': 'predictor',
15 | 'params': [param for name, param in model.named_parameters() if name.startswith(predictor_prefix)],
16 | 'lr': lr
17 | }]
18 | if name == 'lars':
19 | optimizer = LARS(parameters, lr=lr, momentum=momentum, weight_decay=weight_decay)
20 | elif name == 'sgd':
21 | if cl_default:
22 | optimizer = torch.optim.SGD(parameters, lr=lr)
23 | else:
24 | optimizer = torch.optim.SGD(parameters, lr=lr, momentum=momentum, weight_decay=weight_decay)
25 | elif name == 'adam':
26 | optimizer = torch.optim.Adam(parameters, lr=lr)
27 | else:
28 | raise NotImplementedError
29 | return optimizer
30 |
31 |
32 |
33 |
--------------------------------------------------------------------------------
/models/optimizers/lars.py:
--------------------------------------------------------------------------------
1 | """ Layer-wise adaptive rate scaling for SGD in PyTorch! """
2 | import torch
3 | from torch.optim.optimizer import Optimizer, required
4 |
5 | class LARS(Optimizer):
6 | r"""Implements layer-wise adaptive rate scaling for SGD.
7 |
8 | Args:
9 | params (iterable): iterable of parameters to optimize or dicts defining
10 | parameter groups
11 | lr (float): base learning rate (\gamma_0)
12 | momentum (float, optional): momentum factor (default: 0) ("m")
13 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
14 | ("\beta")
15 | eta (float, optional): LARS coefficient
16 | max_epoch: maximum training epoch to determine polynomial LR decay.
17 |
18 | Based on Algorithm 1 of the following paper by You, Gitman, and Ginsburg.
19 | Large Batch Training of Convolutional Networks:
20 | https://arxiv.org/abs/1708.03888
21 |
22 | Example:
23 | >>> optimizer = LARS(model.parameters(), lr=0.1, eta=1e-3)
24 | >>> optimizer.zero_grad()
25 | >>> loss_fn(model(input), target).backward()
26 | >>> optimizer.step()
27 | """
28 | def __init__(self, params, lr=required, momentum=.9,
29 | weight_decay=.0005, eta=0.001, max_epoch=200):
30 | if lr is not required and lr < 0.0:
31 | raise ValueError("Invalid learning rate: {}".format(lr))
32 | if momentum < 0.0:
33 | raise ValueError("Invalid momentum value: {}".format(momentum))
34 | if weight_decay < 0.0:
35 | raise ValueError("Invalid weight_decay value: {}"
36 | .format(weight_decay))
37 | if eta < 0.0:
38 | raise ValueError("Invalid LARS coefficient value: {}".format(eta))
39 |
40 | self.epoch = 0
41 | defaults = dict(lr=lr, momentum=momentum,
42 | weight_decay=weight_decay,
43 | eta=eta, max_epoch=max_epoch)
44 | super(LARS, self).__init__(params, defaults)
45 |
46 | def step(self, epoch=None, closure=None):
47 | """Performs a single optimization step.
48 |
49 | Arguments:
50 | closure (callable, optional): A closure that reevaluates the model
51 | and returns the loss.
52 | epoch: current epoch to calculate polynomial LR decay schedule.
53 | if None, uses self.epoch and increments it.
54 | """
55 | loss = None
56 | if closure is not None:
57 | loss = closure()
58 |
59 | if epoch is None:
60 | epoch = self.epoch
61 | self.epoch += 1
62 |
63 | for group in self.param_groups:
64 | weight_decay = group['weight_decay']
65 | momentum = group['momentum']
66 | eta = group['eta']
67 | lr = group['lr']
68 | max_epoch = group['max_epoch']
69 |
70 | for p in group['params']:
71 | if p.grad is None:
72 | continue
73 |
74 | param_state = self.state[p]
75 | d_p = p.grad.data
76 |
77 | weight_norm = torch.norm(p.data)
78 | grad_norm = torch.norm(d_p)
79 |
80 | # Global LR computed on polynomial decay schedule
81 | decay = (1 - float(epoch) / max_epoch) ** 2
82 | global_lr = lr * decay
83 |
84 | # Compute local learning rate for this layer
85 | local_lr = eta * weight_norm / (grad_norm + weight_decay * weight_norm)
86 |
87 | # Update the momentum term
88 | actual_lr = local_lr * global_lr
89 |
90 | if 'momentum_buffer' not in param_state:
91 | buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
92 | else:
93 | buf = param_state['momentum_buffer']
94 | buf.mul_(momentum).add_(d_p + weight_decay * p.data, alpha=actual_lr)
95 | p.data.add_(-buf)
96 |
97 | return loss
--------------------------------------------------------------------------------
/models/optimizers/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class LR_Scheduler(object):
6 | def __init__(self, optimizer, warmup_epochs, warmup_lr, num_epochs, base_lr, final_lr, iter_per_epoch, constant_predictor_lr=False):
7 | self.base_lr = base_lr
8 | self.constant_predictor_lr = constant_predictor_lr
9 | warmup_iter = iter_per_epoch * warmup_epochs
10 | warmup_lr_schedule = np.linspace(warmup_lr, base_lr, warmup_iter)
11 | decay_iter = iter_per_epoch * (num_epochs - warmup_epochs)
12 | cosine_lr_schedule = final_lr+0.5*(base_lr-final_lr)*(1+np.cos(np.pi*np.arange(decay_iter)/decay_iter))
13 |
14 | self.lr_schedule = np.concatenate((warmup_lr_schedule, cosine_lr_schedule))
15 | self.optimizer = optimizer
16 | self.iter = 0
17 | self.current_lr = 0
18 | def step(self):
19 | for param_group in self.optimizer.param_groups:
20 |
21 | if self.constant_predictor_lr and param_group['name'] == 'predictor':
22 | param_group['lr'] = self.base_lr
23 | else:
24 | lr = param_group['lr'] = self.lr_schedule[self.iter]
25 |
26 | self.iter += 1
27 | self.current_lr = lr
28 | return lr
29 |
30 | def reset(self):
31 | self.iter = 0
32 | self.current_lr = 0
33 |
34 | def get_lr(self):
35 | return self.current_lr
36 |
37 | if __name__ == "__main__":
38 | import torchvision
39 | model = torchvision.models.resnet50()
40 | optimizer = torch.optim.SGD(model.parameters(), lr=999)
41 | epochs = 100
42 | n_iter = 1000
43 | scheduler = LR_Scheduler(optimizer, 10, 1, epochs, 3, 0, n_iter)
44 | import matplotlib.pyplot as plt
45 | lrs = []
46 | for epoch in range(epochs):
47 | for it in range(n_iter):
48 | lr = scheduler.step()
49 | lrs.append(lr)
50 | plt.plot(lrs)
51 | plt.show()
52 |
--------------------------------------------------------------------------------
/models/qdi.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # NVIDIA CORPORATION and its licensors retain all intellectual property
6 | # and proprietary rights in and to this software, related documentation
7 | # and any modifications thereto. Any use, reproduction, disclosure or
8 | # distribution of this software and related documentation without an express
9 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
10 |
11 | import torch
12 | import numpy as np
13 | import random
14 | from utils.buffer import Buffer
15 | from torch.nn import functional as F
16 | from models.utils.continual_model import ContinualModel
17 | from augmentations import get_aug
18 | from utils.deep_inversion import DeepInversionFeatureHook
19 | from utils.losses import LabelSmoothing, KL_div_Loss
20 | import torchvision.utils as vutils
21 | from datasets import get_dataset
22 |
23 |
24 | def lr_policy(lr_fn):
25 | def _alr(optimizer, epoch):
26 | lr = lr_fn(epoch)
27 | for param_group in optimizer.param_groups:
28 | param_group['lr'] = lr
29 | return _alr
30 |
31 | def lr_cosine_policy(base_lr, warmup_length, epochs):
32 | def _lr_fn(epoch):
33 | if epoch < warmup_length:
34 | lr = base_lr * (epoch + 1) / warmup_length
35 | else:
36 | e = epoch - warmup_length
37 | es = epochs - warmup_length
38 | lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr
39 | print(lr)
40 | return lr
41 | return lr_policy(_lr_fn)
42 |
43 |
44 | class Qdi(ContinualModel):
45 | NAME = 'qdi'
46 | COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual']
47 |
48 | def __init__(self, backbone, loss, args, len_train_loader, transform, global_model=None):
49 | super(Qdi, self).__init__(backbone, loss, args, len_train_loader, transform)
50 | self.num_classes = 10
51 | im_size = (32, 32) if args.dataset.name == "seq-cifar10" or args.dataset.name == "seq-cifar100" else (64, 64)
52 | images_per_class = 20
53 | self.buffer = Buffer(self.args.model.buffer_size, self.device)
54 | self.global_model = global_model
55 | self.criterion_kl = KL_div_Loss(temperature=1.0).cuda()
56 | self.lr_scheduler = lr_cosine_policy(args.train.di_lr, 100, args.train.di_itrs)
57 | self.args = args
58 | self.cpt = get_dataset(args).N_CLASSES_PER_TASK
59 | self.current_step = 0
60 |
61 | def begin_task(self, task_id, dataset=None):
62 | if task_id:
63 | self.sample_inputs = []
64 | if dataset is not None:
65 | for i in range(0, dataset.train_loader.dataset.data.shape[0], self.args.train.batch_size):
66 | inputs = torch.stack([dataset.train_loader.dataset.__getitem__(j)[0][0]
67 | for j in range(i, min(i + self.args.train.batch_size, len(dataset.train_loader.dataset)))])
68 | self.sample_inputs.append(inputs)
69 |
70 | self.sample_inputs = torch.cat(self.sample_inputs)
71 |
72 | rand_idx = torch.randperm(self.sample_inputs.shape[0])
73 | sample_inputs = self.sample_inputs[rand_idx].to(self.device)
74 | sample_batch = sample_inputs[:self.args.model.buffer_size * 4].to(self.device)
75 | statistics = []
76 |
77 | batchnorm_flag = [True if isinstance(module, torch.nn.BatchNorm2d) else False for module in self.global_model.net.module.backbone.modules()]
78 |
79 | if True in batchnorm_flag:
80 | for module in self.global_model.net.module.backbone.modules():
81 | if isinstance(module, torch.nn.BatchNorm2d):
82 | statistics.append(DeepInversionFeatureHook(module))
83 |
84 | for item in statistics:
85 | item.capture_bn_stats = False
86 | item.use_stored_stats = False
87 | else:
88 | for module in self.global_model.net.module.backbone.modules():
89 | if isinstance(module, torch.nn.Conv2d):
90 | statistics.append(DeepInversionFeatureHook(module))
91 |
92 | for item in statistics:
93 | item.capture_bn_stats = True
94 | item.use_stored_stats = True
95 |
96 | _ = self.global_model.net.module.backbone(sample_batch)
97 | print('Finished capturing post conv2d stats. Freezing the stats.')
98 |
99 | for item in statistics:
100 | item.capture_bn_stats = False
101 | item.use_stored_stats = True
102 |
103 | rand_idx = torch.randperm(self.sample_inputs.shape[0])
104 | sample_inputs = self.sample_inputs[rand_idx].to(self.device)
105 | sample_batch = sample_inputs[:self.args.model.buffer_size].to(self.device)
106 | vutils.save_image(sample_batch.data.clone(),
107 | f'./di_images_{self.args.dataset.name}/sample_batch_{task_id}.png',
108 | normalize=True, scale_each=True, nrow=5)
109 | sample_batch_size, im_size = sample_batch.shape[0], sample_batch.shape[2]
110 | cls_per_task = task_id * self.cpt
111 | self.label_syn = torch.tensor([np.ones(sample_batch_size//cls_per_task) * i for i in range(cls_per_task)], dtype=torch.long, requires_grad=False, device=self.device).view(-1)
112 | rand_idx = torch.randperm(len(self.label_syn))
113 | label_syn = self.label_syn[rand_idx]
114 | image_syn = torch.randn(size=(self.label_syn.shape[0], 3, im_size, im_size), dtype=torch.float, requires_grad=True, device=self.device)
115 | sample_batch = sample_batch[:self.label_syn.shape[0]]
116 | image_syn.data = sample_batch.data.clone()
117 | image_opt = torch.optim.Adam([image_syn], lr=self.args.train.di_lr, betas=[0.5, 0.9], eps = 1e-8)
118 |
119 |
120 | self.global_model.eval()
121 | self.net.eval()
122 |
123 | for step in range(self.args.train.di_itrs +1):
124 | self.lr_scheduler(image_opt, step)
125 | image_opt.zero_grad()
126 | self.global_model.zero_grad()
127 | outputs = self.global_model.net.module.backbone(image_syn)
128 | loss_ce = self.loss(outputs, label_syn.long())
129 |
130 | diff1 = image_syn[:,:,:,:-1] - image_syn[:,:,:,1:]
131 | diff2 = image_syn[:,:,:-1,:] - image_syn[:,:,1:,:]
132 | diff3 = image_syn[:,:,1:,:-1] - image_syn[:,:,:-1,1:]
133 | diff4 = image_syn[:,:,:-1,:-1] - image_syn[:,:,1:,1:]
134 | loss_var = torch.norm(diff1) + torch.norm(diff2) + torch.norm(diff3) + torch.norm(diff4)
135 |
136 | loss_distr = self.args.train.di_feature * sum([mod.r_feature for mod in statistics])
137 | loss_var = self.args.train.di_var * loss_var
138 | loss_l2 = self.args.train.di_l2 * torch.norm(image_syn, 2)
139 | loss = loss_ce + loss_distr + loss_l2 + loss_var
140 |
141 | if step % 5 == 0:
142 | print('\t step', step, '\t ce', loss_ce.item(), '\t r feature', loss_distr.item(), '\tr var', loss_var.item(), '\tr l2', loss_l2.item(), '\t total', loss.item())
143 |
144 | loss.backward()
145 | image_opt.step()
146 | if step % 5 == 0:
147 | vutils.save_image(image_syn.data.clone(),
148 | f'./di_images_{self.args.dataset.name}/di_generated_{task_id}_{step//5}.png',
149 | normalize=True, scale_each=True, nrow=5)
150 |
151 | self.global_model.buffer.add_data(examples=image_syn, labels=label_syn)
152 | self.image_syn = image_syn.detach().clone()
153 | self.label_syn = label_syn.detach().clone()
154 | self.net.train()
155 |
156 | def observe(self, inputs1, labels, inputs2, notaug_inputs, task_id):
157 | inputs1, labels = inputs1.to(self.device), labels.to(self.device)
158 | real_batch_size = inputs1.shape[0]
159 |
160 | if task_id:
161 | outputs_clean = self.net.module.backbone(inputs1)
162 |
163 | outputs = self.net.module.backbone(self.image_syn)
164 | outputs_teacher = self.global_model.net.module.backbone(self.image_syn)
165 | outputs_teacher_clean = self.global_model.net.module.backbone(inputs1)
166 |
167 | penalty = self.criterion_kl(outputs_clean, outputs_teacher_clean) + self.criterion_kl(outputs, outputs_teacher)
168 | loss = self.loss(outputs_clean, labels) + self.args.train.alpha * penalty
169 | else:
170 | outputs = self.net.module.backbone(inputs1)
171 | loss = self.loss(outputs, labels)
172 |
173 | if task_id:
174 | data_dict = {'loss': loss, 'penalty': penalty}
175 | else:
176 | data_dict = {'loss': loss, 'penalty': 0.}
177 |
178 | self.opt.zero_grad()
179 | loss.backward()
180 | self.opt.step()
181 | data_dict.update({'lr': self.args.train.base_lr})
182 | self.current_step += 1
183 |
184 | return data_dict
185 |
--------------------------------------------------------------------------------
/models/simsiam.py:
--------------------------------------------------------------------------------
1 | # Adopted from https://github.com/divyam3897/UCL/blob/main/models/simsiam.py
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torchvision.models import resnet50
7 |
8 |
9 | def D(p, z, version='simplified'): # negative cosine similarity
10 | if version == 'original':
11 | z = z.detach() # stop gradient
12 | p = F.normalize(p, dim=1) # l2-normalize
13 | z = F.normalize(z, dim=1) # l2-normalize
14 | return -(p*z).sum(dim=1).mean()
15 |
16 | elif version == 'simplified':# same thing, much faster. Scroll down, speed test in __main__
17 | return - F.cosine_similarity(p, z.detach(), dim=-1).mean()
18 | else:
19 | raise Exception
20 |
21 |
22 |
23 | class projection_MLP(nn.Module):
24 | def __init__(self, in_dim, hidden_dim=2048, out_dim=2048):
25 | super().__init__()
26 | ''' page 3 baseline setting
27 | Projection MLP. The projection MLP (in f) has BN ap-
28 | plied to each fully-connected (fc) layer, including its out-
29 | put fc. Its output fc has no ReLU. The hidden fc is 2048-d.
30 | This MLP has 3 layers.
31 | '''
32 | self.layer1 = nn.Sequential(
33 | nn.Linear(in_dim, hidden_dim),
34 | nn.BatchNorm1d(hidden_dim),
35 | nn.ReLU(inplace=True)
36 | )
37 | self.layer2 = nn.Sequential(
38 | nn.Linear(hidden_dim, hidden_dim),
39 | nn.BatchNorm1d(hidden_dim),
40 | nn.ReLU(inplace=True)
41 | )
42 | self.layer3 = nn.Sequential(
43 | nn.Linear(hidden_dim, out_dim),
44 | nn.BatchNorm1d(hidden_dim)
45 | )
46 | self.num_layers = 3
47 | def set_layers(self, num_layers):
48 | self.num_layers = num_layers
49 |
50 | def forward(self, x):
51 | if self.num_layers == 3:
52 | x = self.layer1(x)
53 | x = self.layer2(x)
54 | x = self.layer3(x)
55 | elif self.num_layers == 2:
56 | x = self.layer1(x)
57 | x = self.layer3(x)
58 | else:
59 | raise Exception
60 | return x
61 |
62 |
63 | class prediction_MLP(nn.Module):
64 | def __init__(self, in_dim=2048, hidden_dim=512, out_dim=2048): # bottleneck structure
65 | super().__init__()
66 | ''' page 3 baseline setting
67 | Prediction MLP. The prediction MLP (h) has BN applied
68 | to its hidden fc layers. Its output fc does not have BN
69 | (ablation in Sec. 4.4) or ReLU. This MLP has 2 layers.
70 | The dimension of h’s input and output (z and p) is d = 2048,
71 | and h’s hidden layer’s dimension is 512, making h a
72 | bottleneck structure (ablation in supplement).
73 | '''
74 | self.layer1 = nn.Sequential(
75 | nn.Linear(in_dim, hidden_dim),
76 | nn.BatchNorm1d(hidden_dim),
77 | nn.ReLU(inplace=True)
78 | )
79 | self.layer2 = nn.Linear(hidden_dim, out_dim)
80 | """
81 | Adding BN to the output of the prediction MLP h does not work
82 | well (Table 3d). We find that this is not about collapsing.
83 | The training is unstable and the loss oscillates.
84 | """
85 |
86 | def forward(self, x):
87 | x = self.layer1(x)
88 | x = self.layer2(x)
89 | return x
90 |
91 | class SimSiam(nn.Module):
92 | def __init__(self, backbone=resnet50()):
93 | super().__init__()
94 |
95 | self.backbone = backbone
96 | self.projector = projection_MLP(backbone.output_dim)
97 |
98 | self.encoder = nn.Sequential( # f encoder
99 | self.backbone,
100 | self.projector
101 | )
102 | self.predictor = prediction_MLP()
103 | self.distil_predictor = prediction_MLP()
104 |
105 | def forward(self, x1, x2):
106 |
107 | f, h = self.encoder, self.predictor
108 | z1, z2 = f(x1), f(x2)
109 | p1, p2 = h(z1), h(z2)
110 | L = D(p1, z2) / 2 + D(p2, z1) / 2
111 | return {'loss': L, 'z1': z1, 'z2': z2}
112 |
113 | if __name__ == "__main__":
114 | model = SimSiam()
115 | model = torch.nn.DataParallel(model).cuda()
116 | x1 = torch.randn((128, 3, 32, 32))
117 | x2 = torch.randn_like(x1)
118 |
119 | for i in range(50):
120 | model.forward(x1, x2).backward()
121 | print("forward backwork check")
122 |
123 | z1 = torch.randn((200, 2560))
124 | z2 = torch.randn_like(z1)
125 | import time
126 | tic = time.time()
127 | print(D(z1, z2, version='original'))
128 | toc = time.time()
129 | print(toc - tic)
130 | tic = time.time()
131 | print(D(z1, z2, version='simplified'))
132 | toc = time.time()
133 | print(toc - tic)
134 |
135 | # Output:
136 | # tensor(-0.0010)
137 | # 0.005159854888916016
138 | # tensor(-0.0010)
139 | # 0.0014872550964355469
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
--------------------------------------------------------------------------------
/models/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/HCL/4aa98387995946061ade7b3eede3715e7c29c1ba/models/utils/__init__.py
--------------------------------------------------------------------------------
/models/utils/continual_model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch.nn as nn
7 | from torch.optim import SGD
8 | import torch
9 | import torchvision
10 | from argparse import Namespace
11 | from utils.conf import get_device
12 | import numpy as np
13 | from ..optimizers import get_optimizer, LR_Scheduler
14 |
15 |
16 | class ContinualModel(nn.Module):
17 | """
18 | Continual learning model.
19 | """
20 | NAME = None
21 | COMPATIBILITY = []
22 |
23 | def __init__(self, backbone: nn.Module, loss: nn.Module,
24 | args: Namespace, dataset, transform: torchvision.transforms) -> None:
25 | super(ContinualModel, self).__init__()
26 |
27 | self.net = backbone
28 | self.net = nn.DataParallel(self.net)
29 | self.loss = loss
30 | self.args = args
31 | self.transform = transform
32 | self.dataset = dataset
33 |
34 | if args.cl_default:
35 | self.opt = get_optimizer(
36 | args.train.optimizer.name, self.net,
37 | lr=args.train.base_lr,
38 | momentum=args.train.optimizer.momentum,
39 | weight_decay=args.train.optimizer.weight_decay,
40 | cl_default=args.cl_default)
41 | else:
42 | self.opt = get_optimizer(
43 | args.train.optimizer.name, self.net,
44 | lr=args.train.base_lr*args.train.batch_size/256,
45 | momentum=args.train.optimizer.momentum,
46 | weight_decay=args.train.optimizer.weight_decay,
47 | cl_default=args.cl_default)
48 |
49 | # self.lr_scheduler = LR_Scheduler(
50 | # self.opt,
51 | # args.train.warmup_epochs, args.train.warmup_lr*args.train.batch_size/256,
52 | # args.train.num_epochs, args.train.base_lr*args.train.batch_size/256, args.train.final_lr*args.train.batch_size/256,
53 | # len_train_lodaer,
54 | # constant_predictor_lr=True # see the end of section 4.2 predictor
55 | # )
56 | self.device = get_device()
57 |
58 | def forward(self, x: torch.Tensor) -> torch.Tensor:
59 | """
60 | Computes a forward pass.
61 | :param x: batch of inputs
62 | :param task_label: some models require the task label
63 | :return: the result of the computation
64 | """
65 | return self.net.module.backbone.forward(x)
66 |
67 | def observe(self, inputs: torch.Tensor, labels: torch.Tensor,
68 | not_aug_inputs: torch.Tensor) -> float:
69 | """
70 | Compute a training step over a given batch of examples.
71 | :param inputs: batch of examples
72 | :param labels: ground-truth labels
73 | :param kwargs: some methods could require additional parameters
74 | :return: the value of the loss function
75 | """
76 | pass
77 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | cycler==0.10.0
2 | kiwisolver==1.3.2
3 | matplotlib==3.4.3
4 | numpy==1.21.2
5 | Pillow==8.3.2
6 | pyparsing==2.4.7
7 | python-dateutil==2.8.2
8 | PyYAML==5.4.1
9 | quadprog==0.1.10
10 | six==1.16.0
11 | torch==1.9.1
12 | torchvision==0.10.1
13 | tqdm==4.62.3
14 | typing-extensions==3.10.0.2
15 |
--------------------------------------------------------------------------------
/tools/__init__.py:
--------------------------------------------------------------------------------
1 | from .average_meter import AverageMeter
2 | from .accuracy import accuracy
3 | from .knn_monitor import knn_monitor
4 | from .logger import Logger
5 | from .file_exist_fn import file_exist_check
6 |
--------------------------------------------------------------------------------
/tools/accuracy.py:
--------------------------------------------------------------------------------
1 | def accuracy(output, target, topk=(1,)):
2 | """Computes the accuracy over the k top predictions for the specified values of k"""
3 | with torch.no_grad():
4 | maxk = max(topk)
5 | batch_size = target.size(0)
6 |
7 | _, pred = output.topk(maxk, 1, True, True)
8 | pred = pred.t()
9 | correct = pred.eq(target.view(1, -1).expand_as(pred))
10 |
11 | res = []
12 | for k in topk:
13 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
14 | res.append(correct_k.mul_(100.0 / batch_size))
15 | return res
16 |
--------------------------------------------------------------------------------
/tools/average_meter.py:
--------------------------------------------------------------------------------
1 | class AverageMeter():
2 | """Computes and stores the average and current value"""
3 | def __init__(self, name, fmt=':f'):
4 | self.name = name
5 | self.fmt = fmt
6 | self.log = []
7 | self.val = 0
8 | self.avg = 0
9 | self.sum = 0
10 | self.count = 0
11 |
12 | def reset(self):
13 | self.log.append(self.avg)
14 | self.val = 0
15 | self.avg = 0
16 | self.sum = 0
17 | self.count = 0
18 |
19 | def update(self, val, n=1):
20 | self.val = val
21 | self.sum += val * n
22 | self.count += n
23 | self.avg = self.sum / self.count
24 |
25 | def __str__(self):
26 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
27 | return fmtstr.format(**self.__dict__)
28 |
29 | if __name__ == "__main__":
30 | meter = AverageMeter('sldk')
31 | print(meter.log)
32 |
33 |
--------------------------------------------------------------------------------
/tools/file_exist_fn.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 |
4 |
5 | def file_exist_check(file_dir):
6 |
7 | if os.path.isdir(file_dir):
8 | for i in range(2, 1000):
9 | if not os.path.isdir(file_dir + f'({i})'):
10 | file_dir += f'({i})'
11 | break
12 | return file_dir
13 |
14 |
15 |
16 |
17 |
18 |
--------------------------------------------------------------------------------
/tools/knn_monitor.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 | import torch.nn.functional as F
3 | import torch
4 | import numpy as np
5 | import copy
6 | from utils.metrics import mask_classes
7 |
8 | # code copied from https://colab.research.google.com/github/facebookresearch/moco/blob/colab-notebook/colab/moco_cifar10_demo.ipynb#scrollTo=RI1Y8bSImD7N
9 | # test using a knn monitor
10 | def knn_monitor(net, dataset, memory_data_loader, test_data_loader, device, cl_default, task_id, k=200, t=0.1, hide_progress=False):
11 | net.eval()
12 | try:
13 | classes = len(memory_data_loader.dataset.classes)
14 | except:
15 | classes = 200
16 | total_top1 = total_top1_mask = total_top5 = total_num = 0.0
17 | feature_bank = []
18 | with torch.no_grad():
19 | # generate feature bank
20 | for data, target in tqdm(memory_data_loader, desc='Feature extracting', leave=False, disable=True):
21 | if cl_default:
22 | feature = net(data.cuda(non_blocking=True), return_features=True)
23 | else:
24 | feature = net(data.cuda(non_blocking=True))
25 | feature = F.normalize(feature, dim=1)
26 | feature_bank.append(feature)
27 | # [D, N]
28 | feature_bank = torch.cat(feature_bank, dim=0).t().contiguous()
29 | # [N]
30 | # feature_labels = torch.tensor(memory_data_loader.dataset.targets - np.amin(memory_data_loader.dataset.targets), device=feature_bank.device)
31 | feature_labels = torch.tensor(memory_data_loader.dataset.targets, device=feature_bank.device)
32 | # loop test data to predict the label by weighted knn search
33 | test_bar = tqdm(test_data_loader, desc='kNN', disable=True)
34 | for data, target in test_bar:
35 | data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True)
36 | if cl_default:
37 | feature = net(data, return_features=True)
38 | else:
39 | feature = net(data)
40 | feature = F.normalize(feature, dim=1)
41 | pred_scores = knn_predict(feature, feature_bank, feature_labels, classes, k, t)
42 |
43 | total_num += data.shape[0]
44 | _, preds = torch.max(pred_scores.data, 1)
45 | total_top1 += torch.sum(preds == target).item()
46 |
47 | pred_scores_mask = mask_classes(copy.deepcopy(pred_scores), dataset, task_id)
48 | _, preds_mask = torch.max(pred_scores_mask.data, 1)
49 | total_top1_mask += torch.sum(preds_mask == target).item()
50 |
51 | return total_top1 / total_num * 100, total_top1_mask / total_num * 100
52 |
53 |
54 | # knn monitor as in InstDisc https://arxiv.org/abs/1805.01978
55 | # implementation follows http://github.com/zhirongw/lemniscate.pytorch and https://github.com/leftthomas/SimCLR
56 | def knn_predict(feature, feature_bank, feature_labels, classes, knn_k, knn_t):
57 | # compute cos similarity between each feature vector and feature bank ---> [B, N]
58 | sim_matrix = torch.mm(feature, feature_bank)
59 | # [B, K]
60 | sim_weight, sim_indices = sim_matrix.topk(k=knn_k, dim=-1)
61 | # [B, K]
62 | sim_labels = torch.gather(feature_labels.expand(feature.size(0), -1), dim=-1, index=sim_indices)
63 | sim_weight = (sim_weight / knn_t).exp()
64 |
65 | # counts for each class
66 | one_hot_label = torch.zeros(feature.size(0) * knn_k, classes, device=sim_labels.device)
67 | # [B*K, C]
68 | one_hot_label = one_hot_label.scatter(dim=-1, index=sim_labels.view(-1, 1), value=1.0)
69 | # weighted score ---> [B, C]
70 | pred_scores = torch.sum(one_hot_label.view(feature.size(0), -1, classes) * sim_weight.unsqueeze(dim=-1), dim=1)
71 |
72 | return pred_scores
73 |
--------------------------------------------------------------------------------
/tools/logger.py:
--------------------------------------------------------------------------------
1 | from torch import Tensor
2 | from collections import OrderedDict
3 | import os
4 | from .plotter import Plotter
5 |
6 |
7 | class Logger(object):
8 | def __init__(self, log_dir, matplotlib=True):
9 |
10 | self.reset(log_dir, matplotlib)
11 |
12 | def reset(self, log_dir=None, tensorboard=True, matplotlib=True):
13 |
14 | if log_dir is not None: self.log_dir=log_dir
15 | self.plotter = Plotter() if matplotlib else None
16 | self.counter = OrderedDict()
17 |
18 | def update_scalers(self, ordered_dict):
19 |
20 | for key, value in ordered_dict.items():
21 | if isinstance(value, Tensor):
22 | try:
23 | ordered_dict[key] = value.item()
24 | except:
25 | pass
26 | if self.counter.get(key) is None:
27 | self.counter[key] = 1
28 | else:
29 | self.counter[key] += 1
30 |
31 | # if self.plotter:
32 | # self.plotter.update(ordered_dict)
33 | # self.plotter.save(os.path.join(self.log_dir, 'plotter.svg'))
34 |
35 |
36 |
37 |
38 |
39 |
--------------------------------------------------------------------------------
/tools/plotter.py:
--------------------------------------------------------------------------------
1 | import matplotlib
2 | matplotlib.use('Agg') #https://stackoverflow.com/questions/49921721/runtimeerror-main-thread-is-not-in-main-loop-with-matplotlib-and-flask
3 | import matplotlib.pyplot as plt
4 | from collections import OrderedDict
5 | from torch import Tensor
6 |
7 | class Plotter(object):
8 | def __init__(self):
9 | self.logger = OrderedDict()
10 | def update(self, ordered_dict):
11 | for key, value in ordered_dict.items():
12 | if isinstance(value, Tensor):
13 | try:
14 | ordered_dict[key] = value.item()
15 | except:
16 | pass
17 | if self.logger.get(key) is None:
18 | self.logger[key] = [value]
19 | else:
20 | self.logger[key].append(value)
21 |
22 | def save(self, file, **kwargs):
23 | fig, axes = plt.subplots(nrows=len(self.logger), ncols=1, figsize=(8,2*len(self.logger)))
24 | fig.tight_layout()
25 | for ax, (key, value) in zip(axes, self.logger.items()):
26 | ax.plot(value)
27 | ax.set_title(key)
28 |
29 | plt.savefig(file, **kwargs)
30 | plt.close()
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import os
7 |
8 |
9 | def create_if_not_exists(path: str) -> None:
10 | """
11 | Creates the specified folder if it does not exist.
12 | :param path: the complete path of the folder to be created
13 | """
14 | if not os.path.exists(path):
15 | os.makedirs(path)
16 |
--------------------------------------------------------------------------------
/utils/args.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from argparse import ArgumentParser
7 | from datasets import NAMES as DATASET_NAMES
8 | from models import get_all_models
9 |
10 |
11 | def add_experiment_args(parser: ArgumentParser) -> None:
12 | """
13 | Adds the arguments used by all the models.
14 | :param parser: the parser instance
15 | """
16 | parser.add_argument('--dataset', type=str, required=True,
17 | choices=DATASET_NAMES,
18 | help='Which dataset to perform experiments on.')
19 | parser.add_argument('--model', type=str, required=True,
20 | help='Model name.', choices=get_all_models())
21 |
22 | parser.add_argument('--lr', type=float, required=True,
23 | help='Learning rate.')
24 | parser.add_argument('--warmup_lr', default=0.0, type=float,
25 | help='Warmup Learning rate')
26 | parser.add_argument('--warmup_epochs', default=0, type=int,
27 | help='Warmup epochs')
28 | parser.add_argument('--final_lr', default=0.0, type=float,
29 | help='Final Learning rate')
30 | parser.add_argument('--batch_size', type=int, required=True,
31 | help='Batch size.')
32 | parser.add_argument('--n_epochs', type=int, required=True,
33 | help='The number of epochs for each task.')
34 | parser.add_argument('--sim_siam', action='store_true',
35 | help='Use SimSiam')
36 |
37 |
38 | def add_management_args(parser: ArgumentParser) -> None:
39 | parser.add_argument('--seed', type=int, default=None,
40 | help='The random seed.')
41 | parser.add_argument('--notes', type=str, default=None,
42 | help='Notes for this run.')
43 |
44 | parser.add_argument('--csv_log', action='store_true',
45 | help='Enable csv logging')
46 | parser.add_argument('--tensorboard', action='store_true',
47 | help='Enable tensorboard logging')
48 | parser.add_argument('--validation', action='store_true',
49 | help='Test on the validation set')
50 |
51 |
52 | def add_rehearsal_args(parser: ArgumentParser) -> None:
53 | """
54 | Adds the arguments used by all the rehearsal-based methods
55 | :param parser: the parser instance
56 | """
57 | parser.add_argument('--buffer_size', type=int, required=True,
58 | help='The size of the memory buffer.')
59 | parser.add_argument('--minibatch_size', type=int,
60 | help='The batch size of the memory buffer.')
61 |
--------------------------------------------------------------------------------
/utils/batch_norm.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 | import torch.nn as nn
8 |
9 | class bn_track_stats:
10 | def __init__(self, module: nn.Module, condition=True):
11 | self.module = module
12 | self.enable = condition
13 |
14 | def __enter__(self):
15 | if not self.enable:
16 | for m in self.module.modules():
17 | if isinstance(m, (torch.nn.BatchNorm2d, torch.nn.BatchNorm1d)):
18 | m.track_running_stats = False
19 |
20 | def __exit__(self ,type, value, traceback):
21 | if not self.enable:
22 | for m in self.module.modules():
23 | if isinstance(m, (torch.nn.BatchNorm2d, torch.nn.BatchNorm1d)):
24 | m.track_running_stats = True
25 |
--------------------------------------------------------------------------------
/utils/buffer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 | import numpy as np
8 | from typing import Tuple
9 | from torchvision import transforms
10 |
11 | def icarl_replay(self, dataset, val_set_split=0):
12 | """
13 | Merge the replay buffer with the current task data.
14 | Optionally split the replay buffer into a validation set.
15 | :param self: the model instance
16 | :param dataset: the dataset
17 | :param val_set_split: the fraction of the replay buffer to be used as validation set
18 | """
19 | if self.task > 0:
20 | buff_val_mask = torch.rand(len(self.buffer)) < val_set_split
21 | val_train_mask = torch.zeros(len(dataset.train_loader.dataset.data)).bool()
22 | val_train_mask[torch.randperm(len(dataset.train_loader.dataset.data))[:buff_val_mask.sum()]] = True
23 |
24 | if val_set_split > 0:
25 | self.val_loader = deepcopy(dataset.train_loader)
26 | data_concatenate = torch.cat if type(dataset.train_loader.dataset.data) == torch.Tensor else np.concatenate
27 | need_aug = hasattr(dataset.train_loader.dataset, 'not_aug_transform')
28 | if not need_aug:
29 | refold_transform = lambda x: x.cpu()
30 | else:
31 | data_shape = len(dataset.train_loader.dataset.data[0].shape)
32 | if data_shape == 3:
33 | refold_transform = lambda x: (x.cpu()*255).permute([0, 2, 3, 1]).numpy().astype(np.uint8)
34 | elif data_shape == 2:
35 | refold_transform = lambda x: (x.cpu()*255).squeeze(1).type(torch.uint8)
36 |
37 | # REDUCE AND MERGE TRAINING SET
38 | dataset.train_loader.dataset.targets = np.concatenate([
39 | dataset.train_loader.dataset.targets[~val_train_mask],
40 | self.buffer.labels.cpu().numpy()[:len(self.buffer)][~buff_val_mask]
41 | ])
42 | dataset.train_loader.dataset.data = data_concatenate([
43 | dataset.train_loader.dataset.data[~val_train_mask],
44 | refold_transform((self.buffer.examples)[:len(self.buffer)][~buff_val_mask])
45 | ])
46 |
47 | if val_set_split > 0:
48 | # REDUCE AND MERGE VALIDATION SET
49 | self.val_loader.dataset.targets = np.concatenate([
50 | self.val_loader.dataset.targets[val_train_mask],
51 | self.buffer.labels.cpu().numpy()[:len(self.buffer)][buff_val_mask]
52 | ])
53 | self.val_loader.dataset.data = data_concatenate([
54 | self.val_loader.dataset.data[val_train_mask],
55 | refold_transform((self.buffer.examples)[:len(self.buffer)][buff_val_mask])
56 | ])
57 |
58 | def reservoir(num_seen_examples: int, buffer_size: int) -> int:
59 | """
60 | Reservoir sampling algorithm.
61 | :param num_seen_examples: the number of seen examples
62 | :param buffer_size: the maximum buffer size
63 | :return: the target index if the current image is sampled, else -1
64 | """
65 | if num_seen_examples < buffer_size:
66 | return num_seen_examples
67 |
68 | rand = np.random.randint(0, num_seen_examples + 1)
69 | if rand < buffer_size:
70 | return rand
71 | else:
72 | return -1
73 |
74 |
75 | def ring(num_seen_examples: int, buffer_portion_size: int, task: int) -> int:
76 | return num_seen_examples % buffer_portion_size + task * buffer_portion_size
77 |
78 |
79 | class Buffer:
80 | """
81 | The memory buffer of rehearsal method.
82 | """
83 | def __init__(self, buffer_size, device, n_tasks=None, mode='reservoir'):
84 | assert mode in ['ring', 'reservoir']
85 | self.buffer_size = buffer_size
86 | self.device = device
87 | self.num_seen_examples = 0
88 | self.functional_index = eval(mode)
89 | if mode == 'ring':
90 | assert n_tasks is not None
91 | self.task_number = n_tasks
92 | self.buffer_portion_size = buffer_size // n_tasks
93 | self.attributes = ['examples', 'labels', 'logits', 'task_labels']
94 |
95 | def init_tensors(self, examples: torch.Tensor, labels: torch.Tensor,
96 | logits: torch.Tensor, task_labels: torch.Tensor) -> None:
97 | """
98 | Initializes just the required tensors.
99 | :param examples: tensor containing the images
100 | :param labels: tensor containing the labels
101 | :param logits: tensor containing the outputs of the network
102 | :param task_labels: tensor containing the task labels
103 | """
104 | for attr_str in self.attributes:
105 | attr = eval(attr_str)
106 | if attr is not None and not hasattr(self, attr_str):
107 | typ = torch.int64 if attr_str.endswith('els') else torch.float32
108 | setattr(self, attr_str, torch.zeros((self.buffer_size,
109 | *attr.shape[1:]), dtype=typ, device=self.device))
110 |
111 | def add_data(self, examples, labels=None, logits=None, task_labels=None):
112 | """
113 | Adds the data to the memory buffer according to the reservoir strategy.
114 | :param examples: tensor containing the images
115 | :param labels: tensor containing the labels
116 | :param logits: tensor containing the outputs of the network
117 | :param task_labels: tensor containing the task labels
118 | :return:
119 | """
120 | if not hasattr(self, 'examples'):
121 | self.init_tensors(examples, labels, logits, task_labels)
122 |
123 | for i in range(examples.shape[0]):
124 | index = reservoir(self.num_seen_examples, self.buffer_size)
125 | self.num_seen_examples += 1
126 | if index >= 0:
127 | self.examples[index] = examples[i].to(self.device)
128 | if labels is not None:
129 | self.labels[index] = labels[i].to(self.device)
130 | if logits is not None:
131 | self.logits[index] = logits[i].to(self.device)
132 | if task_labels is not None:
133 | self.task_labels[index] = task_labels[i].to(self.device)
134 |
135 | def get_data(self, size: int, transform: transforms=None) -> Tuple:
136 | """
137 | Random samples a batch of size items.
138 | :param size: the number of requested items
139 | :param transform: the transformation to be applied (data augmentation)
140 | :return:
141 | """
142 | if size > min(self.num_seen_examples, self.examples.shape[0]):
143 | size = min(self.num_seen_examples, self.examples.shape[0])
144 |
145 | choice = np.random.choice(min(self.num_seen_examples, self.examples.shape[0]),
146 | size=size, replace=False)
147 | if transform is None: transform = lambda x: x
148 | # import pdb
149 | # pdb.set_trace()
150 | ret_tuple = (torch.stack([transform(ee.cpu())
151 | for ee in self.examples[choice]]).to(self.device),)
152 | for attr_str in self.attributes[1:]:
153 | if hasattr(self, attr_str):
154 | attr = getattr(self, attr_str)
155 | ret_tuple += (attr[choice],)
156 |
157 | return ret_tuple
158 |
159 | def is_empty(self) -> bool:
160 | """
161 | Returns true if the buffer is empty, false otherwise.
162 | """
163 | if self.num_seen_examples == 0:
164 | return True
165 | else:
166 | return False
167 |
168 | def get_all_data(self, transform: transforms=None) -> Tuple:
169 | """
170 | Return all the items in the memory buffer.
171 | :param transform: the transformation to be applied (data augmentation)
172 | :return: a tuple with all the items in the memory buffer
173 | """
174 | if transform is None: transform = lambda x: x
175 | ret_tuple = (torch.stack([transform(ee.cpu())
176 | for ee in self.examples]).to(self.device),)
177 | for attr_str in self.attributes[1:]:
178 | if hasattr(self, attr_str):
179 | attr = getattr(self, attr_str)
180 | ret_tuple += (attr,)
181 | return ret_tuple
182 |
183 | def empty(self) -> None:
184 | """
185 | Set all the tensors to None.
186 | """
187 | for attr_str in self.attributes:
188 | if hasattr(self, attr_str):
189 | delattr(self, attr_str)
190 | self.num_seen_examples = 0
191 |
--------------------------------------------------------------------------------
/utils/conf.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import random
7 | import torch
8 | import numpy as np
9 |
10 | def get_device() -> torch.device:
11 | """
12 | Returns the GPU device if available else CPU.
13 | """
14 | return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
15 |
16 |
17 | def base_path() -> str:
18 | """
19 | Returns the base bath where to log accuracies and tensorboard data.
20 | """
21 | return './data/'
22 |
23 |
24 | def set_random_seed(seed: int) -> None:
25 | """
26 | Sets the seeds at a certain value.
27 | :param seed: the value to be set
28 | """
29 | random.seed(seed)
30 | np.random.seed(seed)
31 | torch.manual_seed(seed)
32 | torch.cuda.manual_seed_all(seed)
33 |
--------------------------------------------------------------------------------
/utils/continual_training.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 | from datasets import get_gcl_dataset
8 | from models import get_model
9 | from utils.status import progress_bar
10 | from utils.tb_logger import *
11 | from utils.status import create_fake_stash
12 | from models.utils.continual_model import ContinualModel
13 | from argparse import Namespace
14 |
15 |
16 | def evaluate(model: ContinualModel, dataset) -> float:
17 | """
18 | Evaluates the final accuracy of the model.
19 | :param model: the model to be evaluated
20 | :param dataset: the GCL dataset at hand
21 | :return: a float value that indicates the accuracy
22 | """
23 | model.net.eval()
24 | correct, total = 0, 0
25 | while not dataset.test_over:
26 | inputs, labels = dataset.get_test_data()
27 | inputs, labels = inputs.to(model.device), labels.to(model.device)
28 | outputs = model(inputs)
29 | _, predicted = torch.max(outputs.data, 1)
30 | correct += torch.sum(predicted == labels).item()
31 | total += labels.shape[0]
32 |
33 | acc = correct / total * 100
34 | return acc
35 |
36 |
37 | def train(args: Namespace):
38 | """
39 | The training process, including evaluations and loggers.
40 | :param model: the module to be trained
41 | :param dataset: the continual dataset at hand
42 | :param args: the arguments of the current execution
43 | """
44 | if args.csv_log:
45 | from utils.loggers import CsvLogger
46 |
47 | dataset = get_gcl_dataset(args)
48 | backbone = dataset.get_backbone()
49 | loss = dataset.get_loss()
50 | model = get_model(args, backbone, loss, dataset.get_transform())
51 | model.net.to(model.device)
52 |
53 | model_stash = create_fake_stash(model, args)
54 |
55 | if args.csv_log:
56 | csv_logger = CsvLogger(dataset.SETTING, dataset.NAME, model.NAME)
57 |
58 | model.net.train()
59 | epoch, i = 0, 0
60 | while not dataset.train_over:
61 | inputs, labels, not_aug_inputs = dataset.get_train_data()
62 | inputs, labels = inputs.to(model.device), labels.to(model.device)
63 | not_aug_inputs = not_aug_inputs.to(model.device)
64 | loss = model.observe(inputs, labels, not_aug_inputs)
65 | progress_bar(i, dataset.LENGTH // args.batch_size, epoch, 'C', loss)
66 | i += 1
67 |
68 | if model.NAME == 'joint_gcl':
69 | model.end_task(dataset)
70 |
71 | acc = evaluate(model, dataset)
72 | print('Accuracy:', acc)
73 |
74 | if args.csv_log:
75 | csv_logger.log(acc)
76 | csv_logger.write(vars(args))
77 |
--------------------------------------------------------------------------------
/utils/deep_inversion.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import torch
10 |
11 | class DeepInversionFeatureHook():
12 | '''
13 | Implementation of the forward hook to track feature statistics and compute a loss on them.
14 | Will compute mean and variance, and will use l2 as a loss
15 | '''
16 |
17 | def __init__(self, module):
18 | self.hook = module.register_forward_hook(self.hook_fn)
19 | self.mean = None
20 | self.var = None
21 | self.use_stored_stats = False
22 | self.capture_bn_stats = False
23 |
24 |
25 | def hook_fn(self, module, input, output):
26 | # hook co compute deepinversion's feature distribution regularization
27 | nch_in = input[0].shape[1]
28 | nch_out = output.shape[1]
29 |
30 | mean = input[0].mean([0, 2, 3])
31 | var = input[0].permute(1, 0, 2, 3).contiguous().view([nch_in, -1]).var(1, unbiased=False)
32 |
33 | in_mean = input[0].mean([0, 2, 3])
34 | in_var = input[0].permute(1, 0, 2, 3).contiguous().view([nch_in, -1]).var(1, unbiased=False)
35 |
36 | out_mean = output.mean([0, 2, 3])
37 | out_var = output.permute(1, 0, 2, 3).contiguous().view([nch_out, -1]).var(1, unbiased=False)
38 |
39 | if self.capture_bn_stats:
40 | self.out_mean = out_mean.clone().detach()
41 | self.out_var = out_var.clone().detach()
42 |
43 | if not self.use_stored_stats:
44 | r_feature = torch.norm(module.running_var.data.type(in_var.type()) - in_var, 2) + torch.norm(
45 | module.running_mean.data.type(in_mean.type()) - in_mean, 2)
46 | else:
47 | r_feature = torch.norm(self.out_var - out_var, 2) + torch.norm(self.out_mean - out_mean, 2)
48 |
49 | self.r_feature = r_feature
50 | # must have no output
51 |
52 | def close(self):
53 | self.hook.remove()
54 |
55 |
--------------------------------------------------------------------------------
/utils/loggers.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import csv
7 | import os
8 | import sys
9 | from typing import Dict, Any
10 | from utils.metrics import *
11 |
12 | from utils import create_if_not_exists
13 | from utils.conf import base_path
14 | import numpy as np
15 |
16 | useless_args = ['dataset', 'tensorboard', 'validation', 'model',
17 | 'csv_log', 'notes', 'load_best_args']
18 |
19 |
20 | def print_mean_accuracy(mean_acc: np.ndarray, task_number: int,
21 | setting: str) -> None:
22 | """
23 | Prints the mean accuracy on stderr.
24 | :param mean_acc: mean accuracy value
25 | :param task_number: task index
26 | :param setting: the setting of the benchmark
27 | """
28 | if setting == 'domain-il':
29 | mean_acc, _ = mean_acc
30 | print('\nAccuracy for {} task(s): {} %'.format(
31 | task_number, round(mean_acc, 2)), file=sys.stderr)
32 | else:
33 | mean_acc_class_il, mean_acc_task_il = mean_acc
34 | print('\nAccuracy for {} task(s): \t [Class-IL]: {} %'
35 | ' \t [Task-IL]: {} %\n'.format(task_number, round(
36 | mean_acc_class_il, 2), round(mean_acc_task_il, 2)), file=sys.stderr)
37 |
38 |
39 | class CsvLogger:
40 | def __init__(self, setting_str: str, dataset_str: str,
41 | model_str: str) -> None:
42 | self.accs = []
43 | if setting_str == 'class-il':
44 | self.accs_mask_classes = []
45 | self.setting = setting_str
46 | self.dataset = dataset_str
47 | self.model = model_str
48 | self.fwt = None
49 | self.fwt_mask_classes = None
50 | self.bwt = None
51 | self.bwt_mask_classes = None
52 | self.forgetting = None
53 | self.forgetting_mask_classes = None
54 |
55 | def add_fwt(self, results, accs, results_mask_classes, accs_mask_classes):
56 | self.fwt = forward_transfer(results, accs)
57 | if self.setting == 'class-il':
58 | self.fwt_mask_classes = forward_transfer(results_mask_classes, accs_mask_classes)
59 |
60 | def add_bwt(self, results, results_mask_classes):
61 | self.bwt = backward_transfer(results)
62 | self.bwt_mask_classes = backward_transfer(results_mask_classes)
63 |
64 | def add_forgetting(self, results, results_mask_classes):
65 | self.forgetting = forgetting(results)
66 | self.forgetting_mask_classes = forgetting(results_mask_classes)
67 |
68 | def log(self, mean_acc: np.ndarray) -> None:
69 | """
70 | Logs a mean accuracy value.
71 | :param mean_acc: mean accuracy value
72 | """
73 | if self.setting == 'general-continual':
74 | self.accs.append(mean_acc)
75 | elif self.setting == 'domain-il':
76 | mean_acc, _ = mean_acc
77 | self.accs.append(mean_acc)
78 | else:
79 | mean_acc_class_il, mean_acc_task_il = mean_acc
80 | self.accs.append(mean_acc_class_il)
81 | self.accs_mask_classes.append(mean_acc_task_il)
82 |
83 | def write(self, ckpt_dir, args: Dict[str, Any]) -> None:
84 | """
85 | writes out the logged value along with its arguments.
86 | :param args: the namespace of the current experiment
87 | """
88 | for cc in useless_args:
89 | if cc in args:
90 | del args[cc]
91 |
92 | columns = list(args.keys())
93 |
94 | new_cols = []
95 | for i, acc in enumerate(self.accs):
96 | args['task' + str(i + 1)] = acc
97 | new_cols.append('task' + str(i + 1))
98 |
99 | args['forward_transfer'] = self.fwt
100 | new_cols.append('forward_transfer')
101 |
102 | args['backward_transfer'] = self.bwt
103 | new_cols.append('backward_transfer')
104 |
105 | args['forgetting'] = self.forgetting
106 | new_cols.append('forgetting')
107 |
108 | columns = new_cols + columns
109 |
110 | create_if_not_exists(ckpt_dir + "results/" + self.setting)
111 | create_if_not_exists(ckpt_dir + "results/" + self.setting +
112 | "/" + self.dataset)
113 | create_if_not_exists(ckpt_dir + "results/" + self.setting +
114 | "/" + self.dataset + "/" + self.model)
115 |
116 | write_headers = False
117 | path = ckpt_dir + "results/" + self.setting + "/" + self.dataset\
118 | + "/" + self.model + "/mean_accs.csv"
119 | if not os.path.exists(path):
120 | write_headers = True
121 | with open(path, 'a') as tmp:
122 | writer = csv.DictWriter(tmp, fieldnames=columns)
123 | if write_headers:
124 | writer.writeheader()
125 | writer.writerow(args)
126 |
127 | if self.setting == 'class-il':
128 | create_if_not_exists(ckpt_dir + "results/task-il/"
129 | + self.dataset)
130 | create_if_not_exists(ckpt_dir + "results/task-il/"
131 | + self.dataset + "/" + self.model)
132 |
133 | for i, acc in enumerate(self.accs_mask_classes):
134 | args['task' + str(i + 1)] = acc
135 |
136 | args['forward_transfer'] = self.fwt_mask_classes
137 | args['backward_transfer'] = self.bwt_mask_classes
138 | args['forgetting'] = self.forgetting_mask_classes
139 |
140 | write_headers = False
141 | path = ckpt_dir + "results/task-il" + "/" + self.dataset + "/"\
142 | + self.model + "/mean_accs.csv"
143 | if not os.path.exists(path):
144 | write_headers = True
145 | with open(path, 'a') as tmp:
146 | writer = csv.DictWriter(tmp, fieldnames=columns)
147 | if write_headers:
148 | writer.writeheader()
149 | writer.writerow(args)
150 |
--------------------------------------------------------------------------------
/utils/losses.py:
--------------------------------------------------------------------------------
1 | # Originated from https://github.com/sutd-visual-computing-group/LS-KD-compatibility/blob/master/src/image_classification/imagenet/utils.py
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torchvision.models as models
6 | import torch.nn.functional as F
7 |
8 |
9 | # Define Smooth Loss
10 | class LabelSmoothing(nn.Module):
11 | """
12 | NLL loss with label smoothing.
13 | https://github.com/NVIDIA/DeepLearningExamples/blob/8d8b21a933fff3defb692e0527fca15532da5dc6/PyTorch/Classification/ConvNets/image_classification/smoothing.py
14 | """
15 |
16 | def __init__(self, smoothing=0.0):
17 | """
18 | Constructor for the LabelSmoothing module.
19 | :param smoothing: label smoothing factor
20 | """
21 | super(LabelSmoothing, self).__init__()
22 | self.confidence = 1.0 - smoothing
23 | self.smoothing = smoothing
24 |
25 | def forward(self, x, target):
26 | logprobs = torch.nn.functional.log_softmax(x, dim=-1)
27 |
28 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
29 | nll_loss = nll_loss.squeeze(1)
30 | smooth_loss = -logprobs.mean(dim=-1)
31 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss
32 | return loss.mean()
33 |
34 |
35 | # Define KL divergence loss
36 | class KL_div_Loss(nn.Module):
37 | """
38 | We use formulation of Hinton et. for KD loss.
39 | $T^2$ scaling is implemented to avoid gradient rescaling when using T!=1
40 | """
41 |
42 | def __init__(self, temperature):
43 | """
44 | Constructor for the LabelSmoothing module.
45 | :param smoothing: label smoothing factor
46 | """
47 | super(KL_div_Loss, self).__init__()
48 | self.temperature = temperature
49 | #print( "Setting temperature = {} for KD (Only Teacher)".format(self.temperature) )
50 | print( "Setting temperature = {} for KD".format(self.temperature) )
51 |
52 |
53 | def forward(self, y, teacher_scores):
54 | p = F.log_softmax(y / self.temperature, dim=1) # Hinton formulation
55 |
56 | #p = F.log_softmax(y, dim=1) # Muller et. al used this.
57 |
58 | q = F.softmax(teacher_scores / self.temperature, dim=1)
59 | l_kl = F.kl_div(p, q, reduction='batchmean')
60 | return l_kl*(self.temperature**2) # $T^2$ scaling is important
61 |
--------------------------------------------------------------------------------
/utils/metrics.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 | import numpy as np
8 | from datasets.utils.continual_dataset import ContinualDataset
9 | from typing import Tuple
10 |
11 |
12 | def backward_transfer(results):
13 | n_tasks = len(results)
14 | li = list()
15 | for i in range(n_tasks - 1):
16 | li.append(results[-1][i] - results[i][i])
17 |
18 | return np.mean(li)
19 |
20 |
21 | def forward_transfer(results, random_results):
22 | n_tasks = len(results)
23 | li = list()
24 | for i in range(1, n_tasks):
25 | li.append(results[i-1][i] - random_results[i])
26 |
27 | return np.mean(li)
28 |
29 |
30 | def forgetting(results):
31 | n_tasks = len(results)
32 | li = list()
33 | for i in range(n_tasks - 1):
34 | results[i] += [0.0] * (n_tasks - len(results[i]))
35 | np_res = np.array(results)
36 | maxx = np.max(np_res, axis=0)
37 | for i in range(n_tasks - 1):
38 | li.append(maxx[i] - results[-1][i])
39 |
40 | return np.mean(li)
41 |
42 |
43 | def mask_classes(outputs: torch.Tensor, dataset: ContinualDataset, k: int):
44 | """
45 | Given the output tensor, the dataset at hand and the current task,
46 | masks the former by setting the responses for the other tasks at -inf.
47 | It is used to obtain the results for the task-il setting.
48 | :param outputs: the output tensor
49 | :param dataset: the continual dataset
50 | :param k: the task index
51 | """
52 | outputs[:, 0:k * dataset.N_CLASSES_PER_TASK] = -float('inf')
53 | outputs[:, (k + 1) * dataset.N_CLASSES_PER_TASK:
54 | dataset.N_TASKS * dataset.N_CLASSES_PER_TASK] = -float('inf')
55 |
56 | return outputs
57 |
58 |
59 |
--------------------------------------------------------------------------------
/utils/status.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from datetime import datetime
7 | import sys
8 | import os
9 | from utils.conf import base_path
10 | from typing import Any, Dict, Union
11 | from torch import nn
12 | from argparse import Namespace
13 | from datasets.utils.continual_dataset import ContinualDataset
14 |
15 |
16 | def create_stash(model: nn.Module, args: Namespace,
17 | dataset: ContinualDataset) -> Dict[Any, str]:
18 | """
19 | Creates the dictionary where to save the model status.
20 | :param model: the model
21 | :param args: the current arguments
22 | :param dataset: the dataset at hand
23 | """
24 | now = datetime.now()
25 | model_stash = {'task_idx': 0, 'epoch_idx': 0, 'batch_idx': 0}
26 | name_parts = [args.dataset, model.NAME]
27 | if 'buffer_size' in vars(args).keys():
28 | name_parts.append('buf_' + str(args.buffer_size))
29 | name_parts.append(now.strftime("%Y%m%d_%H%M%S_%f"))
30 | model_stash['model_name'] = '/'.join(name_parts)
31 | model_stash['mean_accs'] = []
32 | model_stash['args'] = args
33 | model_stash['backup_folder'] = os.path.join(base_path(), 'backups',
34 | dataset.SETTING,
35 | model_stash['model_name'])
36 | return model_stash
37 |
38 |
39 | def create_fake_stash(model: nn.Module, args: Namespace) -> Dict[Any, str]:
40 | """
41 | Create a fake stash, containing just the model name.
42 | This is used in general continual, as it is useless to backup
43 | a lightweight MNIST-360 training.
44 | :param model: the model
45 | :param args: the arguments of the call
46 | :return: a dict containing a fake stash
47 | """
48 | now = datetime.now()
49 | model_stash = {'task_idx': 0, 'epoch_idx': 0}
50 | name_parts = [args.dataset, model.NAME]
51 | if 'buffer_size' in vars(args).keys():
52 | name_parts.append('buf_' + str(args.buffer_size))
53 | name_parts.append(now.strftime("%Y%m%d_%H%M%S_%f"))
54 | model_stash['model_name'] = '/'.join(name_parts)
55 |
56 | return model_stash
57 |
58 |
59 | def progress_bar(i: int, max_iter: int, epoch: Union[int, str],
60 | task_number: int, loss: float) -> None:
61 | """
62 | Prints out the progress bar on the stderr file.
63 | :param i: the current iteration
64 | :param max_iter: the maximum number of iteration
65 | :param epoch: the epoch
66 | :param task_number: the task index
67 | :param loss: the current value of the loss function
68 | """
69 | if not (i + 1) % 10 or (i + 1) == max_iter:
70 | progress = min(float((i + 1) / max_iter), 1)
71 | progress_bar = ('█' * int(50 * progress)) + ('┈' * (50 - int(50 * progress)))
72 | print('\r[ {} ] Task {} | epoch {}: |{}| loss: {}'.format(
73 | datetime.now().strftime("%m-%d | %H:%M"),
74 | task_number + 1 if isinstance(task_number, int) else task_number,
75 | epoch,
76 | progress_bar,
77 | round(loss, 8)
78 | ), file=sys.stderr, end='', flush=True)
79 |
--------------------------------------------------------------------------------
/utils/tb_logger.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
2 | # All rights reserved.
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from utils.conf import base_path
7 | import os
8 | from argparse import Namespace
9 | from typing import Dict, Any
10 | import numpy as np
11 | import torchvision
12 | import matplotlib.pyplot as plt
13 | import torch
14 | import torch.nn.functional as F
15 |
16 | def img_denormlaize(img):
17 | """Scaling and shift a batch of images (NCHW)
18 | """
19 | mean = [0.4914, 0.4822, 0.4465]
20 | std = [0.2470, 0.2435, 0.2615]
21 | nch = img.shape[1]
22 |
23 | mean = torch.tensor(mean, device=img.device).reshape(1, nch, 1, 1)
24 | std = torch.tensor(std, device=img.device).reshape(1, nch, 1, 1)
25 |
26 | return img * std + mean
27 |
28 |
29 | def save_img(img, unnormalize=True, max_num=5, size=32, nrow=5, dataname='imagenet'):
30 | img = img[:max_num].detach()
31 | if unnormalize:
32 | img = img_denormlaize(img)
33 | images = torch.clamp(img, min=0., max=1.)
34 | images = torchvision.utils.make_grid(images, nrow=nrow, padding=2)
35 | # print(images.shape)
36 | # if img.shape[-1] > size:
37 | # img = F.interpolate(img, size)
38 |
39 | return images
40 |
41 | class TensorboardLogger:
42 | def __init__(self, args: Namespace, setting: str,
43 | stash: Dict[Any, str]=None) -> None:
44 | from torch.utils.tensorboard import SummaryWriter
45 |
46 | self.settings = [setting]
47 | if setting == 'class-il':
48 | self.settings.append('task-il')
49 | self.loggers = {}
50 | self.name = args.model.backbone
51 | for a_setting in self.settings:
52 | self.loggers[a_setting] = SummaryWriter(
53 | os.path.join(args.ckpt_dir, 'tensorboard_runs'))
54 | config_text = ', '.join(
55 | ["%s=%s" % (name, getattr(args, name)) for name in args.__dir__()
56 | if not name.startswith('_')])
57 | for a_logger in self.loggers.values():
58 | a_logger.add_text('config', config_text)
59 |
60 | def get_name(self) -> str:
61 | """
62 | :return: the name of the model
63 | """
64 | return self.name
65 |
66 | def log_accuracy(self, all_accs: np.ndarray, all_mean_accs: np.ndarray,
67 | args: Namespace, task_number: int) -> None:
68 | """
69 | Logs the current accuracy value for each task.
70 | :param all_accs: the accuracies (class-il, task-il) for each task
71 | :param all_mean_accs: the mean accuracies for (class-il, task-il)
72 | :param args: the arguments of the run
73 | :param task_number: the task index
74 | """
75 | mean_acc_common, mean_acc_task_il = all_mean_accs
76 | for setting, a_logger in self.loggers.items():
77 | mean_acc = mean_acc_task_il\
78 | if setting == 'task-il' else mean_acc_common
79 | index = 1 if setting == 'task-il' else 0
80 | accs = [all_accs[index][kk] for kk in range(len(all_accs[0]))]
81 | for kk, acc in enumerate(accs):
82 | a_logger.add_scalar('acc_task%02d' % (kk + 1), acc,
83 | task_number * args.train.num_epochs)
84 | a_logger.add_scalar('acc_mean', mean_acc, task_number * args.train.num_epochs)
85 |
86 | def log_loss(self, loss: float, args: Namespace, epoch: int,
87 | task_number: int, iteration: int) -> None:
88 | """
89 | Logs the loss value at each iteration.
90 | :param loss: the loss value
91 | :param args: the arguments of the run
92 | :param epoch: the epoch index
93 | :param task_number: the task index
94 | :param iteration: the current iteration
95 | """
96 | for a_logger in self.loggers.values():
97 | a_logger.add_scalar('loss', loss, task_number * args.train.num_epochs + epoch)
98 |
99 |
100 | def log_penalty(self, penalty: float, args: Namespace, epoch: int,
101 | task_number: int, iteration: int) -> None:
102 | """
103 | Logs the loss penalty value at each iteration.
104 | :param loss penalty: the loss penalty value
105 | :param args: the arguments of the run
106 | :param epoch: the epoch index
107 | :param task_number: the task index
108 | :param iteration: the current iteration
109 | """
110 | for a_logger in self.loggers.values():
111 | a_logger.add_scalar('penalty', penalty, task_number * args.train.num_epochs + epoch)
112 |
113 |
114 | def log_lr(self, lr: float, args: Namespace, epoch: int,
115 | task_number: int, iteration: int) -> None:
116 | """
117 | Logs the lr value at each iteration.
118 | :param lr: the lr value
119 | :param iteration: the current iteration
120 | """
121 | for a_logger in self.loggers.values():
122 | a_logger.add_scalar('lr', lr, iteration)
123 | a_logger.add_scalar('lr', lr, task_number * args.train.num_epochs + epoch)
124 |
125 | def log_images(self, images, args: Namespace, epoch: int,
126 | task_number: int, iteration: int) -> None:
127 | """
128 | Logs the lr value at each iteration.
129 | :param lr: the lr value
130 | :param iteration: the current iteration
131 | """
132 | # img_grid = torchvision.utils.make_grid(images)
133 | # matplotlib_imshow(img_grid)
134 | images = save_img(images)
135 | for a_logger in self.loggers.values():
136 | a_logger.add_image('syn_images', images, task_number * args.train.num_epochs + epoch)
137 |
138 | def log_loss_gcl(self, loss: float, iteration: int) -> None:
139 | """
140 | Logs the loss value at each iteration.
141 | :param loss: the loss value
142 | :param iteration: the current iteration
143 | """
144 | for a_logger in self.loggers.values():
145 | a_logger.add_scalar('loss', loss, iteration)
146 |
147 | def close(self) -> None:
148 | """
149 | At the end of the execution, closes the logger.
150 | """
151 | for a_logger in self.loggers.values():
152 | a_logger.close()
153 |
--------------------------------------------------------------------------------