├── .gitignore ├── LICENSE ├── README.md ├── arguments.py ├── assets └── concept_figure.png ├── augmentations ├── __init__.py ├── eval_aug.py ├── gaussian_blur.py └── simsiam_aug.py ├── configs ├── __init__.py ├── cifar10 │ ├── distil.yaml │ ├── distilbuf.yaml │ └── qdi.yaml ├── cifar100 │ ├── distil.yaml │ ├── distilbuf.yaml │ └── qdi.yaml └── tinyimg │ ├── distil.yaml │ ├── distilbuf.yaml │ └── qdi.yaml ├── datasets ├── __init__.py ├── datasets_utils.py ├── random_dataset.py ├── seq_cifar10.py ├── seq_cifar100.py ├── seq_tinyimagenet.py ├── test │ ├── seq-cifar10.pt │ ├── seq-cifar100.pt │ ├── seq-domainnet.pt │ └── seq-tinyimg.pt ├── transforms │ ├── __init__.py │ ├── denormalization.py │ ├── permutation.py │ └── rotation.py └── utils │ ├── __init__.py │ ├── continual_dataset.py │ └── validation.py ├── linear_eval_alltasks.py ├── main.py ├── models ├── __init__.py ├── backbones │ ├── Alexnet.py │ ├── Densenet.py │ ├── Inception.py │ ├── Lenet.py │ ├── Regnet.py │ ├── ResNet18.py │ ├── ResNext.py │ ├── Senet.py │ ├── Swin.py │ ├── Vgg.py │ ├── __init__.py │ └── utils │ │ ├── __init__.py │ │ └── modules.py ├── distil.py ├── distilbuf.py ├── optimizers │ ├── __init__.py │ ├── lars.py │ └── lr_scheduler.py ├── qdi.py ├── simsiam.py └── utils │ ├── __init__.py │ └── continual_model.py ├── requirements.txt ├── tools ├── __init__.py ├── accuracy.py ├── average_meter.py ├── file_exist_fn.py ├── knn_monitor.py ├── logger.py └── plotter.py └── utils ├── __init__.py ├── args.py ├── batch_norm.py ├── buffer.py ├── conf.py ├── continual_training.py ├── deep_inversion.py ├── loggers.py ├── losses.py ├── metrics.py ├── status.py └── tb_logger.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.npy 2 | __pycache__/ 3 | checkpoints/ 4 | data/ 5 | logs/ 6 | wandb/ 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2023, NVIDIA Corporation. All rights reserved. 2 | 3 | Nvidia Source Code License-NC 4 | 5 | 1. Definitions 6 | 7 | “Licensor” means any person or entity that distributes its Work. 8 | 9 | “Work” means (a) the original work of authorship made available under this license, which may include software, documentation, 10 | or other files, and (b) any additions to or derivative works thereof that are made available under this license. 11 | 12 | The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as provided under U.S. 13 | copyright law; provided, however, that for the purposes of this license, derivative works shall not include works that 14 | remain separable from, or merely link (or bind by name) to the interfaces of, the Work. 15 | 16 | Works are “made available” under this license by including in or with the Work either (a) a copyright notice referencing 17 | the applicability of this license to the Work, or (b) a copy of this license. 18 | 19 | 2. License Grant 20 | 21 | 2.1 Copyright Grant. Subject to the terms and conditions of this license, each Licensor grants to you a perpetual, 22 | worldwide, non-exclusive, royalty-free, copyright license to use, reproduce, prepare derivative works of, publicly 23 | display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form. 24 | 25 | 3. Limitations 26 | 27 | 3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this license, (b) you include a 28 | complete copy of this license with your distribution, and (c) you retain without modification any copyright, patent, 29 | trademark, or attribution notices that are present in the Work. 30 | 31 | 3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and distribution 32 | of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use limitation in Section 3.3 33 | applies to your derivative works, and (b) you identify the specific derivative works that are subject to Your Terms. 34 | Notwithstanding Your Terms, this license (including the redistribution requirements in Section 3.1) will continue to apply 35 | to the Work itself. 36 | 37 | 3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use non-commercially. 38 | Notwithstanding the foregoing, NVIDIA Corporation and its affiliates may use the Work and any derivative works commercially. 39 | As used herein, “non-commercially” means for research or evaluation purposes only. 40 | 41 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim, cross-claim 42 | or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then your rights under 43 | this license from such Licensor (including the grant in Section 2.1) will terminate immediately. 44 | 45 | 3.5 Trademarks. This license does not grant any rights to use any Licensor’s or its affiliates’ names, logos, or trademarks, 46 | except as necessary to reproduce the notices described in this license. 47 | 48 | 3.6 Termination. If you violate any term of this license, then your rights under this license (including the grant in Section 2.1) 49 | will terminate immediately. 50 | 51 | 4. Disclaimer of Warranty. 52 | 53 | THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES 54 | OR CONDITIONS OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING 55 | ANY ACTIVITIES UNDER THIS LICENSE. 56 | 57 | 5. Limitation of Liability. 58 | 59 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, 60 | OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL 61 | DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, 62 | BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER DAMAGES OR LOSSES), EVEN IF THE LICENSOR 63 | HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. 64 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # Heterogeneous Continual Learning 4 | [![Conference](http://img.shields.io/badge/CVPR-2023(Highlight)-FFD93D.svg)](https://cvpr.thecvf.com/) 5 | [![Paper](http://img.shields.io/badge/Paper-arxiv.2303.14369-FF6B6B.svg)](https://arxiv.org/abs/2306.08593) 6 |
7 | 8 | Official PyTorch implementation of CVPR 2023 Highlight (Top 10%) paper [**Heterogeneous Continual Learning**](https://arxiv.org/abs/2306.08593). 9 | 10 | **Authors**: [Divyam Madaan](https://dmadaan.com/), [Hongxu Yin](https://hongxu-yin.github.i), [Wonmin Byeon](https://wonmin-byeon.github.i), [Pavlo Molchanov](https://research.nvidia.com/person/pavlo-molchano), 11 | 12 | For business inquiries, please visit our website and submit the form: [NVIDIA Research Licensing](https://www.nvidia.com/en-us/research/inquiries/) 13 | 14 | **TL;DR: First continual learning approach in which the architecture continuously evolves with the data.** 15 | 16 | ![conceptFigure.pdf](https://github.com/NVlabs/HCL/blob/main/assets/concept_figure.png) 17 | 18 | 19 | ## Abstract 20 | We propose a novel framework and a solution to tackle 21 | the continual learning (CL) problem with changing network 22 | architectures. Most CL methods focus on adapting a single 23 | architecture to a new task/class by modifying its weights. 24 | However, with rapid progress in architecture design, the 25 | problem of adapting existing solutions to novel architectures 26 | becomes relevant. To address this limitation, we propose 27 | Heterogeneous Continual Learning (HCL), where a wide 28 | range of evolving network architectures emerge continually 29 | together with novel data/tasks. As a solution, we build on 30 | top of the distillation family of techniques and modify it 31 | to a new setting where a weaker model takes the role of a 32 | teacher; meanwhile, a new stronger architecture acts as a 33 | student. Furthermore, we consider a setup of limited access 34 | to previous data and propose Quick Deep Inversion (QDI) to 35 | recover prior task visual features to support knowledge trans- 36 | fer. QDI significantly reduces computational costs compared 37 | to previous solutions and improves overall performance. In 38 | summary, we propose a new setup for CL with a modified 39 | knowledge distillation paradigm and design a quick data 40 | inversion method to enhance distillation. Our evaluation 41 | of various benchmarks shows a significant improvement on 42 | accuracy in comparison to state-of-the-art methods over 43 | various networks architectures. 44 | 45 | __Contribution of this work__ 46 | 47 | - We propose a novel CL framework called Heteroge- 48 | neous Continual Learning (HCL) to learn a stream of 49 | different architectures on a sequence of tasks while 50 | transferring the knowledge from past representations. 51 | - We revisit knowledge distillation and propose Quick 52 | Deep Inversion (QDI), which inverts the previous task 53 | parameters while interpolating the current task exam- 54 | ples with minimal additional cost. 55 | - We benchmark existing state-of-the-art solutions in the 56 | new setting and outperform them with our proposed 57 | method across a diverse stream of architectures for both 58 | task-incremental and class-incremental CL. 59 | 60 | ## Prerequisites 61 | 62 | ``` 63 | $ pip install -r requirements.txt 64 | ``` 65 | 66 | ## 🚀 Quick start 67 | 68 | ### Training 69 | 70 | ```python 71 | python main.py --data_dir ../data/ --log_dir ./logs/scl/ -c configs/cifar10/distil.yaml --ckpt_dir ./checkpoints/c10/scl/distil/ --hide_progress --cl_default --validation --hcl 72 | 73 | ``` 74 | 75 | ### Evaluation 76 | 77 | ```python 78 | python linear_eval_alltasks.py --data_dir ../data/ --log_dir ./logs/scl/ -c configs/cifar10/distil.yaml --ckpt_dir ./checkpoints/c10/scl/distil/ --hide_progress --cl_default --hcl 79 | 80 | ``` 81 | 82 | 83 | To change the dataset and method, use the configuration files from `./configs`. 84 | 85 | # Contributing 86 | 87 | We'd love to accept your contributions to this project. Please feel free to open an issue, or submit a pull request as necessary. If you have implementations of this repository in other ML frameworks, please reach out so we may highlight them here. 88 | 89 | ## 🎗️ Acknowledgment 90 | 91 | The code is build upon [aimagelab/mammoth](https://github.com/aimagelab/mammoth), [divyam3897/UCL](https://github.com/divyam3897/UCL), [kuangliu/pytorch-cifar](https://github.com/kuangliu/pytorch-cifar/tree/master), [sutd-visual-computing-group/LS-KD-compatibility](https://github.com/sutd-visual-computing-group/LS-KD-compatibility), and [berniwal/swin-transformer-pytorch](https://github.com/berniwal/swin-transformer-pytorch). 92 | 93 | We thank the authors for their amazing work and releasing the code base. 94 | 95 | 96 | ## Licenses 97 | 98 | Copyright © 2023, NVIDIA Corporation. All rights reserved. 99 | 100 | This work is made available under the NVIDIA Source Code License-NC. Click [here](LICENSE) to view a copy of this license. 101 | 102 | For license information regarding the mammoth repository, please refer to its [repository](https://github.com/aimagelab/mammoth/blob/master/LICENSE). 103 | For license information regarding the UCL repository, please refer to its [repository](https://github.com/divyam3897/UCL/blob/main/LICENSE). 104 | For license information regarding the pytorch-cifar repository, please refer to its [repository](https://github.com/kuangliu/pytorch-cifar/blob/master/LICENSE). 105 | For license information regarding the LS-KD repository, please refer to its [repository](https://github.com/sutd-visual-computing-group/LS-KD-compatibility/blob/master/LICENSE). 106 | For license information regarding the swin-transformer repository, please refer to its [repository](https://github.com/berniwal/swin-transformer-pytorch/blob/master/LICENSE). 107 | 108 | 109 | ## 📌 Citation 110 | 111 | If you find this paper useful, please consider staring 🌟 this repo and citing 📑 our paper: 112 | 113 | ```bibtex 114 | @inproceedings{madaan2023heterogeneous, 115 | title={Heterogeneous Continual Learning}, 116 | author={Madaan, Divyam and Yin, Hongxu and Byeon, Wonmin and Kautz, Jan and Molchanov, Pavlo}, 117 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 118 | year={2023} 119 | 120 | ``` 121 | -------------------------------------------------------------------------------- /arguments.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | 5 | import numpy as np 6 | import torch 7 | import random 8 | 9 | import re 10 | import yaml 11 | 12 | import shutil 13 | import warnings 14 | 15 | from datetime import datetime 16 | 17 | 18 | class Namespace(object): 19 | def __init__(self, somedict): 20 | for key, value in somedict.items(): 21 | assert isinstance(key, str) and re.match("[A-Za-z_-]", key) 22 | if isinstance(value, dict): 23 | self.__dict__[key] = Namespace(value) 24 | else: 25 | self.__dict__[key] = value 26 | 27 | def __getattr__(self, attribute): 28 | 29 | raise AttributeError(f"Can not find {attribute} in namespace. Please write {attribute} in your config file(xxx.yaml)!") 30 | 31 | 32 | def set_deterministic(seed): 33 | # seed by default is None 34 | if seed is not None: 35 | print(f"Deterministic with seed = {seed}") 36 | random.seed(seed) 37 | np.random.seed(seed) 38 | torch.manual_seed(seed) 39 | torch.cuda.manual_seed(seed) 40 | torch.backends.cudnn.deterministic = True 41 | torch.backends.cudnn.benchmark = False 42 | 43 | def get_args(): 44 | parser = argparse.ArgumentParser() 45 | parser.add_argument('-c', '--config-file', required=True, type=str, help="xxx.yaml") 46 | parser.add_argument('--debug', action='store_true') 47 | parser.add_argument('--debug_subset_size', type=int, default=8) 48 | parser.add_argument('--download', action='store_true', help="if can't find dataset, download from web") 49 | parser.add_argument('--data_dir', type=str, default=os.getenv('DATA')) 50 | parser.add_argument('--log_dir', type=str, default=os.getenv('LOG')) 51 | parser.add_argument('--ckpt_dir', type=str, default=os.getenv('CHECKPOINT')) 52 | parser.add_argument('--ckpt_dir_1', type=str, default=os.getenv('CHECKPOINT')) 53 | parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') 54 | parser.add_argument('--eval_from', type=str, default=None) 55 | parser.add_argument('--hide_progress', action='store_true') 56 | parser.add_argument('--cl_default', action='store_true') 57 | parser.add_argument('--server', action='store_true') 58 | parser.add_argument('--hcl', action='store_true') 59 | parser.add_argument('--buffer_qdi', action='store_true') 60 | parser.add_argument('--validation', action='store_true', 61 | help='Test on the validation set') 62 | parser.add_argument('--ood_eval', action='store_true', 63 | help='Test on the OOD set') 64 | parser.add_argument('--alpha', type=float, default=0.3) 65 | args = parser.parse_args() 66 | 67 | 68 | with open(args.config_file, 'r') as f: 69 | for key, value in Namespace(yaml.load(f, Loader=yaml.FullLoader)).__dict__.items(): 70 | vars(args)[key] = value 71 | 72 | if args.debug: 73 | if args.train: 74 | args.train.batch_size = 2 75 | args.train.num_epochs = 1 76 | args.train.stop_at_epoch = 1 77 | if args.eval: 78 | args.eval.batch_size = 2 79 | args.eval.num_epochs = 1 # train only one epoch 80 | args.dataset.num_workers = 0 81 | 82 | 83 | assert not None in [args.log_dir, args.data_dir, args.ckpt_dir, args.name] 84 | 85 | args.log_dir = os.path.join(args.log_dir, 'in-progress_'+datetime.now().strftime('%m%d%H%M%S_')+args.name) 86 | 87 | os.makedirs(args.log_dir, exist_ok=False) 88 | print(f'creating file {args.log_dir}') 89 | os.makedirs(args.ckpt_dir, exist_ok=True) 90 | 91 | shutil.copy2(args.config_file, args.log_dir) 92 | set_deterministic(args.seed) 93 | 94 | 95 | vars(args)['aug_kwargs'] = { 96 | 'name':args.model.name, 97 | 'image_size': args.dataset.image_size, 98 | 'cl_default': args.cl_default 99 | } 100 | vars(args)['dataset_kwargs'] = { 101 | # 'name':args.model.name, 102 | # 'image_size': args.dataset.image_size, 103 | 'dataset':args.dataset.name, 104 | 'data_dir': args.data_dir, 105 | 'download':args.download, 106 | 'debug_subset_size': args.debug_subset_size if args.debug else None, 107 | # 'drop_last': True, 108 | # 'pin_memory': True, 109 | # 'num_workers': args.dataset.num_workers, 110 | } 111 | vars(args)['dataloader_kwargs'] = { 112 | 'drop_last': True, 113 | 'pin_memory': True, 114 | 'num_workers': args.dataset.num_workers, 115 | } 116 | 117 | return args 118 | -------------------------------------------------------------------------------- /assets/concept_figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/HCL/4aa98387995946061ade7b3eede3715e7c29c1ba/assets/concept_figure.png -------------------------------------------------------------------------------- /augmentations/__init__.py: -------------------------------------------------------------------------------- 1 | from .simsiam_aug import SimSiamTransform 2 | from .eval_aug import Transform_single 3 | 4 | 5 | def get_aug(name='simsiam', image_size=224, train=True, train_classifier=None, mean_std=None, **aug_kwargs): 6 | if train==True: 7 | augmentation = SimSiamTransform(image_size, mean_std=mean_std, **aug_kwargs) 8 | elif train==False: 9 | if train_classifier is None: 10 | raise Exception 11 | augmentation = Transform_single(image_size, train=train_classifier, mean_std=mean_std) 12 | else: 13 | raise Exception 14 | 15 | return augmentation 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | -------------------------------------------------------------------------------- /augmentations/eval_aug.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | from PIL import Image 3 | 4 | 5 | class Transform_single(): 6 | def __init__(self, image_size, train, mean_std): 7 | if train == True: 8 | self.transform = transforms.Compose([ 9 | transforms.RandomResizedCrop(image_size, scale=(0.08, 1.0), ratio=(3.0/4.0,4.0/3.0), interpolation=Image.BICUBIC), 10 | # transforms.RandomCrop(image_size, padding=4), 11 | transforms.RandomHorizontalFlip(), 12 | transforms.ToTensor(), 13 | transforms.Normalize(*mean_std) 14 | ]) 15 | else: 16 | self.transform = transforms.Compose([ 17 | # transforms.Resize(int(image_size*(8/7)), interpolation=Image.BICUBIC), # 224 -> 256 18 | # transforms.CenterCrop(image_size), 19 | transforms.ToTensor(), 20 | transforms.Normalize(*mean_std) 21 | ]) 22 | 23 | def __call__(self, x): 24 | return self.transform(x) 25 | -------------------------------------------------------------------------------- /augmentations/gaussian_blur.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from torchvision.transforms.functional import to_pil_image, to_tensor 4 | from torch.nn.functional import conv2d, pad as torch_pad 5 | from typing import Any, List, Sequence, Optional 6 | import numbers 7 | import numpy as np 8 | import torch 9 | from PIL import Image 10 | from typing import Tuple 11 | 12 | class GaussianBlur(torch.nn.Module): 13 | """Blurs image with randomly chosen Gaussian blur. 14 | The image can be a PIL Image or a Tensor, in which case it is expected 15 | to have [..., C, H, W] shape, where ... means an arbitrary number of leading 16 | dimensions 17 | 18 | Args: 19 | kernel_size (int or sequence): Size of the Gaussian kernel. 20 | sigma (float or tuple of float (min, max)): Standard deviation to be used for 21 | creating kernel to perform blurring. If float, sigma is fixed. If it is tuple 22 | of float (min, max), sigma is chosen uniformly at random to lie in the 23 | given range. 24 | 25 | Returns: 26 | PIL Image or Tensor: Gaussian blurred version of the input image. 27 | 28 | """ 29 | 30 | def __init__(self, kernel_size, sigma=(0.1, 2.0)): 31 | super().__init__() 32 | self.kernel_size = _setup_size(kernel_size, "Kernel size should be a tuple/list of two integers") 33 | for ks in self.kernel_size: 34 | if ks <= 0 or ks % 2 == 0: 35 | raise ValueError("Kernel size value should be an odd and positive number.") 36 | 37 | if isinstance(sigma, numbers.Number): 38 | if sigma <= 0: 39 | raise ValueError("If sigma is a single number, it must be positive.") 40 | sigma = (sigma, sigma) 41 | elif isinstance(sigma, Sequence) and len(sigma) == 2: 42 | if not 0. < sigma[0] <= sigma[1]: 43 | raise ValueError("sigma values should be positive and of the form (min, max).") 44 | else: 45 | raise ValueError("sigma should be a single number or a list/tuple with length 2.") 46 | 47 | self.sigma = sigma 48 | 49 | @staticmethod 50 | def get_params(sigma_min: float, sigma_max: float) -> float: 51 | """Choose sigma for random gaussian blurring. 52 | 53 | Args: 54 | sigma_min (float): Minimum standard deviation that can be chosen for blurring kernel. 55 | sigma_max (float): Maximum standard deviation that can be chosen for blurring kernel. 56 | 57 | Returns: 58 | float: Standard deviation to be passed to calculate kernel for gaussian blurring. 59 | """ 60 | return torch.empty(1).uniform_(sigma_min, sigma_max).item() 61 | 62 | def forward(self, img: Tensor) -> Tensor: 63 | """ 64 | Args: 65 | img (PIL Image or Tensor): image to be blurred. 66 | 67 | Returns: 68 | PIL Image or Tensor: Gaussian blurred image 69 | """ 70 | sigma = self.get_params(self.sigma[0], self.sigma[1]) 71 | return gaussian_blur(img, self.kernel_size, [sigma, sigma]) 72 | 73 | def __repr__(self): 74 | s = '(kernel_size={}, '.format(self.kernel_size) 75 | s += 'sigma={})'.format(self.sigma) 76 | return self.__class__.__name__ + s 77 | 78 | @torch.jit.unused 79 | def _is_pil_image(img: Any) -> bool: 80 | return isinstance(img, Image.Image) 81 | def _setup_size(size, error_msg): 82 | if isinstance(size, numbers.Number): 83 | return int(size), int(size) 84 | 85 | if isinstance(size, Sequence) and len(size) == 1: 86 | return size[0], size[0] 87 | 88 | if len(size) != 2: 89 | raise ValueError(error_msg) 90 | 91 | return size 92 | def _is_tensor_a_torch_image(x: Tensor) -> bool: 93 | return x.ndim >= 2 94 | def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> Tensor: 95 | ksize_half = (kernel_size - 1) * 0.5 96 | 97 | x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) 98 | pdf = torch.exp(-0.5 * (x / sigma).pow(2)) 99 | kernel1d = pdf / pdf.sum() 100 | 101 | return kernel1d 102 | 103 | def _cast_squeeze_in(img: Tensor, req_dtype: torch.dtype) -> Tuple[Tensor, bool, bool, torch.dtype]: 104 | need_squeeze = False 105 | # make image NCHW 106 | if img.ndim < 4: 107 | img = img.unsqueeze(dim=0) 108 | need_squeeze = True 109 | 110 | out_dtype = img.dtype 111 | need_cast = False 112 | if out_dtype != req_dtype: 113 | need_cast = True 114 | img = img.to(req_dtype) 115 | return img, need_cast, need_squeeze, out_dtype 116 | def _cast_squeeze_out(img: Tensor, need_cast: bool, need_squeeze: bool, out_dtype: torch.dtype): 117 | if need_squeeze: 118 | img = img.squeeze(dim=0) 119 | 120 | if need_cast: 121 | # it is better to round before cast 122 | img = torch.round(img).to(out_dtype) 123 | 124 | return img 125 | def _get_gaussian_kernel2d( 126 | kernel_size: List[int], sigma: List[float], dtype: torch.dtype, device: torch.device 127 | ) -> Tensor: 128 | kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0]).to(device, dtype=dtype) 129 | kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1]).to(device, dtype=dtype) 130 | kernel2d = torch.mm(kernel1d_y[:, None], kernel1d_x[None, :]) 131 | return kernel2d 132 | def _gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Tensor: 133 | """PRIVATE METHOD. Performs Gaussian blurring on the img by given kernel. 134 | 135 | .. warning:: 136 | 137 | Module ``transforms.functional_tensor`` is private and should not be used in user application. 138 | Please, consider instead using methods from `transforms.functional` module. 139 | 140 | Args: 141 | img (Tensor): Image to be blurred 142 | kernel_size (sequence of int or int): Kernel size of the Gaussian kernel ``(kx, ky)``. 143 | sigma (sequence of float or float, optional): Standard deviation of the Gaussian kernel ``(sx, sy)``. 144 | 145 | Returns: 146 | Tensor: An image that is blurred using gaussian kernel of given parameters 147 | """ 148 | if not (isinstance(img, torch.Tensor) or _is_tensor_a_torch_image(img)): 149 | raise TypeError('img should be Tensor Image. Got {}'.format(type(img))) 150 | 151 | dtype = img.dtype if torch.is_floating_point(img) else torch.float32 152 | kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype, device=img.device) 153 | kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1]) 154 | 155 | img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, kernel.dtype) 156 | 157 | # padding = (left, right, top, bottom) 158 | padding = [kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[1] // 2] 159 | img = torch_pad(img, padding, mode="reflect") 160 | img = conv2d(img, kernel, groups=img.shape[-3]) 161 | 162 | img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype) 163 | return img 164 | 165 | def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Tensor: 166 | """Performs Gaussian blurring on the img by given kernel. 167 | The image can be a PIL Image or a Tensor, in which case it is expected 168 | to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions 169 | 170 | Args: 171 | img (PIL Image or Tensor): Image to be blurred 172 | kernel_size (sequence of ints or int): Gaussian kernel size. Can be a sequence of integers 173 | like ``(kx, ky)`` or a single integer for square kernels. 174 | In torchscript mode kernel_size as single int is not supported, use a tuple or 175 | list of length 1: ``[ksize, ]``. 176 | sigma (sequence of floats or float, optional): Gaussian kernel standard deviation. Can be a 177 | sequence of floats like ``(sigma_x, sigma_y)`` or a single float to define the 178 | same sigma in both X/Y directions. If None, then it is computed using 179 | ``kernel_size`` as ``sigma = 0.3 * ((kernel_size - 1) * 0.5 - 1) + 0.8``. 180 | Default, None. In torchscript mode sigma as single float is 181 | not supported, use a tuple or list of length 1: ``[sigma, ]``. 182 | 183 | Returns: 184 | PIL Image or Tensor: Gaussian Blurred version of the image. 185 | """ 186 | if not isinstance(kernel_size, (int, list, tuple)): 187 | raise TypeError('kernel_size should be int or a sequence of integers. Got {}'.format(type(kernel_size))) 188 | if isinstance(kernel_size, int): 189 | kernel_size = [kernel_size, kernel_size] 190 | if len(kernel_size) != 2: 191 | raise ValueError('If kernel_size is a sequence its length should be 2. Got {}'.format(len(kernel_size))) 192 | for ksize in kernel_size: 193 | if ksize % 2 == 0 or ksize < 0: 194 | raise ValueError('kernel_size should have odd and positive integers. Got {}'.format(kernel_size)) 195 | 196 | if sigma is None: 197 | sigma = [ksize * 0.15 + 0.35 for ksize in kernel_size] 198 | 199 | if sigma is not None and not isinstance(sigma, (int, float, list, tuple)): 200 | raise TypeError('sigma should be either float or sequence of floats. Got {}'.format(type(sigma))) 201 | if isinstance(sigma, (int, float)): 202 | sigma = [float(sigma), float(sigma)] 203 | if isinstance(sigma, (list, tuple)) and len(sigma) == 1: 204 | sigma = [sigma[0], sigma[0]] 205 | if len(sigma) != 2: 206 | raise ValueError('If sigma is a sequence, its length should be 2. Got {}'.format(len(sigma))) 207 | for s in sigma: 208 | if s <= 0.: 209 | raise ValueError('sigma should have positive values. Got {}'.format(sigma)) 210 | 211 | t_img = img 212 | if not isinstance(img, torch.Tensor): 213 | if not _is_pil_image(img): 214 | raise TypeError('img should be PIL Image or Tensor. Got {}'.format(type(img))) 215 | 216 | t_img = to_tensor(img) 217 | 218 | output = _gaussian_blur(t_img, kernel_size, sigma) 219 | 220 | if not isinstance(img, torch.Tensor): 221 | output = to_pil_image(output) 222 | return output 223 | 224 | 225 | 226 | 227 | # if __name__ == "__main__": 228 | # gaussian_blur = GaussianBlur(kernel_size=23) 229 | -------------------------------------------------------------------------------- /augmentations/simsiam_aug.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms as T 2 | from PIL import Image 3 | try: 4 | from torchvision.transforms import GaussianBlur 5 | except ImportError: 6 | from .gaussian_blur import GaussianBlur 7 | T.GaussianBlur = GaussianBlur 8 | 9 | 10 | class SimSiamTransform(): 11 | def __init__(self, image_size, mean_std, **aug_kwargs): 12 | p_blur = 0.5 if image_size > 32 else 0 # exclude cifar 13 | # self.not_aug_transform = T.Compose([T.ToTensor(), T.Normalize(*mean_std)]) 14 | self.not_aug_transform = T.Compose([T.ToTensor()]) 15 | 16 | random_crop = T.RandomCrop(image_size, padding=4) if aug_kwargs['cl_default'] else T.RandomResizedCrop(image_size, scale=(0.2, 1.0)) 17 | self.transform = T.Compose([ 18 | random_crop, 19 | T.RandomHorizontalFlip(), 20 | T.RandomApply([T.ColorJitter(0.4,0.4,0.4,0.1)], p=0.8), 21 | T.RandomGrayscale(p=0.2), 22 | T.RandomApply([T.GaussianBlur(kernel_size=image_size//20*2+1, sigma=(0.1, 2.0))], p=p_blur), 23 | T.ToTensor(), 24 | T.Normalize(*mean_std) 25 | ]) 26 | def __call__(self, x): 27 | x1 = self.transform(x) 28 | x2 = self.transform(x) 29 | not_aug_x = self.not_aug_transform(x) 30 | return x1, x2, not_aug_x 31 | 32 | 33 | def to_pil_image(pic, mode=None): 34 | """Convert a tensor or an ndarray to PIL Image. 35 | 36 | See :class:`~torchvision.transforms.ToPILImage` for more details. 37 | 38 | Args: 39 | pic (Tensor or numpy.ndarray): Image to be converted to PIL Image. 40 | mode (`PIL.Image mode`_): color space and pixel depth of input data (optional). 41 | 42 | .. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes 43 | 44 | Returns: 45 | PIL Image: Image converted to PIL Image. 46 | """ 47 | if not(isinstance(pic, torch.Tensor) or isinstance(pic, np.ndarray)): 48 | raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic))) 49 | 50 | elif isinstance(pic, torch.Tensor): 51 | if pic.ndimension() not in {2, 3}: 52 | raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndimension())) 53 | 54 | elif pic.ndimension() == 2: 55 | # if 2D image, add channel dimension (CHW) 56 | pic = pic.unsqueeze(0) 57 | 58 | elif isinstance(pic, np.ndarray): 59 | if pic.ndim not in {2, 3}: 60 | raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndim)) 61 | 62 | elif pic.ndim == 2: 63 | # if 2D image, add channel dimension (HWC) 64 | pic = np.expand_dims(pic, 2) 65 | 66 | npimg = pic 67 | if isinstance(pic, torch.Tensor): 68 | if pic.is_floating_point() and mode != 'F': 69 | pic = pic.mul(255).byte() 70 | npimg = np.transpose(pic.cpu().numpy(), (1, 2, 0)) 71 | 72 | if not isinstance(npimg, np.ndarray): 73 | raise TypeError('Input pic must be a torch.Tensor or NumPy ndarray, ' + 74 | 'not {}'.format(type(npimg))) 75 | 76 | if npimg.shape[2] == 1: 77 | expected_mode = None 78 | npimg = npimg[:, :, 0] 79 | if npimg.dtype == np.uint8: 80 | expected_mode = 'L' 81 | elif npimg.dtype == np.int16: 82 | expected_mode = 'I;16' 83 | elif npimg.dtype == np.int32: 84 | expected_mode = 'I' 85 | elif npimg.dtype == np.float32: 86 | expected_mode = 'F' 87 | if mode is not None and mode != expected_mode: 88 | raise ValueError("Incorrect mode ({}) supplied for input type {}. Should be {}" 89 | .format(mode, np.dtype, expected_mode)) 90 | mode = expected_mode 91 | 92 | elif npimg.shape[2] == 2: 93 | permitted_2_channel_modes = ['LA'] 94 | if mode is not None and mode not in permitted_2_channel_modes: 95 | raise ValueError("Only modes {} are supported for 2D inputs".format(permitted_2_channel_modes)) 96 | 97 | if mode is None and npimg.dtype == np.uint8: 98 | mode = 'LA' 99 | 100 | elif npimg.shape[2] == 4: 101 | permitted_4_channel_modes = ['RGBA', 'CMYK', 'RGBX'] 102 | if mode is not None and mode not in permitted_4_channel_modes: 103 | raise ValueError("Only modes {} are supported for 4D inputs".format(permitted_4_channel_modes)) 104 | 105 | if mode is None and npimg.dtype == np.uint8: 106 | mode = 'RGBA' 107 | else: 108 | permitted_3_channel_modes = ['RGB', 'YCbCr', 'HSV'] 109 | if mode is not None and mode not in permitted_3_channel_modes: 110 | raise ValueError("Only modes {} are supported for 3D inputs".format(permitted_3_channel_modes)) 111 | if mode is None and npimg.dtype == np.uint8: 112 | mode = 'RGB' 113 | 114 | if mode is None: 115 | raise TypeError('Input type {} is not supported'.format(npimg.dtype)) 116 | 117 | return Image.fromarray(npimg, mode=mode) 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/HCL/4aa98387995946061ade7b3eede3715e7c29c1ba/configs/__init__.py -------------------------------------------------------------------------------- /configs/cifar10/distil.yaml: -------------------------------------------------------------------------------- 1 | name: c10-experiment-hcl 2 | dataset: 3 | name: seq-cifar10 4 | image_size: 32 5 | num_workers: 4 6 | 7 | model: 8 | name: simsiam 9 | backbone: resnet18 10 | cl_model: distil 11 | proj_layers: 2 12 | buffer_size: 200 13 | 14 | train: 15 | optimizer: 16 | name: sgd 17 | weight_decay: 0.0005 18 | momentum: 0.9 19 | warmup_epochs: 10 20 | warmup_lr: 0 21 | base_lr: 0.03 22 | final_lr: 0 23 | num_epochs: 200 # this parameter influence the lr decay 24 | stop_at_epoch: 200 # has to be smaller than num_epochs 25 | batch_size: 32 26 | alpha: 3.0 27 | eval: # linear evaluation, False will turn off automatic evaluation after training 28 | optimizer: 29 | name: sgd 30 | weight_decay: 0 31 | momentum: 0.9 32 | warmup_lr: 0 33 | warmup_epochs: 0 34 | base_lr: 30 35 | final_lr: 0 36 | batch_size: 256 37 | num_epochs: 100 38 | 39 | logger: 40 | csv_log: True 41 | tensorboard: True 42 | matplotlib: True 43 | 44 | seed: null # None type for yaml file 45 | # two things might lead to stochastic behavior other than seed: 46 | # worker_init_fn from dataloader and torch.nn.functional.interpolate 47 | # (keep this in mind if you want to achieve 100% deterministic) 48 | 49 | 50 | 51 | 52 | -------------------------------------------------------------------------------- /configs/cifar10/distilbuf.yaml: -------------------------------------------------------------------------------- 1 | name: c10-experiment-hcl 2 | dataset: 3 | name: seq-cifar10 4 | image_size: 32 5 | num_workers: 4 6 | 7 | model: 8 | name: simsiam 9 | backbone: resnet18 10 | cl_model: distilbuf 11 | proj_layers: 2 12 | buffer_size: 200 13 | 14 | train: 15 | optimizer: 16 | name: sgd 17 | weight_decay: 0.0005 18 | momentum: 0.9 19 | warmup_epochs: 10 20 | warmup_lr: 0 21 | base_lr: 0.03 22 | final_lr: 0 23 | num_epochs: 200 # this parameter influence the lr decay 24 | stop_at_epoch: 200 # has to be smaller than num_epochs 25 | batch_size: 32 26 | alpha: 1.0 27 | eval: # linear evaluation, False will turn off automatic evaluation after training 28 | optimizer: 29 | name: sgd 30 | weight_decay: 0 31 | momentum: 0.9 32 | warmup_lr: 0 33 | warmup_epochs: 0 34 | base_lr: 30 35 | final_lr: 0 36 | batch_size: 256 37 | num_epochs: 100 38 | 39 | logger: 40 | csv_log: True 41 | tensorboard: True 42 | matplotlib: True 43 | 44 | seed: null # None type for yaml file 45 | # two things might lead to stochastic behavior other than seed: 46 | # worker_init_fn from dataloader and torch.nn.functional.interpolate 47 | # (keep this in mind if you want to achieve 100% deterministic) 48 | 49 | 50 | 51 | 52 | -------------------------------------------------------------------------------- /configs/cifar10/qdi.yaml: -------------------------------------------------------------------------------- 1 | name: simsiam-c10-experiment-resnet18 2 | dataset: 3 | name: seq-cifar10 4 | image_size: 32 5 | num_workers: 4 6 | 7 | model: 8 | name: simsiam 9 | backbone: resnet18 10 | cl_model: qdi 11 | proj_layers: 2 12 | buffer_size: 200 13 | 14 | train: 15 | optimizer: 16 | name: sgd 17 | weight_decay: 0.0005 18 | momentum: 0.9 19 | warmup_epochs: 10 20 | warmup_lr: 0 21 | base_lr: 0.03 22 | final_lr: 0 23 | num_epochs: 200 # this parameter influence the lr decay 24 | stop_at_epoch: 200 # has to be smaller than num_epochs 25 | batch_size: 32 26 | alpha: 1.0 27 | di_lr: 0.005 28 | di_var: 0.001 29 | di_l2: 0. 30 | di_feature: 0.1 31 | di_itrs: 500 32 | eval: # linear evaluation, False will turn off automatic evaluation after training 33 | optimizer: 34 | name: sgd 35 | weight_decay: 0 36 | momentum: 0.9 37 | warmup_lr: 0 38 | warmup_epochs: 0 39 | base_lr: 30 40 | final_lr: 0 41 | batch_size: 256 42 | num_epochs: 100 43 | 44 | logger: 45 | csv_log: True 46 | tensorboard: True 47 | matplotlib: True 48 | 49 | seed: null # None type for yaml file 50 | # two things might lead to stochastic behavior other than seed: 51 | # worker_init_fn from dataloader and torch.nn.functional.interpolate 52 | # (keep this in mind if you want to achieve 100% deterministic) 53 | 54 | 55 | 56 | 57 | -------------------------------------------------------------------------------- /configs/cifar100/distil.yaml: -------------------------------------------------------------------------------- 1 | name: c100-experiment-hcl 2 | dataset: 3 | name: seq-cifar100 4 | image_size: 32 5 | num_workers: 4 6 | 7 | model: 8 | name: simsiam 9 | backbone: resnet18 10 | cl_model: distil 11 | proj_layers: 2 12 | buffer_size: 200 13 | 14 | train: 15 | optimizer: 16 | name: sgd 17 | weight_decay: 0.0005 18 | momentum: 0.9 19 | warmup_epochs: 10 20 | warmup_lr: 0 21 | base_lr: 0.03 22 | final_lr: 0 23 | num_epochs: 200 # this parameter influence the lr decay 24 | stop_at_epoch: 200 # has to be smaller than num_epochs 25 | batch_size: 32 26 | alpha: 3.0 27 | eval: # linear evaluation, False will turn off automatic evaluation after training 28 | optimizer: 29 | name: sgd 30 | weight_decay: 0 31 | momentum: 0.9 32 | warmup_lr: 0 33 | warmup_epochs: 0 34 | base_lr: 30 35 | final_lr: 0 36 | batch_size: 256 37 | num_epochs: 100 38 | 39 | logger: 40 | csv_log: True 41 | tensorboard: True 42 | matplotlib: True 43 | 44 | seed: null # None type for yaml file 45 | # two things might lead to stochastic behavior other than seed: 46 | # worker_init_fn from dataloader and torch.nn.functional.interpolate 47 | # (keep this in mind if you want to achieve 100% deterministic) 48 | 49 | 50 | 51 | 52 | -------------------------------------------------------------------------------- /configs/cifar100/distilbuf.yaml: -------------------------------------------------------------------------------- 1 | name: c100-experiment-hcl 2 | dataset: 3 | name: seq-cifar100 4 | image_size: 32 5 | num_workers: 4 6 | 7 | model: 8 | name: simsiam 9 | backbone: resnet18 10 | cl_model: distilbuf 11 | proj_layers: 2 12 | buffer_size: 200 13 | 14 | train: 15 | optimizer: 16 | name: sgd 17 | weight_decay: 0.0005 18 | momentum: 0.9 19 | warmup_epochs: 10 20 | warmup_lr: 0 21 | base_lr: 0.03 22 | final_lr: 0 23 | num_epochs: 200 # this parameter influence the lr decay 24 | stop_at_epoch: 200 # has to be smaller than num_epochs 25 | batch_size: 32 26 | alpha: 3.0 27 | eval: # linear evaluation, False will turn off automatic evaluation after training 28 | optimizer: 29 | name: sgd 30 | weight_decay: 0 31 | momentum: 0.9 32 | warmup_lr: 0 33 | warmup_epochs: 0 34 | base_lr: 30 35 | final_lr: 0 36 | batch_size: 256 37 | num_epochs: 100 38 | 39 | logger: 40 | csv_log: True 41 | tensorboard: True 42 | matplotlib: True 43 | 44 | seed: null # None type for yaml file 45 | # two things might lead to stochastic behavior other than seed: 46 | # worker_init_fn from dataloader and torch.nn.functional.interpolate 47 | # (keep this in mind if you want to achieve 100% deterministic) 48 | 49 | 50 | 51 | 52 | -------------------------------------------------------------------------------- /configs/cifar100/qdi.yaml: -------------------------------------------------------------------------------- 1 | name: simsiam-c100-experiment-resnet18 2 | dataset: 3 | name: seq-cifar100 4 | image_size: 32 5 | num_workers: 4 6 | 7 | model: 8 | name: simsiam 9 | backbone: resnet18 10 | cl_model: qdi 11 | proj_layers: 2 12 | buffer_size: 200 13 | 14 | train: 15 | optimizer: 16 | name: sgd 17 | weight_decay: 0.0005 18 | momentum: 0.9 19 | warmup_epochs: 10 20 | warmup_lr: 0 21 | base_lr: 0.03 22 | final_lr: 0 23 | num_epochs: 200 # this parameter influence the lr decay 24 | stop_at_epoch: 200 # has to be smaller than num_epochs 25 | batch_size: 32 26 | alpha: 3.0 27 | di_var: 0.003 28 | di_l2: 0.003 29 | di_feature: 0.2 30 | di_itrs: 500 31 | di_lr: 0.03 32 | eval: # linear evaluation, False will turn off automatic evaluation after training 33 | optimizer: 34 | name: sgd 35 | weight_decay: 0 36 | momentum: 0.9 37 | warmup_lr: 0 38 | warmup_epochs: 0 39 | base_lr: 30 40 | final_lr: 0 41 | batch_size: 256 42 | num_epochs: 100 43 | 44 | logger: 45 | csv_log: True 46 | tensorboard: True 47 | matplotlib: True 48 | 49 | seed: null # None type for yaml file 50 | # two things might lead to stochastic behavior other than seed: 51 | # worker_init_fn from dataloader and torch.nn.functional.interpolate 52 | # (keep this in mind if you want to achieve 100% deterministic) 53 | 54 | 55 | 56 | 57 | -------------------------------------------------------------------------------- /configs/tinyimg/distil.yaml: -------------------------------------------------------------------------------- 1 | name: tinyimagenet-experiment-hcl 2 | dataset: 3 | name: seq-tinyimg 4 | image_size: 64 5 | num_workers: 4 6 | 7 | model: 8 | name: simsiam 9 | backbone: resnet18 10 | cl_model: distil 11 | proj_layers: 2 12 | buffer_size: 200 13 | 14 | train: 15 | optimizer: 16 | name: sgd 17 | weight_decay: 0.0005 18 | momentum: 0.9 19 | warmup_epochs: 10 20 | warmup_lr: 0 21 | base_lr: 0.03 22 | final_lr: 0 23 | num_epochs: 200 # this parameter influence the lr decay 24 | stop_at_epoch: 200 # has to be smaller than num_epochs 25 | batch_size: 32 26 | alpha: 3.0 27 | eval: # linear evaluation, False will turn off automatic evaluation after training 28 | optimizer: 29 | name: sgd 30 | weight_decay: 0 31 | momentum: 0.9 32 | warmup_lr: 0 33 | warmup_epochs: 0 34 | base_lr: 30 35 | final_lr: 0 36 | batch_size: 256 37 | num_epochs: 100 38 | 39 | logger: 40 | csv_log: True 41 | tensorboard: True 42 | matplotlib: True 43 | 44 | seed: null # None type for yaml file 45 | # two things might lead to stochastic behavior other than seed: 46 | # worker_init_fn from dataloader and torch.nn.functional.interpolate 47 | # (keep this in mind if you want to achieve 100% deterministic) 48 | 49 | 50 | 51 | 52 | -------------------------------------------------------------------------------- /configs/tinyimg/distilbuf.yaml: -------------------------------------------------------------------------------- 1 | name: tinyimagenet-experiment-hcl 2 | dataset: 3 | name: seq-tinyimg 4 | image_size: 64 5 | num_workers: 4 6 | 7 | model: 8 | name: simsiam 9 | backbone: resnet18 10 | cl_model: distilbuf 11 | proj_layers: 2 12 | buffer_size: 200 13 | 14 | train: 15 | optimizer: 16 | name: sgd 17 | weight_decay: 0.0005 18 | momentum: 0.9 19 | warmup_epochs: 10 20 | warmup_lr: 0 21 | base_lr: 0.03 22 | final_lr: 0 23 | num_epochs: 200 # this parameter influence the lr decay 24 | stop_at_epoch: 200 # has to be smaller than num_epochs 25 | batch_size: 32 26 | alpha: 1.0 27 | eval: # linear evaluation, False will turn off automatic evaluation after training 28 | optimizer: 29 | name: sgd 30 | weight_decay: 0 31 | momentum: 0.9 32 | warmup_lr: 0 33 | warmup_epochs: 0 34 | base_lr: 30 35 | final_lr: 0 36 | batch_size: 256 37 | num_epochs: 100 38 | 39 | logger: 40 | csv_log: True 41 | tensorboard: True 42 | matplotlib: True 43 | 44 | seed: null # None type for yaml file 45 | # two things might lead to stochastic behavior other than seed: 46 | # worker_init_fn from dataloader and torch.nn.functional.interpolate 47 | # (keep this in mind if you want to achieve 100% deterministic) 48 | 49 | 50 | 51 | 52 | -------------------------------------------------------------------------------- /configs/tinyimg/qdi.yaml: -------------------------------------------------------------------------------- 1 | name: tinyimg-experiment-resnet18 2 | dataset: 3 | name: seq-tinyimg 4 | image_size: 64 5 | num_workers: 4 6 | 7 | model: 8 | name: simsiam 9 | backbone: resnet18 10 | cl_model: qdi 11 | proj_layers: 2 12 | buffer_size: 200 13 | 14 | train: 15 | optimizer: 16 | name: sgd 17 | weight_decay: 0.0005 18 | momentum: 0.9 19 | warmup_epochs: 10 20 | warmup_lr: 0 21 | base_lr: 0.03 22 | final_lr: 0 23 | num_epochs: 200 # this parameter influence the lr decay 24 | stop_at_epoch: 200 # has to be smaller than num_epochs 25 | batch_size: 32 26 | alpha: 1.0 27 | di_var: 0.003 28 | di_l2: 0.003 29 | di_feature: 0.2 30 | di_itrs: 500 31 | di_lr: 0.03 32 | eval: # linear evaluation, False will turn off automatic evaluation after training 33 | optimizer: 34 | name: sgd 35 | weight_decay: 0 36 | momentum: 0.9 37 | warmup_lr: 0 38 | warmup_epochs: 0 39 | base_lr: 30 40 | final_lr: 0 41 | batch_size: 256 42 | num_epochs: 100 43 | 44 | logger: 45 | csv_log: True 46 | tensorboard: True 47 | matplotlib: True 48 | 49 | seed: null # None type for yaml file 50 | # two things might lead to stochastic behavior other than seed: 51 | # worker_init_fn from dataloader and torch.nn.functional.interpolate 52 | # (keep this in mind if you want to achieve 100% deterministic) 53 | 54 | 55 | 56 | 57 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from datasets.seq_cifar10 import SequentialCIFAR10 7 | from datasets.seq_cifar100 import SequentialCIFAR100 8 | from datasets.seq_tinyimagenet import SequentialTinyImagenet 9 | from datasets.utils.continual_dataset import ContinualDataset 10 | from argparse import Namespace 11 | import torchvision 12 | 13 | NAMES = { 14 | SequentialCIFAR10.NAME: SequentialCIFAR10, 15 | SequentialCIFAR100.NAME: SequentialCIFAR100, 16 | SequentialTinyImagenet.NAME: SequentialTinyImagenet, 17 | } 18 | 19 | N_CLASSES = {'seq-cifar10': 10, 'seq-cifar100': 100, 'seq-tinyimg': 200} 20 | BACKBONES = {'seq-cifar10': ["lenet", "resnet18", "densenet", "senet", "regnet"], 21 | 'seq-cifar100': ["lenet","lenet", "alexnet", "alexnet", "vgg16", "vgg16", "inception", "inception", "resnet18", "resnet18", "resnext", "resnext", "densenet", "densenet", "senet", "senet", "regnet", "regnet", "regnet", "regnet"], 22 | 'seq-tinyimg': ["lenet", "lenet", "resnet18", "resnet18", "resnext", "resnext", "senet", "senet", "regnet", "regnet"], 23 | } 24 | 25 | 26 | def get_dataset(args: Namespace) -> ContinualDataset: 27 | """ 28 | Creates and returns a continual dataset. 29 | :param args: the arguments which contains the hyperparameters 30 | :return: the continual dataset 31 | """ 32 | assert args.dataset_kwargs['dataset'] in NAMES.keys() 33 | return NAMES[args.dataset_kwargs['dataset']](args) 34 | 35 | 36 | def get_gcl_dataset(args: Namespace): 37 | """ 38 | Creates and returns a GCL dataset. 39 | :param args: the arguments which contains the hyperparameters 40 | :return: the continual dataset 41 | """ 42 | assert args.dataset in GCL_NAMES.keys() 43 | return GCL_NAMES[args.dataset](args) 44 | -------------------------------------------------------------------------------- /datasets/datasets_utils.py: -------------------------------------------------------------------------------- 1 | ########################################## 2 | # Code from https://github.com/joansj/hat 3 | ########################################## 4 | 5 | import os,sys 6 | import os.path 7 | import numpy as np 8 | import torch 9 | import torch.utils.data 10 | from torchvision import datasets,transforms 11 | from sklearn.utils import shuffle 12 | import urllib.request 13 | from PIL import Image 14 | import pickle 15 | 16 | 17 | class FashionMNIST(datasets.MNIST): 18 | """`Fashion MNIST `_ Dataset. 19 | """ 20 | urls = [ 21 | 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz', 22 | 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz', 23 | 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz', 24 | 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz', 25 | ] 26 | 27 | -------------------------------------------------------------------------------- /datasets/random_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class RandomDataset(torch.utils.data.Dataset): 4 | def __init__(self, root=None, train=True, transform=None, target_transform=None): 5 | self.transform = transform 6 | self.target_transform = target_transform 7 | 8 | self.size = 1000 9 | def __getitem__(self, idx): 10 | if idx < self.size: 11 | return [torch.randn((3, 224, 224)), torch.randn((3, 224, 224))], [0,0,0] 12 | else: 13 | raise Exception 14 | 15 | def __len__(self): 16 | return self.size 17 | -------------------------------------------------------------------------------- /datasets/seq_cifar10.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from torchvision.datasets import CIFAR10 7 | import torchvision.transforms as transforms 8 | import torch.nn.functional as F 9 | from datasets.seq_tinyimagenet import base_path 10 | from PIL import Image 11 | from datasets.utils.validation import get_train_val 12 | from datasets.utils.continual_dataset import ContinualDataset, store_masked_loaders 13 | from datasets.utils.continual_dataset import get_previous_train_loader 14 | from typing import Tuple 15 | from datasets.transforms.denormalization import DeNormalize 16 | import torch 17 | from augmentations import get_aug 18 | from PIL import Image 19 | 20 | class MyCIFAR10(CIFAR10): 21 | """ 22 | Overrides the CIFAR10 dataset to change the getitem function. 23 | """ 24 | def __init__(self, root, train=True, transform=None, 25 | target_transform=None, download=False) -> None: 26 | super(MyCIFAR10, self).__init__(root, train, transform, target_transform, download) 27 | 28 | def __getitem__(self, index: int) -> Tuple[type(Image), int, type(Image)]: 29 | """ 30 | Gets the requested element from the dataset. 31 | :param index: index of the element to be returned 32 | :returns: tuple: (image, target) where target is index of the target class. 33 | """ 34 | img, target = self.data[index], self.targets[index] 35 | img = Image.fromarray(img, mode='RGB') 36 | original_img = img.copy() 37 | 38 | 39 | img, img1, not_aug_img = self.transform(original_img) 40 | 41 | if hasattr(self, 'logits'): 42 | return (img, img1, not_aug_img), target, self.logits[index] 43 | 44 | return (img, img1, not_aug_img), target 45 | 46 | 47 | class SequentialCIFAR10(ContinualDataset): 48 | 49 | NAME = 'seq-cifar10' 50 | SETTING = 'class-il' 51 | N_CLASSES_PER_TASK = 2 52 | N_TASKS = 5 53 | 54 | def get_data_loaders(self, args): 55 | cifar_norm = [[0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2615]] 56 | transform = get_aug(train=True, mean_std=cifar_norm, **args.aug_kwargs) 57 | test_transform = get_aug(train=False, train_classifier=False, mean_std=cifar_norm, **args.aug_kwargs) 58 | 59 | if args.server: 60 | train_dataset = MyCIFAR10('/cifar10-pytorch', train=True, 61 | download=False, transform=transform) 62 | memory_dataset = MyCIFAR10('/cifar10-pytorch', train=True, 63 | download=False, transform=test_transform) 64 | else: 65 | train_dataset = MyCIFAR10(base_path() + 'CIFAR10', train=True, 66 | download=True, transform=transform) 67 | memory_dataset = MyCIFAR10(base_path() + 'CIFAR10', train=True, 68 | download=True, transform=test_transform) 69 | 70 | if self.args.validation: 71 | train_dataset, test_dataset = get_train_val(train_dataset, test_transform, self.NAME) 72 | memory_dataset, _ = get_train_val(memory_dataset, test_transform, self.NAME) 73 | else: 74 | if args.server: 75 | test_dataset = CIFAR10('/cifar10-pytorch',train=False, 76 | download=False, transform=test_transform) 77 | else: 78 | test_dataset = CIFAR10(base_path() + 'CIFAR10',train=False, 79 | download=True, transform=test_transform) 80 | 81 | train, memory, test = store_masked_loaders(train_dataset, test_dataset, memory_dataset, self) 82 | return train, memory, test 83 | 84 | 85 | def get_transform(self, args): 86 | cifar_norm = [[0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2615]] 87 | if args.cl_default: 88 | transform = transforms.Compose( 89 | [transforms.ToPILImage(), 90 | transforms.RandomCrop(32, padding=4), 91 | transforms.RandomHorizontalFlip(), 92 | transforms.ToTensor(), 93 | transforms.Normalize(*cifar_norm) 94 | ]) 95 | else: 96 | transform = transforms.Compose( 97 | [transforms.ToPILImage(), 98 | transforms.RandomResizedCrop(32, scale=(0.08, 1.0), ratio=(3.0/4.0,4.0/3.0), interpolation=Image.BICUBIC), 99 | transforms.RandomHorizontalFlip(), 100 | transforms.ToTensor(), 101 | transforms.Normalize(*cifar_norm) 102 | ]) 103 | 104 | return transform 105 | 106 | def not_aug_dataloader(self, batch_size): 107 | cifar_norm = [[0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2615]] 108 | transform = transforms.Compose([transforms.ToTensor(), 109 | transforms.Normalize(*cifar_norm)]) 110 | 111 | train_dataset = CIFAR10(base_path() + 'CIFAR10', train=True, 112 | download=True, transform=transform) 113 | train_loader = get_previous_train_loader(train_dataset, batch_size, self) 114 | 115 | return train_loader 116 | -------------------------------------------------------------------------------- /datasets/seq_cifar100.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from torchvision.datasets import CIFAR100 7 | import torchvision.transforms as transforms 8 | import torch.nn.functional as F 9 | from datasets.seq_tinyimagenet import base_path 10 | from PIL import Image 11 | from datasets.utils.validation import get_train_val 12 | from datasets.utils.continual_dataset import ContinualDataset, store_masked_loaders 13 | from datasets.utils.continual_dataset import get_previous_train_loader 14 | from typing import Tuple 15 | from datasets.transforms.denormalization import DeNormalize 16 | import torch 17 | from augmentations import get_aug 18 | from PIL import Image 19 | 20 | 21 | class MyCIFAR100(CIFAR100): 22 | """ 23 | Overrides the CIFAR10 dataset to change the getitem function. 24 | """ 25 | def __init__(self, root, train=True, transform=None, 26 | target_transform=None, download=False) -> None: 27 | super(MyCIFAR100, self).__init__(root, train, transform, target_transform, download) 28 | 29 | def __getitem__(self, index: int) -> Tuple[type(Image), int, type(Image)]: 30 | """ 31 | Gets the requested element from the dataset. 32 | :param index: index of the element to be returned 33 | :returns: tuple: (image, target) where target is index of the target class. 34 | """ 35 | img, target = self.data[index], self.targets[index] 36 | img = Image.fromarray(img, mode='RGB') 37 | original_img = img.copy() 38 | 39 | img, img1, not_aug_img = self.transform(original_img) 40 | 41 | if hasattr(self, 'logits'): 42 | return (img, img1, not_aug_img), target, self.logits[index] 43 | 44 | return (img, img1, not_aug_img), target 45 | 46 | 47 | class SequentialCIFAR100(ContinualDataset): 48 | 49 | NAME = 'seq-cifar100' 50 | SETTING = 'class-il' 51 | N_CLASSES_PER_TASK = 5 52 | N_TASKS = 20 53 | 54 | def get_data_loaders(self, args): 55 | cifar_norm = [[0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2615]] 56 | transform = get_aug(train=True, mean_std=cifar_norm, **args.aug_kwargs) 57 | test_transform = get_aug(train=False, train_classifier=False, mean_std=cifar_norm, **args.aug_kwargs) 58 | 59 | if args.server: 60 | train_dataset = MyCIFAR100('/cifar100_data', train=True, 61 | download=False, transform=transform) 62 | memory_dataset = CIFAR100('/cifar100_data', train=True, 63 | download=False, transform=test_transform) 64 | else: 65 | train_dataset = MyCIFAR100(base_path() + 'CIFAR100', train=True, 66 | download=True, transform=transform) 67 | memory_dataset = CIFAR100(base_path() + 'CIFAR100', train=True, 68 | download=True, transform=test_transform) 69 | 70 | if self.args.validation: 71 | train_dataset, test_dataset = get_train_val(train_dataset, test_transform, self.NAME) 72 | memory_dataset, _ = get_train_val(memory_dataset, test_transform, self.NAME) 73 | else: 74 | if args.server: 75 | test_dataset = CIFAR100('/cifar100_data', train=False, 76 | download=False, transform=test_transform) 77 | else: 78 | test_dataset = CIFAR100(base_path() + 'CIFAR100', train=False, 79 | download=True, transform=test_transform) 80 | 81 | train, memory, test = store_masked_loaders(train_dataset, test_dataset, memory_dataset, self) 82 | return train, memory, test 83 | 84 | def get_transform(self, args): 85 | cifar_norm = [[0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2615]] 86 | if args.cl_default: 87 | transform = transforms.Compose( 88 | [transforms.ToPILImage(), 89 | transforms.RandomCrop(32, padding=4), 90 | transforms.RandomHorizontalFlip(), 91 | transforms.ToTensor(), 92 | transforms.Normalize(*cifar_norm) 93 | ]) 94 | else: 95 | transform = transforms.Compose( 96 | [transforms.ToPILImage(), 97 | transforms.RandomResizedCrop(32, scale=(0.08, 1.0), ratio=(3.0/4.0,4.0/3.0), interpolation=Image.BICUBIC), 98 | transforms.RandomHorizontalFlip(), 99 | transforms.ToTensor(), 100 | transforms.Normalize(*cifar_norm) 101 | ]) 102 | 103 | return transform 104 | 105 | 106 | 107 | def not_aug_dataloader(self, batch_size): 108 | cifar_norm = [[0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2615]] 109 | transform = transforms.Compose([transforms.ToTensor(), 110 | transforms.Normalize(*cifar_norm)]) 111 | 112 | train_dataset = MyCIFAR100(base_path() + 'CIFAR100', train=True, 113 | download=True, transform=transform) 114 | train_loader = get_previous_train_loader(train_dataset, batch_size, self) 115 | 116 | return train_loader 117 | -------------------------------------------------------------------------------- /datasets/seq_tinyimagenet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import numpy as np 7 | import torchvision.transforms as transforms 8 | from torch.utils.data import Dataset 9 | import torch.nn.functional as F 10 | from utils.conf import base_path 11 | from PIL import Image 12 | import os 13 | from datasets.utils.validation import get_train_val 14 | from datasets.utils.continual_dataset import ContinualDataset, store_masked_loaders 15 | from datasets.utils.continual_dataset import get_previous_train_loader 16 | from datasets.transforms.denormalization import DeNormalize 17 | from augmentations import get_aug 18 | 19 | 20 | class TinyImagenet(Dataset): 21 | """ 22 | Defines Tiny Imagenet as for the others pytorch datasets. 23 | """ 24 | def __init__(self, root: str, train: bool=True, transform: transforms=None, 25 | target_transform: transforms=None, download: bool=False) -> None: 26 | self.not_aug_transform = transforms.Compose([transforms.ToTensor()]) 27 | self.root = root 28 | self.train = train 29 | self.transform = transform 30 | self.target_transform = target_transform 31 | self.download = download 32 | 33 | if download: 34 | if os.path.isdir(root) and len(os.listdir(root)) > 0: 35 | print('Download not needed, files already on disk.') 36 | else: 37 | import gdown 38 | import zipfile 39 | # https://drive.google.com/file/d/1Sy3ScMBr0F4se8VZ6TAwDYF-nNGAAdxj/view 40 | url = 'https://drive.google.com/uc?id=1Sy3ScMBr0F4se8VZ6TAwDYF-nNGAAdxj' 41 | if not os.path.exists(root): os.makedirs(root) 42 | gdown.download(url, root, quiet=False, fuzzy=True) 43 | with zipfile.ZipFile(os.listdir(root), "r") as f: 44 | f.extractall(path=root) 45 | gdown.extractall(root) 46 | 47 | self.data = [] 48 | for num in range(20): 49 | self.data.append(np.load(os.path.join( 50 | root, 'processed/x_%s_%02d.npy' % 51 | ('train' if self.train else 'val', num+1)))) 52 | self.data = np.concatenate(np.array(self.data)) 53 | 54 | self.targets = [] 55 | for num in range(20): 56 | self.targets.append(np.load(os.path.join( 57 | root, 'processed/y_%s_%02d.npy' % 58 | ('train' if self.train else 'val', num+1)))) 59 | self.targets = np.concatenate(np.array(self.targets)) 60 | 61 | def __len__(self): 62 | return len(self.data) 63 | 64 | def __getitem__(self, index): 65 | img, target = self.data[index], self.targets[index] 66 | 67 | # doing this so that it is consistent with all other datasets 68 | # to return a PIL Image 69 | img = Image.fromarray(np.uint8(255 * img)) 70 | original_img = img.copy() 71 | 72 | if self.transform is not None: 73 | img = self.transform(img) 74 | 75 | if self.target_transform is not None: 76 | target = self.target_transform(target) 77 | 78 | if hasattr(self, 'logits'): 79 | return img, target, original_img, self.logits[index] 80 | 81 | return img, target 82 | 83 | 84 | class SequentialTinyImagenet(ContinualDataset): 85 | 86 | NAME = 'seq-tinyimg' 87 | SETTING = 'class-il' 88 | N_CLASSES_PER_TASK = 20 89 | N_TASKS = 10 90 | TRANSFORM = transforms.Compose( 91 | [transforms.RandomCrop(64, padding=4), 92 | transforms.RandomHorizontalFlip(), 93 | transforms.ToTensor(), 94 | transforms.Normalize((0.4802, 0.4480, 0.3975), 95 | (0.2770, 0.2691, 0.2821))]) 96 | 97 | def get_data_loaders(self, args): 98 | imagenet_norm = [[0.4802, 0.4480, 0.3975], [0.2770, 0.2691, 0.2821]] 99 | transform = get_aug(train=True, mean_std=imagenet_norm, **args.aug_kwargs) 100 | test_transform = get_aug(train=False, train_classifier=False, mean_std=imagenet_norm, **args.aug_kwargs) 101 | 102 | if args.server: 103 | train_dataset = TinyImagenet('/tinyimg_data', train=True, 104 | download=False, transform=transform) 105 | memory_dataset = TinyImagenet('/tinyimg_data', train=True, 106 | download=False, transform=test_transform) 107 | else: 108 | train_dataset = TinyImagenet(base_path() + 'TINYIMG', 109 | train=True, download=True, transform=transform) 110 | 111 | memory_dataset = TinyImagenet(base_path() + 'TINYIMG', 112 | train=True, download=True, transform=test_transform) 113 | if self.args.validation: 114 | train_dataset, test_dataset = get_train_val(train_dataset, 115 | test_transform, self.NAME) 116 | memory_dataset, _ = get_train_val(memory_dataset, test_transform, self.NAME) 117 | else: 118 | if args.server: 119 | test_dataset = TinyImagenet('/tinyimg_data', train=False, 120 | download=False, transform=test_transform) 121 | else: 122 | test_dataset = TinyImagenet(base_path() + 'TINYIMG', 123 | train=False, download=True, transform=test_transform) 124 | 125 | train, memory, test = store_masked_loaders(train_dataset, test_dataset, memory_dataset, self) 126 | return train, memory, test 127 | 128 | def get_transform(self, args): 129 | imagenet_norm = [[0.4802, 0.4480, 0.3975], [0.2770, 0.2691, 0.2821]] 130 | if args.cl_default: 131 | transform = transforms.Compose( 132 | [transforms.ToPILImage(), 133 | transforms.RandomCrop(64, padding=4), 134 | transforms.RandomHorizontalFlip(), 135 | transforms.ToTensor(), 136 | transforms.Normalize(*imagenet_norm) 137 | ]) 138 | else: 139 | transform = transforms.Compose( 140 | [transforms.ToPILImage(), 141 | transforms.RandomResizedCrop(64, scale=(0.08, 1.0), ratio=(3.0/4.0,4.0/3.0), interpolation=Image.BICUBIC), 142 | transforms.RandomHorizontalFlip(), 143 | transforms.ToTensor(), 144 | transforms.Normalize(*imagenet_norm) 145 | ]) 146 | 147 | return transform 148 | 149 | def not_aug_dataloader(self, batch_size): 150 | imagenet_norm = [[0.4802, 0.4480, 0.3975], [0.2770, 0.2691, 0.2821]] 151 | transform = transforms.Compose([transforms.ToTensor(), 152 | transforms.Normalize(*imagenet_norm)]) 153 | 154 | train_dataset = TinyImagenet(base_path() + 'TINYIMG', 155 | train=True, download=True, transform=transform) 156 | train_loader = get_previous_train_loader(train_dataset, batch_size, self) 157 | 158 | return train_loader 159 | -------------------------------------------------------------------------------- /datasets/test/seq-cifar10.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/HCL/4aa98387995946061ade7b3eede3715e7c29c1ba/datasets/test/seq-cifar10.pt -------------------------------------------------------------------------------- /datasets/test/seq-cifar100.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/HCL/4aa98387995946061ade7b3eede3715e7c29c1ba/datasets/test/seq-cifar100.pt -------------------------------------------------------------------------------- /datasets/test/seq-domainnet.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/HCL/4aa98387995946061ade7b3eede3715e7c29c1ba/datasets/test/seq-domainnet.pt -------------------------------------------------------------------------------- /datasets/test/seq-tinyimg.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/HCL/4aa98387995946061ade7b3eede3715e7c29c1ba/datasets/test/seq-tinyimg.pt -------------------------------------------------------------------------------- /datasets/transforms/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/HCL/4aa98387995946061ade7b3eede3715e7c29c1ba/datasets/transforms/__init__.py -------------------------------------------------------------------------------- /datasets/transforms/denormalization.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | class DeNormalize(object): 8 | def __init__(self, mean, std): 9 | self.mean = mean 10 | self.std = std 11 | 12 | def __call__(self, tensor): 13 | """ 14 | Args: 15 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 16 | Returns: 17 | Tensor: Normalized image. 18 | """ 19 | for t, m, s in zip(tensor, self.mean, self.std): 20 | t.mul_(s).add_(m) 21 | return tensor 22 | -------------------------------------------------------------------------------- /datasets/transforms/permutation.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import numpy as np 7 | 8 | 9 | class Permutation(object): 10 | """ 11 | Defines a fixed permutation for a numpy array. 12 | """ 13 | def __init__(self) -> None: 14 | """ 15 | Initializes the permutation. 16 | """ 17 | self.perm = None 18 | 19 | def __call__(self, sample: np.ndarray) -> np.ndarray: 20 | """ 21 | Randomly defines the permutation and applies the transformation. 22 | :param sample: image to be permuted 23 | :return: permuted image 24 | """ 25 | old_shape = sample.shape 26 | if self.perm is None: 27 | self.perm = np.random.permutation(len(sample.flatten())) 28 | 29 | return sample.flatten()[self.perm].reshape(old_shape) 30 | 31 | 32 | class FixedPermutation(object): 33 | """ 34 | Defines a fixed permutation (given the seed) for a numpy array. 35 | """ 36 | def __init__(self, seed: int) -> None: 37 | """ 38 | Defines the seed. 39 | :param seed: seed of the permutation 40 | """ 41 | self.perm = None 42 | self.seed = seed 43 | 44 | def __call__(self, sample: np.ndarray) -> np.ndarray: 45 | """ 46 | Defines the permutation and applies the transformation. 47 | :param sample: image to be permuted 48 | :return: permuted image 49 | """ 50 | old_shape = sample.shape 51 | if self.perm is None: 52 | np.random.seed(self.seed) 53 | self.perm = np.random.permutation(len(sample.flatten())) 54 | 55 | return sample.flatten()[self.perm].reshape(old_shape) 56 | -------------------------------------------------------------------------------- /datasets/transforms/rotation.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import numpy as np 7 | import torchvision.transforms.functional as F 8 | 9 | 10 | class Rotation(object): 11 | """ 12 | Defines a fixed rotation for a numpy array. 13 | """ 14 | 15 | def __init__(self, deg_min: int = 0, deg_max: int = 180) -> None: 16 | """ 17 | Initializes the rotation with a random angle. 18 | :param deg_min: lower extreme of the possible random angle 19 | :param deg_max: upper extreme of the possible random angle 20 | """ 21 | self.deg_min = deg_min 22 | self.deg_max = deg_max 23 | self.degrees = np.random.uniform(self.deg_min, self.deg_max) 24 | 25 | def __call__(self, x: np.ndarray) -> np.ndarray: 26 | """ 27 | Applies the rotation. 28 | :param x: image to be rotated 29 | :return: rotated image 30 | """ 31 | return F.rotate(x, self.degrees) 32 | 33 | 34 | class FixedRotation(object): 35 | """ 36 | Defines a fixed rotation for a numpy array. 37 | """ 38 | 39 | def __init__(self, seed: int, deg_min: int = 0, deg_max: int = 180) -> None: 40 | """ 41 | Initializes the rotation with a random angle. 42 | :param seed: seed of the rotation 43 | :param deg_min: lower extreme of the possible random angle 44 | :param deg_max: upper extreme of the possible random angle 45 | """ 46 | self.seed = seed 47 | self.deg_min = deg_min 48 | self.deg_max = deg_max 49 | 50 | np.random.seed(seed) 51 | self.degrees = np.random.uniform(self.deg_min, self.deg_max) 52 | 53 | def __call__(self, x: np.ndarray) -> np.ndarray: 54 | """ 55 | Applies the rotation. 56 | :param x: image to be rotated 57 | :return: rotated image 58 | """ 59 | return F.rotate(x, self.degrees) 60 | 61 | 62 | class IncrementalRotation(object): 63 | """ 64 | Defines an incremental rotation for a numpy array. 65 | """ 66 | 67 | def __init__(self, init_deg: int = 0, increase_per_iteration: float = 0.006) -> None: 68 | """ 69 | Defines the initial angle as well as the increase for each rotation 70 | :param init_deg: 71 | :param increase_per_iteration: 72 | """ 73 | self.increase_per_iteration = increase_per_iteration 74 | self.iteration = 0 75 | self.degrees = init_deg 76 | 77 | def __call__(self, x: np.ndarray) -> np.ndarray: 78 | """ 79 | Applies the rotation. 80 | :param x: image to be rotated 81 | :return: rotated image 82 | """ 83 | degs = (self.iteration * self.increase_per_iteration + self.degrees) % 360 84 | self.iteration += 1 85 | return F.rotate(x, degs) 86 | 87 | def set_iteration(self, x: int) -> None: 88 | """ 89 | Set the iteration to a given integer 90 | :param x: iteration index 91 | """ 92 | self.iteration = x 93 | -------------------------------------------------------------------------------- /datasets/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/HCL/4aa98387995946061ade7b3eede3715e7c29c1ba/datasets/utils/__init__.py -------------------------------------------------------------------------------- /datasets/utils/continual_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from abc import abstractmethod 7 | from argparse import Namespace 8 | from torch import nn as nn 9 | from torchvision.transforms import transforms 10 | from torch.utils.data import DataLoader 11 | from typing import Tuple 12 | from torchvision import datasets 13 | import numpy as np 14 | 15 | 16 | class ContinualDataset: 17 | """ 18 | Continual learning evaluation setting. 19 | """ 20 | NAME = None 21 | SETTING = None 22 | N_CLASSES_PER_TASK = None 23 | N_TASKS = None 24 | TRANSFORM = None 25 | 26 | def __init__(self, args: Namespace) -> None: 27 | """ 28 | Initializes the train and test lists of dataloaders. 29 | :param args: the arguments which contains the hyperparameters 30 | """ 31 | self.train_loader = None 32 | self.test_loaders = [] 33 | self.memory_loaders = [] 34 | self.train_loaders = [] 35 | self.i = 0 36 | self.args = args 37 | 38 | @abstractmethod 39 | def get_data_loaders(self) -> Tuple[DataLoader, DataLoader]: 40 | """ 41 | Creates and returns the training and test loaders for the current task. 42 | The current training loader and all test loaders are stored in self. 43 | :return: the current training and test loaders 44 | """ 45 | pass 46 | 47 | @abstractmethod 48 | def not_aug_dataloader(self, batch_size: int) -> DataLoader: 49 | """ 50 | Returns the dataloader of the current task, 51 | not applying data augmentation. 52 | :param batch_size: the batch size of the loader 53 | :return: the current training loader 54 | """ 55 | pass 56 | 57 | @staticmethod 58 | @abstractmethod 59 | def get_backbone() -> nn.Module: 60 | """ 61 | Returns the backbone to be used for to the current dataset. 62 | """ 63 | pass 64 | 65 | @staticmethod 66 | @abstractmethod 67 | def get_transform() -> transforms: 68 | """ 69 | Returns the transform to be used for to the current dataset. 70 | """ 71 | pass 72 | 73 | @staticmethod 74 | @abstractmethod 75 | def get_loss() -> nn.functional: 76 | """ 77 | Returns the loss to be used for to the current dataset. 78 | """ 79 | pass 80 | 81 | @staticmethod 82 | @abstractmethod 83 | def get_normalization_transform() -> transforms: 84 | """ 85 | Returns the transform used for normalizing the current dataset. 86 | """ 87 | pass 88 | 89 | @staticmethod 90 | @abstractmethod 91 | def get_denormalization_transform() -> transforms: 92 | """ 93 | Returns the transform used for denormalizing the current dataset. 94 | """ 95 | pass 96 | 97 | 98 | def store_masked_loaders(train_dataset: datasets, test_dataset: datasets, memory_dataset: datasets, 99 | setting: ContinualDataset) -> Tuple[DataLoader, DataLoader]: 100 | """ 101 | Divides the dataset into tasks. 102 | :param train_dataset: train dataset 103 | :param test_dataset: test dataset 104 | :param setting: continual learning setting 105 | :return: train and test loaders 106 | """ 107 | train_mask = np.logical_and(np.array(train_dataset.targets) >= setting.i, 108 | np.array(train_dataset.targets) < setting.i + setting.N_CLASSES_PER_TASK) 109 | test_mask = np.logical_and(np.array(test_dataset.targets) >= setting.i, 110 | np.array(test_dataset.targets) < setting.i + setting.N_CLASSES_PER_TASK) 111 | 112 | train_dataset.data = train_dataset.data[train_mask] 113 | test_dataset.data = test_dataset.data[test_mask] 114 | 115 | train_dataset.targets = np.array(train_dataset.targets)[train_mask] 116 | test_dataset.targets = np.array(test_dataset.targets)[test_mask] 117 | 118 | memory_dataset.data = memory_dataset.data[train_mask] 119 | memory_dataset.targets = np.array(memory_dataset.targets)[train_mask] 120 | 121 | train_loader = DataLoader(train_dataset, 122 | batch_size=setting.args.train.batch_size, shuffle=True, num_workers=4, pin_memory=True) 123 | test_loader = DataLoader(test_dataset, 124 | batch_size=setting.args.train.batch_size, shuffle=False, num_workers=4, pin_memory=True) 125 | memory_loader = DataLoader(memory_dataset, 126 | batch_size=setting.args.train.batch_size, shuffle=False, num_workers=4) 127 | 128 | setting.test_loaders.append(test_loader) 129 | setting.train_loaders.append(train_loader) 130 | setting.memory_loaders.append(memory_loader) 131 | setting.train_loader = train_loader 132 | 133 | setting.i += setting.N_CLASSES_PER_TASK 134 | return train_loader, memory_loader, test_loader 135 | 136 | 137 | def store_masked_label_loaders(train_dataset: datasets, test_dataset: datasets, memory_dataset: datasets, 138 | setting: ContinualDataset) -> Tuple[DataLoader, DataLoader]: 139 | """ 140 | Divides the dataset into tasks. 141 | :param train_dataset: train dataset 142 | :param test_dataset: test dataset 143 | :param setting: continual learning setting 144 | :return: train and test loaders 145 | """ 146 | train_mask = np.logical_and(np.array(train_dataset.labels) >= setting.i, 147 | np.array(train_dataset.labels) < setting.i + setting.N_CLASSES_PER_TASK) 148 | test_mask = np.logical_and(np.array(test_dataset.labels) >= setting.i, 149 | np.array(test_dataset.labels) < setting.i + setting.N_CLASSES_PER_TASK) 150 | 151 | train_dataset.data = train_dataset.data[train_mask] 152 | test_dataset.data = test_dataset.data[test_mask] 153 | 154 | train_dataset.targets = np.array(train_dataset.labels)[train_mask] 155 | test_dataset.targets = np.array(test_dataset.labels)[test_mask] 156 | 157 | memory_dataset.data = memory_dataset.data[train_mask] 158 | memory_dataset.targets = np.array(memory_dataset.labels)[train_mask] 159 | 160 | train_loader = DataLoader(train_dataset, 161 | batch_size=setting.args.train.batch_size, shuffle=True, num_workers=4) 162 | test_loader = DataLoader(test_dataset, 163 | batch_size=setting.args.train.batch_size, shuffle=False, num_workers=4) 164 | memory_loader = DataLoader(memory_dataset, 165 | batch_size=setting.args.train.batch_size, shuffle=False, num_workers=4) 166 | 167 | setting.test_loaders.append(test_loader) 168 | setting.train_loaders.append(train_loader) 169 | setting.memory_loaders.append(memory_loader) 170 | setting.train_loader = train_loader 171 | 172 | setting.i += setting.N_CLASSES_PER_TASK 173 | return train_loader, memory_loader, test_loader 174 | 175 | def store_domain_loaders(train_dataset: datasets, test_dataset: datasets, memory_dataset: datasets, 176 | setting: ContinualDataset) -> Tuple[DataLoader, DataLoader]: 177 | """ 178 | Divides the dataset into tasks. 179 | :param train_dataset: train dataset 180 | :param test_dataset: test dataset 181 | :param setting: continual learning setting 182 | :return: train and test loaders 183 | """ 184 | train_loader = DataLoader(train_dataset, 185 | batch_size=setting.args.train.batch_size, shuffle=True, num_workers=4, pin_memory=True) 186 | test_loader = DataLoader(test_dataset, 187 | batch_size=setting.args.train.batch_size, shuffle=False, num_workers=4, pin_memory=True) 188 | memory_loader = DataLoader(memory_dataset, 189 | batch_size=setting.args.train.batch_size, shuffle=False, num_workers=4) 190 | 191 | setting.test_loaders.append(test_loader) 192 | setting.train_loaders.append(train_loader) 193 | setting.memory_loaders.append(memory_loader) 194 | setting.train_loader = train_loader 195 | 196 | # setting.i += setting.N_CLASSES_PER_TASK 197 | return train_loader, memory_loader, test_loader 198 | 199 | 200 | 201 | 202 | def get_previous_train_loader(train_dataset: datasets, batch_size: int, 203 | setting: ContinualDataset) -> DataLoader: 204 | """ 205 | Creates a dataloader for the previous task. 206 | :param train_dataset: the entire training set 207 | :param batch_size: the desired batch size 208 | :param setting: the continual dataset at hand 209 | :return: a dataloader 210 | """ 211 | train_mask = np.logical_and(np.array(train_dataset.targets) >= 212 | setting.i - setting.N_CLASSES_PER_TASK, np.array(train_dataset.targets) 213 | < setting.i - setting.N_CLASSES_PER_TASK + setting.N_CLASSES_PER_TASK) 214 | 215 | train_dataset.data = train_dataset.data[train_mask] 216 | train_dataset.targets = np.array(train_dataset.targets)[train_mask] 217 | 218 | return DataLoader(train_dataset, batch_size=batch_size, shuffle=True) 219 | -------------------------------------------------------------------------------- /datasets/utils/validation.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | from PIL import Image 8 | import numpy as np 9 | import os 10 | from utils import create_if_not_exists 11 | import torchvision.transforms.transforms as transforms 12 | from torchvision import datasets 13 | 14 | 15 | class ValidationDataset(torch.utils.data.Dataset): 16 | def __init__(self, data: torch.Tensor, targets: np.ndarray, 17 | transform: transforms=None, target_transform: transforms=None) -> None: 18 | self.data = data 19 | self.targets = targets 20 | self.transform = transform 21 | self.target_transform = target_transform 22 | 23 | def __len__(self): 24 | return self.data.shape[0] 25 | 26 | def __getitem__(self, index): 27 | img, target = self.data[index], self.targets[index] 28 | 29 | # doing this so that it is consistent with all other datasets 30 | # to return a PIL Image 31 | if isinstance(img, np.ndarray): 32 | if np.max(img) < 2: 33 | img = Image.fromarray(np.uint8(img * 255)) 34 | else: 35 | img = Image.fromarray(img) 36 | else: 37 | img = Image.fromarray(img.numpy()) 38 | 39 | if self.transform is not None: 40 | img = self.transform(img) 41 | 42 | if self.target_transform is not None: 43 | target = self.target_transform(target) 44 | 45 | return img, target 46 | 47 | def get_train_val(train: datasets, test_transform: transforms, 48 | dataset: str, val_perc: float=0.1): 49 | """ 50 | Extract val_perc% of the training set as the validation set. 51 | :param train: training dataset 52 | :param test_transform: transformation of the test dataset 53 | :param dataset: dataset name 54 | :param val_perc: percentage of the training set to be extracted 55 | :return: the training set and the validation set 56 | """ 57 | dataset_length = train.data.shape[0] 58 | directory = 'datasets/val_permutations/' 59 | create_if_not_exists(directory) 60 | file_name = dataset + '.pt' 61 | if os.path.exists(directory + file_name): 62 | perm = torch.load(directory + file_name) 63 | else: 64 | perm = torch.randperm(dataset_length) 65 | torch.save(perm, directory + file_name) 66 | train.data = train.data[perm] 67 | train.targets = np.array(train.targets)[perm] 68 | test_dataset = ValidationDataset(train.data[:int(val_perc * dataset_length)], 69 | train.targets[:int(val_perc * dataset_length)], 70 | transform=test_transform) 71 | train.data = train.data[int(val_perc * dataset_length):] 72 | train.targets = train.targets[int(val_perc * dataset_length):] 73 | 74 | return train, test_dataset 75 | -------------------------------------------------------------------------------- /linear_eval_alltasks.py: -------------------------------------------------------------------------------- 1 | """ 2 | Evaluation script 3 | 4 | Originated from https://github.com/divyam3897/UCL/blob/main/linear_eval_alltasks.py 5 | 6 | Hacked together by / Copyright 2023 Divyam Madaan (https://github.com/divyam3897) 7 | """ 8 | import os 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import torchvision 13 | from tqdm import tqdm 14 | from arguments import get_args 15 | from augmentations import get_aug 16 | from models import get_model, get_backbone 17 | from tools import AverageMeter, knn_monitor 18 | from datasets import get_dataset 19 | from models.optimizers import get_optimizer, LR_Scheduler 20 | from utils.loggers import * 21 | from utils.metrics import forgetting 22 | 23 | 24 | def evaluate_single(model, dataset, test_loader, memory_loader, device, k, last=False) -> Tuple[list, list, list, list]: 25 | accs, accs_mask_classes = [], [] 26 | knn_accs, knn_accs_mask_classes = [], [] 27 | correct = correct_mask_classes = total = 0 28 | knn_acc, knn_acc_mask = knn_monitor(model.net.module.backbone, dataset, memory_loader, test_loader, device, args.cl_default, task_id=k, k=min(args.train.knn_k, len(dataset.memory_loaders[k].dataset))) 29 | 30 | return knn_acc 31 | 32 | 33 | def evaluate(model, dataset, device, classifier=None, last=False) -> Tuple[list, list]: 34 | """ 35 | Evaluates the accuracy of the model for each past task. 36 | :param model: the model to be evaluated 37 | :param dataset: the continual dataset at hand 38 | :return: a tuple of lists, containing the class-il 39 | and task-il accuracy for each task 40 | """ 41 | status = model.training 42 | model.eval() 43 | accs, accs_mask_classes = [], [] 44 | for k, test_loader in enumerate(dataset.test_loaders): 45 | if last and k < len(dataset.test_loaders) - 1: 46 | continue 47 | correct, correct_mask_classes, total = 0.0, 0.0, 0.0 48 | for data in test_loader: 49 | inputs, labels = data 50 | inputs, labels = inputs.to(device), labels.to(device) 51 | outputs = model(inputs) 52 | if classifier is not None: 53 | outputs = classifier(outputs) 54 | 55 | _, pred = torch.max(outputs.data, 1) 56 | correct += torch.sum(pred == labels).item() 57 | total += labels.shape[0] 58 | 59 | if dataset.SETTING == 'class-il': 60 | mask_classes(outputs, dataset, k) 61 | _, pred = torch.max(outputs.data, 1) 62 | correct_mask_classes += torch.sum(pred == labels).item() 63 | 64 | accs.append(correct / total * 100) 65 | accs_mask_classes.append(correct_mask_classes / total * 100) 66 | 67 | model.train(status) 68 | return accs, accs_mask_classes 69 | 70 | 71 | def main(device, args): 72 | 73 | dataset = get_dataset(args) 74 | 75 | results, results_mask_classes = [], [] 76 | for t in tqdm(range(0, dataset.N_TASKS), desc='Evaluatinng'): 77 | train_loader, memory_loader, test_loader = dataset.get_data_loaders(args) 78 | model_path = os.path.join(args.ckpt_dir, f"{args.model.cl_model}_{args.name}_{t}.pth") 79 | save_dict = torch.load(model_path, map_location='cpu') 80 | mean_norm = [[0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2615]] 81 | model = get_model(args, device, len(train_loader), get_aug(train=False, train_classifier=False, mean_std=mean_norm), task_id=t) 82 | 83 | msg = model.net.module.backbone.load_state_dict({k[16:]:v for k, v in save_dict['state_dict'].items() if 'backbone.' in k}, strict=True) 84 | model = model.to(args.device) 85 | 86 | accs = evaluate(model.net.module.backbone, dataset, device) 87 | results.append(accs[0]) 88 | results_mask_classes.append(accs[1]) 89 | mean_acc = np.mean(accs, axis=1) 90 | print_mean_accuracy(mean_acc, t + 1, dataset.SETTING) 91 | 92 | ci_mean_fgt = forgetting(results) 93 | ti_mean_fgt = forgetting(results_mask_classes) 94 | print(f'CI Forgetting: {ci_mean_fgt} \t TI Forgetting: {ti_mean_fgt}') 95 | 96 | 97 | if __name__ == "__main__": 98 | args = get_args() 99 | main(device=args.device, args=args) 100 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Training script 3 | 4 | Originated from https://github.com/divyam3897/UCL/blob/main/main.py 5 | 6 | Hacked together by / Copyright 2023 Divyam Madaan (https://github.com/divyam3897) 7 | """ 8 | import os 9 | import copy 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import torchvision 14 | import numpy as np 15 | from tqdm import tqdm 16 | from arguments import get_args 17 | from augmentations import get_aug 18 | from models import get_model 19 | from tools import AverageMeter, Logger, file_exist_check 20 | from datasets import get_dataset 21 | from datetime import datetime 22 | from utils.loggers import * 23 | from utils.metrics import mask_classes 24 | from utils.loggers import CsvLogger 25 | from datasets.utils.continual_dataset import ContinualDataset 26 | from models.utils.continual_model import ContinualModel 27 | from utils.tb_logger import TensorboardLogger 28 | from typing import Tuple 29 | from datasets import BACKBONES 30 | import wandb 31 | from pytorch_model_summary import summary 32 | 33 | 34 | def evaluate(model: ContinualModel, dataset: ContinualDataset, device, classifier=None, last=False) -> Tuple[list, list]: 35 | """ 36 | Evaluates the accuracy of the model for each past task. 37 | :param model: the model to be evaluated 38 | :param dataset: the continual dataset at hand 39 | :return: a tuple of lists, containing the class-il 40 | and task-il accuracy for each task 41 | """ 42 | status = model.training 43 | model.eval() 44 | accs, accs_mask_classes = [], [] 45 | for k, test_loader in enumerate(dataset.test_loaders): 46 | if last and k < len(dataset.test_loaders) - 1: 47 | continue 48 | correct, correct_mask_classes, total = 0.0, 0.0, 0.0 49 | for data in test_loader: 50 | inputs, labels = data 51 | inputs, labels = inputs.to(device), labels.to(device) 52 | outputs = model(inputs) 53 | if classifier is not None: 54 | outputs = classifier(outputs) 55 | 56 | _, pred = torch.max(outputs.data, 1) 57 | correct += torch.sum(pred == labels).item() 58 | total += labels.shape[0] 59 | 60 | if dataset.SETTING == 'class-il': 61 | mask_classes(outputs, dataset, k) 62 | _, pred = torch.max(outputs.data, 1) 63 | correct_mask_classes += torch.sum(pred == labels).item() 64 | 65 | accs.append(correct / total * 100) 66 | accs_mask_classes.append(correct_mask_classes / total * 100) 67 | 68 | model.train(status) 69 | return accs, accs_mask_classes 70 | 71 | 72 | def main(device, args): 73 | 74 | dataset = get_dataset(args) 75 | dataset_copy = get_dataset(args) 76 | train_loader, memory_loader, test_loader = dataset_copy.get_data_loaders(args) 77 | wandb.init(project="poc_lwf", sync_tensorboard=True) 78 | wandb.run.name = f"{args.model.cl_model}_{args.dataset.name}_n_alpha_{args.alpha}" 79 | 80 | # define model 81 | global_model = get_model(args, device, dataset_copy, dataset.get_transform(args), global_model=None) 82 | model = get_model(args, device, dataset_copy, dataset.get_transform(args), global_model=global_model) 83 | 84 | logger = Logger(matplotlib=args.logger.matplotlib, log_dir=args.log_dir) 85 | tb_logger = TensorboardLogger(args, dataset.SETTING) 86 | csv_logger = CsvLogger(dataset.SETTING, dataset.NAME, args.model.backbone) 87 | accuracy = 0 88 | results, results_mask_classes = [], [] 89 | 90 | for t in range(dataset.N_TASKS): 91 | train_loader, memory_loader, test_loader = dataset.get_data_loaders(args) 92 | 93 | global_progress = tqdm(range(0, args.train.stop_at_epoch), desc=f'Training') 94 | prev_mean_acc = 0. 95 | best_epoch = 0. 96 | 97 | if args.hcl and BACKBONES[args.dataset.name][t] != BACKBONES[args.dataset.name][t - 1]: 98 | model = get_model(args, device, dataset_copy, dataset.get_transform(args), task_id=t, global_model=global_model) 99 | print(summary(model.net.module.backbone, torch.zeros((1, 3, args.dataset.image_size, args.dataset.image_size)).to(device), show_input=True)) 100 | 101 | if hasattr(model, 'begin_task'): 102 | model.begin_task(t, dataset) 103 | 104 | if t: 105 | accs = evaluate(model, dataset, device, last=True) 106 | results[t-1] = results[t-1] + accs[0] 107 | results_mask_classes[t-1] = results_mask_classes[t-1] + accs[1] 108 | 109 | for epoch in global_progress: 110 | model.train() 111 | 112 | local_progress=tqdm(train_loader, desc=f'Epoch {epoch}/{args.train.num_epochs}', disable=args.hide_progress) 113 | for idx, data in enumerate(local_progress): 114 | (images1, images2, notaug_images), labels = data 115 | data_dict = model.observe(images1, labels, images2, notaug_images, t) 116 | 117 | logger.update_scalers(data_dict) 118 | tb_logger.log_loss(data_dict['loss'], args, epoch, t, idx) 119 | tb_logger.log_penalty(data_dict['penalty'], args, epoch, t, idx) 120 | tb_logger.log_lr(data_dict['lr'], args, epoch, t, idx) 121 | 122 | global_progress.set_postfix(data_dict) 123 | 124 | accs = evaluate(model.net.module.backbone, dataset, device) 125 | mean_acc = np.mean(accs, axis=1) 126 | 127 | epoch_dict = {"epoch":epoch, "accuracy": mean_acc} 128 | global_progress.set_postfix(epoch_dict) 129 | logger.update_scalers(epoch_dict) 130 | tb_logger.log_accuracy(accs, mean_acc, args, t) 131 | 132 | if (sum(mean_acc)/2.) - prev_mean_acc < -0.2: 133 | continue 134 | if args.cl_default: 135 | best_model = copy.deepcopy(model.net.module.backbone) 136 | else: 137 | best_model = copy.deepcopy(model.net.module) 138 | prev_mean_acc = sum(mean_acc)/2. 139 | best_epoch = epoch 140 | 141 | accs = evaluate(best_model, dataset, device) 142 | results.append(accs[0]) 143 | results_mask_classes.append(accs[1]) 144 | mean_acc = np.mean(accs, axis=1) 145 | print_mean_accuracy(mean_acc, t + 1, dataset.SETTING) 146 | 147 | if args.cl_default: 148 | model.global_model.net.module.backbone = copy.deepcopy(best_model) 149 | else: 150 | model.global_model.net.module = copy.deepcopy(best_model) 151 | print(f"Updated global model at epoch {best_epoch} with accuracy {prev_mean_acc}.") 152 | 153 | model_path = os.path.join(args.ckpt_dir, f"{args.model.cl_model}_{args.name}_{t}.pth") 154 | torch.save({ 155 | 'epoch': best_epoch+1, 156 | 'state_dict': model.global_model.net.state_dict(), 157 | }, model_path) 158 | print(f"Task Model saved to {model_path}") 159 | with open(os.path.join(args.log_dir, f"checkpoint_path.txt"), 'w+') as f: 160 | f.write(f'{model_path}') 161 | 162 | if hasattr(model, 'end_task'): 163 | model.end_task(dataset) 164 | 165 | csv_logger.add_bwt(results, results_mask_classes) 166 | csv_logger.add_forgetting(results, results_mask_classes) 167 | csv_logger.write(args.ckpt_dir, vars(args)) 168 | tb_logger.close() 169 | if args.eval is not False and args.cl_default is False: 170 | args.eval_from = model_path 171 | 172 | if __name__ == "__main__": 173 | args = get_args() 174 | main(device=args.device, args=args) 175 | completed_log_dir = args.log_dir.replace('in-progress', 'debug' if args.debug else 'completed') 176 | os.rename(args.log_dir, completed_log_dir) 177 | print(f'Log file has been saved to {completed_log_dir}') 178 | 179 | 180 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import importlib 3 | from .simsiam import SimSiam 4 | import torch 5 | from .backbones import resnet18, lenet, vgg16, alexnet, densenet, senet, regnet, inception, swin, resnext 6 | from datasets import N_CLASSES, BACKBONES 7 | from utils.losses import LabelSmoothing, KL_div_Loss 8 | 9 | 10 | def get_backbone(args, task_id=0): 11 | if args.hcl: 12 | backbone = BACKBONES[args.dataset.name][task_id] 13 | else: 14 | backbone = args.model.backbone 15 | 16 | net = eval(f"{backbone}(num_classes=N_CLASSES[args.dataset.name], args=args)") 17 | print("Backbone changed to ", backbone) 18 | 19 | net.n_classes = N_CLASSES[args.dataset.name] 20 | net.output_dim = net.fc.in_features 21 | if not args.cl_default: 22 | net.fc = torch.nn.Identity() 23 | 24 | return net 25 | 26 | 27 | def get_all_models(): 28 | return [model.split('.')[0] for model in os.listdir('models') 29 | if not model.find('__') > -1 and 'py' in model] 30 | 31 | def get_model(args, device, dataset, transform, global_model=None, task_id=0): 32 | allowed_models = ["distil", "qdi", "distilbuf"] 33 | if args.model.cl_model in allowed_models: 34 | loss = LabelSmoothing(smoothing=0.1) 35 | else: 36 | loss = torch.nn.CrossEntropyLoss() 37 | if args.model.name == 'simsiam': 38 | backbone = SimSiam(get_backbone(args, task_id=task_id)).to(device) 39 | if args.model.proj_layers is not None: 40 | backbone.projector.set_layers(args.model.proj_layers) 41 | 42 | names = {} 43 | for model in get_all_models(): 44 | mod = importlib.import_module('models.' + model) 45 | class_name = {x.lower():x for x in mod.__dir__()}[model.replace('_', '')] 46 | names[model] = getattr(mod, class_name) 47 | 48 | return names[args.model.cl_model](backbone, loss, args, dataset, transform, global_model) 49 | 50 | -------------------------------------------------------------------------------- /models/backbones/Alexnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class AlexNet(nn.Module): 4 | def __init__(self, num_classes, args): 5 | super(AlexNet, self).__init__() 6 | self.features = nn.Sequential( 7 | nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1), 8 | nn.ReLU(inplace=True), 9 | nn.MaxPool2d(kernel_size=2), 10 | nn.Conv2d(64, 192, kernel_size=3, padding=1), 11 | nn.ReLU(inplace=True), 12 | nn.MaxPool2d(kernel_size=2), 13 | nn.Conv2d(192, 384, kernel_size=3, padding=1), 14 | nn.ReLU(inplace=True), 15 | nn.Conv2d(384, 256, kernel_size=3, padding=1), 16 | nn.ReLU(inplace=True), 17 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 18 | nn.ReLU(inplace=True), 19 | nn.MaxPool2d(kernel_size=2), 20 | ) 21 | self.classifier = nn.Sequential( 22 | nn.Dropout(), 23 | nn.Linear(256 * 2 * 2, 4096), 24 | nn.ReLU(inplace=True), 25 | nn.Dropout(), 26 | nn.Linear(4096, 4096), 27 | nn.ReLU(inplace=True), 28 | ) 29 | self.fc = nn.Linear(4096, num_classes) 30 | 31 | def forward(self, x, return_features=False): 32 | x = self.features(x) 33 | x = x.view(x.size(0), 256 * 2 * 2) 34 | x = self.classifier(x) 35 | if return_features: 36 | return x 37 | x = self.fc(x) 38 | return x 39 | 40 | 41 | def alexnet(num_classes, args): 42 | return AlexNet(num_classes, args) 43 | -------------------------------------------------------------------------------- /models/backbones/Densenet.py: -------------------------------------------------------------------------------- 1 | # Originated from from https://github.com/kuangliu/pytorch-cifar/blob/master/models/densenet.py 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class Bottleneck(nn.Module): 10 | def __init__(self, in_planes, growth_rate): 11 | super(Bottleneck, self).__init__() 12 | self.bn1 = nn.BatchNorm2d(in_planes) 13 | self.conv1 = nn.Conv2d(in_planes, 4*growth_rate, kernel_size=1, bias=False) 14 | self.bn2 = nn.BatchNorm2d(4*growth_rate) 15 | self.conv2 = nn.Conv2d(4*growth_rate, growth_rate, kernel_size=3, padding=1, bias=False) 16 | 17 | def forward(self, x): 18 | out = self.conv1(F.relu(self.bn1(x))) 19 | out = self.conv2(F.relu(self.bn2(out))) 20 | out = torch.cat([out,x], 1) 21 | return out 22 | 23 | 24 | class Transition(nn.Module): 25 | def __init__(self, in_planes, out_planes): 26 | super(Transition, self).__init__() 27 | self.bn = nn.BatchNorm2d(in_planes) 28 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False) 29 | 30 | def forward(self, x): 31 | out = self.conv(F.relu(self.bn(x))) 32 | out = F.avg_pool2d(out, 2) 33 | return out 34 | 35 | 36 | class DenseNet(nn.Module): 37 | def __init__(self, block, nblocks, growth_rate=12, reduction=0.5, num_classes=10, args=None): 38 | super(DenseNet, self).__init__() 39 | self.growth_rate = growth_rate 40 | 41 | num_planes = 2*growth_rate 42 | self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, padding=1, bias=False) 43 | 44 | self.dense1 = self._make_dense_layers(block, num_planes, nblocks[0]) 45 | num_planes += nblocks[0]*growth_rate 46 | out_planes = int(math.floor(num_planes*reduction)) 47 | self.trans1 = Transition(num_planes, out_planes) 48 | num_planes = out_planes 49 | 50 | self.dense2 = self._make_dense_layers(block, num_planes, nblocks[1]) 51 | num_planes += nblocks[1]*growth_rate 52 | out_planes = int(math.floor(num_planes*reduction)) 53 | self.trans2 = Transition(num_planes, out_planes) 54 | num_planes = out_planes 55 | 56 | self.dense3 = self._make_dense_layers(block, num_planes, nblocks[2]) 57 | num_planes += nblocks[2]*growth_rate 58 | out_planes = int(math.floor(num_planes*reduction)) 59 | self.trans3 = Transition(num_planes, out_planes) 60 | num_planes = out_planes 61 | 62 | self.dense4 = self._make_dense_layers(block, num_planes, nblocks[3]) 63 | num_planes += nblocks[3]*growth_rate 64 | 65 | self.bn = nn.BatchNorm2d(num_planes) 66 | self.fc = nn.Linear(num_planes, num_classes) 67 | 68 | def _make_dense_layers(self, block, in_planes, nblock): 69 | layers = [] 70 | for i in range(nblock): 71 | layers.append(block(in_planes, self.growth_rate)) 72 | in_planes += self.growth_rate 73 | return nn.Sequential(*layers) 74 | 75 | def forward(self, x, return_features=False): 76 | out = self.conv1(x) 77 | out = self.trans1(self.dense1(out)) 78 | out = self.trans2(self.dense2(out)) 79 | out = self.trans3(self.dense3(out)) 80 | out = self.dense4(out) 81 | out = F.avg_pool2d(F.relu(self.bn(out)), 4) 82 | out = out.view(out.size(0), -1) 83 | if return_features: 84 | return out 85 | out = self.fc(out) 86 | return out 87 | 88 | def DenseNet121(num_classes, args): 89 | return DenseNet(Bottleneck, [6,12,24,16], growth_rate=32, num_classes=num_classes, args=args) 90 | 91 | def DenseNet169(num_classes, args): 92 | return DenseNet(Bottleneck, [6,12,32,32], growth_rate=32, num_classes=num_classes, args=args) 93 | 94 | def DenseNet201(num_classes, args): 95 | return DenseNet(Bottleneck, [6,12,48,32], growth_rate=32, num_classes=num_classes, args=args) 96 | 97 | def DenseNet161(num_classes, args): 98 | return DenseNet(Bottleneck, [6,12,36,24], growth_rate=48, num_classes=num_classes, args=args) 99 | 100 | def densenet_cifar(num_classes, args): 101 | return DenseNet(Bottleneck, [6,12,24,16], growth_rate=12, num_classes=num_classes, args=args) 102 | 103 | def densenet(num_classes, args): 104 | return DenseNet121(num_classes, args) 105 | -------------------------------------------------------------------------------- /models/backbones/Inception.py: -------------------------------------------------------------------------------- 1 | # Originated from https://github.com/kuangliu/pytorch-cifar/blob/master/models/googlenet.py 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class Inception(nn.Module): 8 | def __init__(self, in_planes, kernel_1_x, kernel_3_in, kernel_3_x, kernel_5_in, kernel_5_x, pool_planes): 9 | super(Inception, self).__init__() 10 | # 1x1 conv branch 11 | self.b1 = nn.Sequential( 12 | nn.Conv2d(in_planes, kernel_1_x, kernel_size=1), 13 | nn.BatchNorm2d(kernel_1_x), 14 | nn.ReLU(True), 15 | ) 16 | 17 | # 1x1 conv -> 3x3 conv branch 18 | self.b2 = nn.Sequential( 19 | nn.Conv2d(in_planes, kernel_3_in, kernel_size=1), 20 | nn.BatchNorm2d(kernel_3_in), 21 | nn.ReLU(True), 22 | nn.Conv2d(kernel_3_in, kernel_3_x, kernel_size=3, padding=1), 23 | nn.BatchNorm2d(kernel_3_x), 24 | nn.ReLU(True), 25 | ) 26 | 27 | # 1x1 conv -> 5x5 conv branch 28 | self.b3 = nn.Sequential( 29 | nn.Conv2d(in_planes, kernel_5_in, kernel_size=1), 30 | nn.BatchNorm2d(kernel_5_in), 31 | nn.ReLU(True), 32 | nn.Conv2d(kernel_5_in, kernel_5_x, kernel_size=3, padding=1), 33 | nn.BatchNorm2d(kernel_5_x), 34 | nn.ReLU(True), 35 | nn.Conv2d(kernel_5_x, kernel_5_x, kernel_size=3, padding=1), 36 | nn.BatchNorm2d(kernel_5_x), 37 | nn.ReLU(True), 38 | ) 39 | 40 | # 3x3 pool -> 1x1 conv branch 41 | self.b4 = nn.Sequential( 42 | nn.MaxPool2d(3, stride=1, padding=1), 43 | nn.Conv2d(in_planes, pool_planes, kernel_size=1), 44 | nn.BatchNorm2d(pool_planes), 45 | nn.ReLU(True), 46 | ) 47 | 48 | def forward(self, x): 49 | y1 = self.b1(x) 50 | y2 = self.b2(x) 51 | y3 = self.b3(x) 52 | y4 = self.b4(x) 53 | return torch.cat([y1,y2,y3,y4], 1) 54 | 55 | 56 | class GoogleNet(nn.Module): 57 | def __init__(self, num_classes, args): 58 | super(GoogleNet, self).__init__() 59 | self.pre_layers = nn.Sequential( 60 | nn.Conv2d(3, 192, kernel_size=3, padding=1), 61 | nn.BatchNorm2d(192), 62 | nn.ReLU(True), 63 | ) 64 | 65 | self.a3 = Inception(192, 64, 96, 128, 16, 32, 32) 66 | self.b3 = Inception(256, 128, 128, 192, 32, 96, 64) 67 | 68 | self.max_pool = nn.MaxPool2d(3, stride=2, padding=1) 69 | 70 | self.a4 = Inception(480, 192, 96, 208, 16, 48, 64) 71 | self.b4 = Inception(512, 160, 112, 224, 24, 64, 64) 72 | self.c4 = Inception(512, 128, 128, 256, 24, 64, 64) 73 | self.d4 = Inception(512, 112, 144, 288, 32, 64, 64) 74 | self.e4 = Inception(528, 256, 160, 320, 32, 128, 128) 75 | 76 | self.a5 = Inception(832, 256, 160, 320, 32, 128, 128) 77 | self.b5 = Inception(832, 384, 192, 384, 48, 128, 128) 78 | 79 | self.avgpool = nn.AvgPool2d(8, stride=1) 80 | self.fc = nn.Linear(1024, num_classes) 81 | 82 | def forward(self, x, return_features=False): 83 | x = self.pre_layers(x) 84 | x = self.a3(x) 85 | x = self.b3(x) 86 | x = self.max_pool(x) 87 | x = self.a4(x) 88 | x = self.b4(x) 89 | x = self.c4(x) 90 | x = self.d4(x) 91 | x = self.e4(x) 92 | x = self.max_pool(x) 93 | x = self.a5(x) 94 | x = self.b5(x) 95 | x = self.avgpool(x) 96 | x = x.view(x.size(0), -1) 97 | if return_features: 98 | return x 99 | x = self.fc(x) 100 | return x 101 | 102 | def inception(num_classes, args): 103 | return GoogleNet(num_classes, args) 104 | -------------------------------------------------------------------------------- /models/backbones/Lenet.py: -------------------------------------------------------------------------------- 1 | ## Originated from https://github.com/kuangliu/pytorch-cifar/blob/master/models/lenet.py 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as func 5 | 6 | 7 | class LeNet(nn.Module): 8 | def __init__(self, num_classes, args): 9 | super(LeNet, self).__init__() 10 | self.conv1 = nn.Conv2d(3, 6, kernel_size=5) 11 | self.conv2 = nn.Conv2d(6, 16, kernel_size=5) 12 | if args.dataset.image_size == 32: 13 | self.fc1 = nn.Linear(16*5*5, 120) 14 | else: 15 | self.fc1 = nn.Linear(2704, 120) 16 | self.fc2 = nn.Linear(120, 84) 17 | self.fc = nn.Linear(84, num_classes) 18 | 19 | def forward(self, x, return_features=False): 20 | x = func.relu(self.conv1(x)) 21 | x = func.max_pool2d(x, 2) 22 | x = func.relu(self.conv2(x)) 23 | x = func.max_pool2d(x, 2) 24 | x = x.view(x.size(0), -1) 25 | x = func.relu(self.fc1(x)) 26 | x = func.relu(self.fc2(x)) 27 | if return_features: 28 | return x 29 | x = self.fc(x) 30 | return x 31 | 32 | def lenet(num_classes, args): 33 | return LeNet(num_classes, args=args) 34 | -------------------------------------------------------------------------------- /models/backbones/Regnet.py: -------------------------------------------------------------------------------- 1 | '''RegNet in PyTorch. 2 | 3 | Paper: "Designing Network Design Spaces". 4 | 5 | Reference: https://github.com/keras-team/keras-applications/blob/master/keras_applications/efficientnet.py 6 | 7 | Originated from https://github.com/kuangliu/pytorch-cifar/blob/master/models/regnet.py 8 | ''' 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | class SE(nn.Module): 15 | '''Squeeze-and-Excitation block.''' 16 | 17 | def __init__(self, in_planes, se_planes): 18 | super(SE, self).__init__() 19 | self.se1 = nn.Conv2d(in_planes, se_planes, kernel_size=1, bias=True) 20 | self.se2 = nn.Conv2d(se_planes, in_planes, kernel_size=1, bias=True) 21 | 22 | def forward(self, x): 23 | out = F.adaptive_avg_pool2d(x, (1, 1)) 24 | out = F.relu(self.se1(out)) 25 | out = self.se2(out).sigmoid() 26 | out = x * out 27 | return out 28 | 29 | 30 | class Block(nn.Module): 31 | def __init__(self, w_in, w_out, stride, group_width, bottleneck_ratio, se_ratio): 32 | super(Block, self).__init__() 33 | # 1x1 34 | w_b = int(round(w_out * bottleneck_ratio)) 35 | self.conv1 = nn.Conv2d(w_in, w_b, kernel_size=1, bias=False) 36 | self.bn1 = nn.BatchNorm2d(w_b) 37 | # 3x3 38 | num_groups = w_b // group_width 39 | self.conv2 = nn.Conv2d(w_b, w_b, kernel_size=3, 40 | stride=stride, padding=1, groups=num_groups, bias=False) 41 | self.bn2 = nn.BatchNorm2d(w_b) 42 | # se 43 | self.with_se = se_ratio > 0 44 | if self.with_se: 45 | w_se = int(round(w_in * se_ratio)) 46 | self.se = SE(w_b, w_se) 47 | # 1x1 48 | self.conv3 = nn.Conv2d(w_b, w_out, kernel_size=1, bias=False) 49 | self.bn3 = nn.BatchNorm2d(w_out) 50 | 51 | self.shortcut = nn.Sequential() 52 | if stride != 1 or w_in != w_out: 53 | self.shortcut = nn.Sequential( 54 | nn.Conv2d(w_in, w_out, 55 | kernel_size=1, stride=stride, bias=False), 56 | nn.BatchNorm2d(w_out) 57 | ) 58 | 59 | def forward(self, x): 60 | out = F.relu(self.bn1(self.conv1(x))) 61 | out = F.relu(self.bn2(self.conv2(out))) 62 | if self.with_se: 63 | out = self.se(out) 64 | out = self.bn3(self.conv3(out)) 65 | out += self.shortcut(x) 66 | out = F.relu(out) 67 | return out 68 | 69 | 70 | class RegNet(nn.Module): 71 | def __init__(self, cfg, num_classes=10): 72 | super(RegNet, self).__init__() 73 | self.cfg = cfg 74 | self.in_planes = 64 75 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, 76 | stride=1, padding=1, bias=False) 77 | self.bn1 = nn.BatchNorm2d(64) 78 | self.layer1 = self._make_layer(0) 79 | self.layer2 = self._make_layer(1) 80 | self.layer3 = self._make_layer(2) 81 | self.layer4 = self._make_layer(3) 82 | self.fc = nn.Linear(self.cfg['widths'][-1], num_classes) 83 | 84 | def _make_layer(self, idx): 85 | depth = self.cfg['depths'][idx] 86 | width = self.cfg['widths'][idx] 87 | stride = self.cfg['strides'][idx] 88 | group_width = self.cfg['group_width'] 89 | bottleneck_ratio = self.cfg['bottleneck_ratio'] 90 | se_ratio = self.cfg['se_ratio'] 91 | 92 | layers = [] 93 | for i in range(depth): 94 | s = stride if i == 0 else 1 95 | layers.append(Block(self.in_planes, width, 96 | s, group_width, bottleneck_ratio, se_ratio)) 97 | self.in_planes = width 98 | return nn.Sequential(*layers) 99 | 100 | def forward(self, x, return_features=False): 101 | out = F.relu(self.bn1(self.conv1(x))) 102 | out = self.layer1(out) 103 | out = self.layer2(out) 104 | out = self.layer3(out) 105 | out = self.layer4(out) 106 | out = F.adaptive_avg_pool2d(out, (1, 1)) 107 | out = out.view(out.size(0), -1) 108 | if return_features: 109 | return out 110 | out = self.fc(out) 111 | return out 112 | 113 | 114 | def RegNetX_200MF(num_classes, args): 115 | cfg = { 116 | 'depths': [1, 1, 4, 7], 117 | 'widths': [24, 56, 152, 368], 118 | 'strides': [1, 1, 2, 2], 119 | 'group_width': 8, 120 | 'bottleneck_ratio': 1, 121 | 'se_ratio': 0, 122 | } 123 | 124 | return RegNet(cfg, num_classes) 125 | 126 | 127 | def RegNetX_400MF(num_classes, args): 128 | cfg = { 129 | 'depths': [1, 2, 7, 12], 130 | 'widths': [32, 64, 160, 384], 131 | 'strides': [1, 1, 2, 2], 132 | 'group_width': 16, 133 | 'bottleneck_ratio': 1, 134 | 'se_ratio': 0, 135 | } 136 | return RegNet(cfg, num_classes) 137 | 138 | 139 | def regnet(num_classes, args): 140 | return RegNetX_200MF(num_classes, args) 141 | 142 | -------------------------------------------------------------------------------- /models/backbones/ResNext.py: -------------------------------------------------------------------------------- 1 | '''ResNeXt in PyTorch. 2 | 3 | See the paper "Aggregated Residual Transformations for Deep Neural Networks" for more details. 4 | 5 | Originated from https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnext.py 6 | ''' 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class Block(nn.Module): 13 | '''Grouped convolution block.''' 14 | expansion = 2 15 | 16 | def __init__(self, in_planes, cardinality=32, bottleneck_width=4, stride=1): 17 | super(Block, self).__init__() 18 | group_width = cardinality * bottleneck_width 19 | self.conv1 = nn.Conv2d(in_planes, group_width, kernel_size=1, bias=False) 20 | self.bn1 = nn.BatchNorm2d(group_width) 21 | self.conv2 = nn.Conv2d(group_width, group_width, kernel_size=3, stride=stride, padding=1, groups=cardinality, bias=False) 22 | self.bn2 = nn.BatchNorm2d(group_width) 23 | self.conv3 = nn.Conv2d(group_width, self.expansion*group_width, kernel_size=1, bias=False) 24 | self.bn3 = nn.BatchNorm2d(self.expansion*group_width) 25 | 26 | self.shortcut = nn.Sequential() 27 | if stride != 1 or in_planes != self.expansion*group_width: 28 | self.shortcut = nn.Sequential( 29 | nn.Conv2d(in_planes, self.expansion*group_width, kernel_size=1, stride=stride, bias=False), 30 | nn.BatchNorm2d(self.expansion*group_width) 31 | ) 32 | 33 | def forward(self, x): 34 | out = F.relu(self.bn1(self.conv1(x))) 35 | out = F.relu(self.bn2(self.conv2(out))) 36 | out = self.bn3(self.conv3(out)) 37 | out += self.shortcut(x) 38 | out = F.relu(out) 39 | return out 40 | 41 | 42 | class ResNeXt(nn.Module): 43 | def __init__(self, num_blocks, cardinality, bottleneck_width, num_classes, args): 44 | super(ResNeXt, self).__init__() 45 | self.cardinality = cardinality 46 | self.bottleneck_width = bottleneck_width 47 | self.in_planes = 64 48 | layer1_stride = 2 if args.dataset.image_size == 64 else 1 49 | 50 | self.conv1 = nn.Conv2d(3, 64, kernel_size=1, bias=False) 51 | self.bn1 = nn.BatchNorm2d(64) 52 | self.layer1 = self._make_layer(num_blocks[0], layer1_stride) 53 | self.layer2 = self._make_layer(num_blocks[1], 2) 54 | self.layer3 = self._make_layer(num_blocks[2], 2) 55 | # self.layer4 = self._make_layer(num_blocks[3], 2) 56 | self.fc = nn.Linear(cardinality*bottleneck_width*8, num_classes) 57 | 58 | def _make_layer(self, num_blocks, stride): 59 | strides = [stride] + [1]*(num_blocks-1) 60 | layers = [] 61 | for stride in strides: 62 | layers.append(Block(self.in_planes, self.cardinality, self.bottleneck_width, stride)) 63 | self.in_planes = Block.expansion * self.cardinality * self.bottleneck_width 64 | # Increase bottleneck_width by 2 after each stage. 65 | self.bottleneck_width *= 2 66 | return nn.Sequential(*layers) 67 | 68 | def forward(self, x, return_features=False): 69 | out = F.relu(self.bn1(self.conv1(x))) 70 | out = self.layer1(out) 71 | out = self.layer2(out) 72 | out = self.layer3(out) 73 | # out = self.layer4(out) 74 | out = F.avg_pool2d(out, 8) 75 | out = out.view(out.size(0), -1) 76 | if return_features: 77 | return out 78 | out = self.fc(out) 79 | return out 80 | 81 | 82 | def ResNeXt29_2x64d(num_classes, args): 83 | return ResNeXt(num_blocks=[3,3,3], cardinality=2, bottleneck_width=64, num_classes=num_classes, args=args) 84 | 85 | def ResNeXt29_4x64d(num_classes, args): 86 | return ResNeXt(num_blocks=[3,3,3], cardinality=4, bottleneck_width=64, num_classes=num_classes, args=args) 87 | 88 | def ResNeXt29_8x64d(num_classes, args): 89 | return ResNeXt(num_blocks=[3,3,3], cardinality=8, bottleneck_width=64, num_classes=num_classes, args=args) 90 | 91 | def ResNeXt29_32x4d(num_classes, args): 92 | return ResNeXt(num_blocks=[3,3,3], cardinality=32, bottleneck_width=4, num_classes=num_classes, args=args) 93 | 94 | def resnext(num_classes, args): 95 | return ResNeXt29_2x64d(num_classes, args=args) 96 | -------------------------------------------------------------------------------- /models/backbones/Senet.py: -------------------------------------------------------------------------------- 1 | '''SENet in PyTorch. 2 | 3 | SENet is the winner of ImageNet-2017. The paper is not released yet. 4 | 5 | Originated from https://github.com/kuangliu/pytorch-cifar/blob/master/models/senet.py 6 | ''' 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class BasicBlock(nn.Module): 13 | def __init__(self, in_planes, planes, stride=1): 14 | super(BasicBlock, self).__init__() 15 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 16 | self.bn1 = nn.BatchNorm2d(planes) 17 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 18 | self.bn2 = nn.BatchNorm2d(planes) 19 | 20 | self.shortcut = nn.Sequential() 21 | if stride != 1 or in_planes != planes: 22 | self.shortcut = nn.Sequential( 23 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False), 24 | nn.BatchNorm2d(planes) 25 | ) 26 | 27 | # SE layers 28 | self.fc1 = nn.Conv2d(planes, planes//16, kernel_size=1) # Use nn.Conv2d instead of nn.Linear 29 | self.fc2 = nn.Conv2d(planes//16, planes, kernel_size=1) 30 | 31 | def forward(self, x): 32 | out = F.relu(self.bn1(self.conv1(x))) 33 | out = self.bn2(self.conv2(out)) 34 | 35 | # Squeeze 36 | w = F.avg_pool2d(out, out.size(2)) 37 | w = F.relu(self.fc1(w)) 38 | w = F.sigmoid(self.fc2(w)) 39 | # Excitation 40 | out = out * w # New broadcasting feature from v0.2! 41 | 42 | out += self.shortcut(x) 43 | out = F.relu(out) 44 | return out 45 | 46 | 47 | class PreActBlock(nn.Module): 48 | def __init__(self, in_planes, planes, stride=1): 49 | super(PreActBlock, self).__init__() 50 | self.bn1 = nn.BatchNorm2d(in_planes) 51 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 52 | self.bn2 = nn.BatchNorm2d(planes) 53 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 54 | 55 | if stride != 1 or in_planes != planes: 56 | self.shortcut = nn.Sequential( 57 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False) 58 | ) 59 | 60 | # SE layers 61 | self.fc1 = nn.Conv2d(planes, planes//16, kernel_size=1) 62 | self.fc2 = nn.Conv2d(planes//16, planes, kernel_size=1) 63 | 64 | def forward(self, x): 65 | out = F.relu(self.bn1(x)) 66 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 67 | out = self.conv1(out) 68 | out = self.conv2(F.relu(self.bn2(out))) 69 | 70 | # Squeeze 71 | w = F.avg_pool2d(out, out.size(2)) 72 | w = F.relu(self.fc1(w)) 73 | w = F.sigmoid(self.fc2(w)) 74 | # Excitation 75 | out = out * w 76 | 77 | out += shortcut 78 | return out 79 | 80 | 81 | class SENet(nn.Module): 82 | def __init__(self, block, num_blocks, num_classes, args): 83 | super(SENet, self).__init__() 84 | self.in_planes = 64 85 | layer1_stride = 2 if args.dataset.image_size == 64 else 1 86 | 87 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 88 | self.bn1 = nn.BatchNorm2d(64) 89 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=layer1_stride) 90 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 91 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 92 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 93 | self.fc = nn.Linear(512, num_classes) 94 | 95 | def _make_layer(self, block, planes, num_blocks, stride): 96 | strides = [stride] + [1]*(num_blocks-1) 97 | layers = [] 98 | for stride in strides: 99 | layers.append(block(self.in_planes, planes, stride)) 100 | self.in_planes = planes 101 | return nn.Sequential(*layers) 102 | 103 | def forward(self, x, return_features=False): 104 | out = F.relu(self.bn1(self.conv1(x))) 105 | out = self.layer1(out) 106 | out = self.layer2(out) 107 | out = self.layer3(out) 108 | out = self.layer4(out) 109 | out = F.avg_pool2d(out, 4) 110 | out = out.view(out.size(0), -1) 111 | if return_features: 112 | return out 113 | out = self.fc(out) 114 | return out 115 | 116 | 117 | def senet18(num_classes, args): 118 | return SENet(PreActBlock, [2,2,2,2], num_classes, args) 119 | -------------------------------------------------------------------------------- /models/backbones/Swin.py: -------------------------------------------------------------------------------- 1 | # https://github.com/berniwal/swin-transformer-pytorch 2 | 3 | import torch 4 | from torch import nn, einsum 5 | import numpy as np 6 | from einops import rearrange, repeat 7 | 8 | 9 | class CyclicShift(nn.Module): 10 | def __init__(self, displacement): 11 | super().__init__() 12 | self.displacement = displacement 13 | 14 | def forward(self, x): 15 | return torch.roll(x, shifts=(self.displacement, self.displacement), dims=(1, 2)) 16 | 17 | 18 | class Residual(nn.Module): 19 | def __init__(self, fn): 20 | super().__init__() 21 | self.fn = fn 22 | 23 | def forward(self, x, **kwargs): 24 | return self.fn(x, **kwargs) + x 25 | 26 | 27 | class PreNorm(nn.Module): 28 | def __init__(self, dim, fn): 29 | super().__init__() 30 | self.norm = nn.LayerNorm(dim) 31 | self.fn = fn 32 | 33 | def forward(self, x, **kwargs): 34 | return self.fn(self.norm(x), **kwargs) 35 | 36 | 37 | class FeedForward(nn.Module): 38 | def __init__(self, dim, hidden_dim): 39 | super().__init__() 40 | self.net = nn.Sequential( 41 | nn.Linear(dim, hidden_dim), 42 | nn.GELU(), 43 | nn.Linear(hidden_dim, dim), 44 | ) 45 | 46 | def forward(self, x): 47 | return self.net(x) 48 | 49 | 50 | def create_mask(window_size, displacement, upper_lower, left_right): 51 | mask = torch.zeros(window_size ** 2, window_size ** 2) 52 | 53 | if upper_lower: 54 | mask[-displacement * window_size:, :-displacement * window_size] = float('-inf') 55 | mask[:-displacement * window_size, -displacement * window_size:] = float('-inf') 56 | 57 | if left_right: 58 | mask = rearrange(mask, '(h1 w1) (h2 w2) -> h1 w1 h2 w2', h1=window_size, h2=window_size) 59 | mask[:, -displacement:, :, :-displacement] = float('-inf') 60 | mask[:, :-displacement, :, -displacement:] = float('-inf') 61 | mask = rearrange(mask, 'h1 w1 h2 w2 -> (h1 w1) (h2 w2)') 62 | 63 | return mask 64 | 65 | 66 | def get_relative_distances(window_size): 67 | indices = torch.tensor(np.array([[x, y] for x in range(window_size) for y in range(window_size)])) 68 | distances = indices[None, :, :] - indices[:, None, :] 69 | return distances 70 | 71 | 72 | class WindowAttention(nn.Module): 73 | def __init__(self, dim, heads, head_dim, shifted, window_size, relative_pos_embedding): 74 | super().__init__() 75 | inner_dim = head_dim * heads 76 | 77 | self.heads = heads 78 | self.scale = head_dim ** -0.5 79 | self.window_size = window_size 80 | self.relative_pos_embedding = relative_pos_embedding 81 | self.shifted = shifted 82 | 83 | if self.shifted: 84 | displacement = window_size // 2 85 | self.cyclic_shift = CyclicShift(-displacement) 86 | self.cyclic_back_shift = CyclicShift(displacement) 87 | self.upper_lower_mask = nn.Parameter(create_mask(window_size=window_size, displacement=displacement, 88 | upper_lower=True, left_right=False), requires_grad=False) 89 | self.left_right_mask = nn.Parameter(create_mask(window_size=window_size, displacement=displacement, 90 | upper_lower=False, left_right=True), requires_grad=False) 91 | 92 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) 93 | 94 | if self.relative_pos_embedding: 95 | self.relative_indices = get_relative_distances(window_size) + window_size - 1 96 | self.pos_embedding = nn.Parameter(torch.randn(2 * window_size - 1, 2 * window_size - 1)) 97 | else: 98 | self.pos_embedding = nn.Parameter(torch.randn(window_size ** 2, window_size ** 2)) 99 | 100 | self.to_out = nn.Linear(inner_dim, dim) 101 | 102 | def forward(self, x): 103 | if self.shifted: 104 | x = self.cyclic_shift(x) 105 | 106 | b, n_h, n_w, _, h = *x.shape, self.heads 107 | 108 | qkv = self.to_qkv(x).chunk(3, dim=-1) 109 | nw_h = n_h // self.window_size 110 | nw_w = n_w // self.window_size 111 | 112 | q, k, v = map( 113 | lambda t: rearrange(t, 'b (nw_h w_h) (nw_w w_w) (h d) -> b h (nw_h nw_w) (w_h w_w) d', 114 | h=h, w_h=self.window_size, w_w=self.window_size), qkv) 115 | 116 | dots = einsum('b h w i d, b h w j d -> b h w i j', q, k) * self.scale 117 | 118 | if self.relative_pos_embedding: 119 | dots += self.pos_embedding[self.relative_indices[:, :, 0], self.relative_indices[:, :, 1]] 120 | else: 121 | dots += self.pos_embedding 122 | 123 | if self.shifted: 124 | dots[:, :, -nw_w:] += self.upper_lower_mask 125 | dots[:, :, nw_w - 1::nw_w] += self.left_right_mask 126 | 127 | attn = dots.softmax(dim=-1) 128 | 129 | out = einsum('b h w i j, b h w j d -> b h w i d', attn, v) 130 | out = rearrange(out, 'b h (nw_h nw_w) (w_h w_w) d -> b (nw_h w_h) (nw_w w_w) (h d)', 131 | h=h, w_h=self.window_size, w_w=self.window_size, nw_h=nw_h, nw_w=nw_w) 132 | out = self.to_out(out) 133 | 134 | if self.shifted: 135 | out = self.cyclic_back_shift(out) 136 | return out 137 | 138 | 139 | class SwinBlock(nn.Module): 140 | def __init__(self, dim, heads, head_dim, mlp_dim, shifted, window_size, relative_pos_embedding): 141 | super().__init__() 142 | self.attention_block = Residual(PreNorm(dim, WindowAttention(dim=dim, 143 | heads=heads, 144 | head_dim=head_dim, 145 | shifted=shifted, 146 | window_size=window_size, 147 | relative_pos_embedding=relative_pos_embedding))) 148 | self.mlp_block = Residual(PreNorm(dim, FeedForward(dim=dim, hidden_dim=mlp_dim))) 149 | 150 | def forward(self, x): 151 | x = self.attention_block(x) 152 | x = self.mlp_block(x) 153 | return x 154 | 155 | 156 | class PatchMerging(nn.Module): 157 | def __init__(self, in_channels, out_channels, downscaling_factor): 158 | super().__init__() 159 | self.downscaling_factor = downscaling_factor 160 | self.patch_merge = nn.Unfold(kernel_size=downscaling_factor, stride=downscaling_factor, padding=0) 161 | self.linear = nn.Linear(in_channels * downscaling_factor ** 2, out_channels) 162 | 163 | def forward(self, x): 164 | b, c, h, w = x.shape 165 | new_h, new_w = h // self.downscaling_factor, w // self.downscaling_factor 166 | x = self.patch_merge(x).view(b, -1, new_h, new_w).permute(0, 2, 3, 1) 167 | x = self.linear(x) 168 | return x 169 | 170 | 171 | class StageModule(nn.Module): 172 | def __init__(self, in_channels, hidden_dimension, layers, downscaling_factor, num_heads, head_dim, window_size, 173 | relative_pos_embedding): 174 | super().__init__() 175 | assert layers % 2 == 0, 'Stage layers need to be divisible by 2 for regular and shifted block.' 176 | 177 | self.patch_partition = PatchMerging(in_channels=in_channels, out_channels=hidden_dimension, 178 | downscaling_factor=downscaling_factor) 179 | 180 | self.layers = nn.ModuleList([]) 181 | for _ in range(layers // 2): 182 | self.layers.append(nn.ModuleList([ 183 | SwinBlock(dim=hidden_dimension, heads=num_heads, head_dim=head_dim, mlp_dim=hidden_dimension * 4, 184 | shifted=False, window_size=window_size, relative_pos_embedding=relative_pos_embedding), 185 | SwinBlock(dim=hidden_dimension, heads=num_heads, head_dim=head_dim, mlp_dim=hidden_dimension * 4, 186 | shifted=True, window_size=window_size, relative_pos_embedding=relative_pos_embedding), 187 | ])) 188 | 189 | def forward(self, x): 190 | x = self.patch_partition(x) 191 | for regular_block, shifted_block in self.layers: 192 | x = regular_block(x) 193 | x = shifted_block(x) 194 | return x.permute(0, 3, 1, 2) 195 | 196 | 197 | class SwinTransformer(nn.Module): 198 | def __init__(self, *, hidden_dim, layers, heads, channels=3, num_classes=10, head_dim=32, window_size=4, 199 | downscaling_factors=(2, 2, 2, 1), relative_pos_embedding=True, args=None): 200 | super().__init__() 201 | 202 | self.stage1 = StageModule(in_channels=channels, hidden_dimension=hidden_dim, layers=layers[0], 203 | downscaling_factor=downscaling_factors[0], num_heads=heads[0], head_dim=head_dim, 204 | window_size=window_size, relative_pos_embedding=relative_pos_embedding) 205 | self.stage2 = StageModule(in_channels=hidden_dim, hidden_dimension=hidden_dim * 2, layers=layers[1], 206 | downscaling_factor=downscaling_factors[1], num_heads=heads[1], head_dim=head_dim, 207 | window_size=window_size, relative_pos_embedding=relative_pos_embedding) 208 | self.stage3 = StageModule(in_channels=hidden_dim * 2, hidden_dimension=hidden_dim * 4, layers=layers[2], 209 | downscaling_factor=downscaling_factors[2], num_heads=heads[2], head_dim=head_dim, 210 | window_size=window_size, relative_pos_embedding=relative_pos_embedding) 211 | self.stage4 = StageModule(in_channels=hidden_dim * 4, hidden_dimension=hidden_dim * 8, layers=layers[3], 212 | downscaling_factor=downscaling_factors[3], num_heads=heads[3], head_dim=head_dim, 213 | window_size=window_size, relative_pos_embedding=relative_pos_embedding) 214 | 215 | self.final_layer_norm = nn.LayerNorm(hidden_dim * 8) 216 | self.fc = nn.Linear(hidden_dim * 8, num_classes) 217 | 218 | def forward(self, img, return_features=False): 219 | x = self.stage1(img) 220 | x = self.stage2(x) 221 | x = self.stage3(x) 222 | x = self.stage4(x) 223 | x = x.mean(dim=[2, 3]) 224 | x = self.final_layer_norm(x) 225 | if return_features: 226 | return x 227 | return self.fc(x) 228 | 229 | 230 | def swin_t(num_classes, args, hidden_dim=96, layers=(2, 2, 6, 2), heads=(3, 6, 12, 24), **kwargs): 231 | return SwinTransformer(hidden_dim=hidden_dim, layers=layers, heads=heads, num_classes=num_classes, args=args, **kwargs) 232 | 233 | 234 | def swin_s(num_classes, hidden_dim=96, layers=(2, 2, 18, 2), heads=(3, 6, 12, 24), **kwargs): 235 | return SwinTransformer(hidden_dim=hidden_dim, layers=layers, heads=heads, num_classes=num_classes, **kwargs) 236 | 237 | 238 | def swin_b(num_classes, hidden_dim=128, layers=(2, 2, 18, 2), heads=(4, 8, 16, 32), **kwargs): 239 | return SwinTransformer(hidden_dim=hidden_dim, layers=layers, heads=heads, **kwargs) 240 | 241 | 242 | def swin_l(num_classes, hidden_dim=192, layers=(2, 2, 18, 2), heads=(6, 12, 24, 48), **kwargs): 243 | return SwinTransformer(hidden_dim=hidden_dim, layers=layers, heads=heads, **kwargs) 244 | -------------------------------------------------------------------------------- /models/backbones/Vgg.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Modified from https://github.com/pytorch/vision.git 3 | ''' 4 | import math 5 | 6 | import torch.nn as nn 7 | import torch.nn.init as init 8 | 9 | __all__ = [ 10 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 11 | 'vgg19_bn', 'vgg19', 12 | ] 13 | 14 | 15 | class VGG(nn.Module): 16 | ''' 17 | VGG model 18 | ''' 19 | def __init__(self, features, num_classes, args): 20 | super(VGG, self).__init__() 21 | self.features = features 22 | self.classifier = nn.Sequential( 23 | nn.Dropout(), 24 | nn.Linear(512, 512), 25 | nn.ReLU(True), 26 | nn.Dropout(), 27 | nn.Linear(512, 512), 28 | nn.ReLU(True), 29 | ) 30 | self.fc = nn.Linear(512, num_classes) 31 | # Initialize weights 32 | for m in self.modules(): 33 | if isinstance(m, nn.Conv2d): 34 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 35 | m.weight.data.normal_(0, math.sqrt(2. / n)) 36 | m.bias.data.zero_() 37 | 38 | 39 | def forward(self, x, return_features=False): 40 | x = self.features(x) 41 | x = x.view(x.size(0), -1) 42 | x = self.classifier(x) 43 | if return_features: 44 | return x 45 | x = self.fc(x) 46 | return x 47 | 48 | 49 | def make_layers(cfg, batch_norm=False): 50 | layers = [] 51 | in_channels = 3 52 | for v in cfg: 53 | if v == 'M': 54 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 55 | else: 56 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 57 | if batch_norm: 58 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 59 | else: 60 | layers += [conv2d, nn.ReLU(inplace=True)] 61 | in_channels = v 62 | return nn.Sequential(*layers) 63 | 64 | 65 | cfg = { 66 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 67 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 68 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 69 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 70 | 512, 512, 512, 512, 'M'], 71 | } 72 | 73 | 74 | def vgg11(num_classes, args): 75 | """VGG 11-layer model (configuration "A")""" 76 | return VGG(make_layers(cfg['A']), num_classes, args) 77 | 78 | 79 | def vgg11_bn(num_classes, args): 80 | """VGG 11-layer model (configuration "A") with batch normalization""" 81 | return VGG(make_layers(cfg['A'], batch_norm=True), num_classes, args) 82 | 83 | 84 | def vgg13(num_classes, args): 85 | """VGG 13-layer model (configuration "B")""" 86 | return VGG(make_layers(cfg['B']), num_classes, args) 87 | 88 | 89 | def vgg13_bn(num_classes, args): 90 | """VGG 13-layer model (configuration "B") with batch normalization""" 91 | return VGG(make_layers(cfg['B'], batch_norm=True), num_classes, args) 92 | 93 | 94 | def vgg16(num_classes, args): 95 | """VGG 16-layer model (configuration "D")""" 96 | return VGG(make_layers(cfg['D']), num_classes, args) 97 | 98 | 99 | def vgg16_bn(num_classes, args): 100 | """VGG 16-layer model (configuration "D") with batch normalization""" 101 | return VGG(make_layers(cfg['D'], batch_norm=True), num_classes, args) 102 | 103 | 104 | def vgg19(num_classes, args): 105 | """VGG 19-layer model (configuration "E")""" 106 | return VGG(make_layers(cfg['E']), num_classes, args) 107 | 108 | 109 | def vgg19_bn(num_classes, args): 110 | """VGG 19-layer model (configuration 'E') with batch normalization""" 111 | return VGG(make_layers(cfg['E'], batch_norm=True), num_classes, args) 112 | -------------------------------------------------------------------------------- /models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from .ResNet18 import resnet18 as resnet18 2 | from .Lenet import lenet as lenet 3 | from .Vgg import vgg16_bn as vgg16 4 | from .Alexnet import alexnet as alexnet 5 | from .Densenet import densenet as densenet 6 | from .Senet import senet18 as senet 7 | from .Regnet import regnet as regnet 8 | from .Inception import inception as inception 9 | from .Swin import swin_t as swin 10 | from .ResNext import resnext as resnext 11 | 12 | -------------------------------------------------------------------------------- /models/backbones/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/HCL/4aa98387995946061ade7b3eede3715e7c29c1ba/models/backbones/utils/__init__.py -------------------------------------------------------------------------------- /models/backbones/utils/modules.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn.parameter import Parameter 9 | 10 | 11 | class AlphaModule(nn.Module): 12 | def __init__(self, shape): 13 | super(AlphaModule, self).__init__() 14 | if not isinstance(shape, tuple): 15 | shape = (shape,) 16 | self.alpha = Parameter(torch.rand(tuple([1] + list(shape))) * 0.1, 17 | requires_grad=True) 18 | 19 | def forward(self, x): 20 | return x * self.alpha 21 | 22 | def parameters(self, recurse: bool = True): 23 | yield self.alpha 24 | 25 | 26 | class ListModule(nn.Module): 27 | def __init__(self, *args): 28 | super(ListModule, self).__init__() 29 | self.idx = 0 30 | for module in args: 31 | self.add_module(str(self.idx), module) 32 | self.idx += 1 33 | 34 | def append(self, module): 35 | self.add_module(str(self.idx), module) 36 | self.idx += 1 37 | 38 | def __getitem__(self, idx): 39 | if idx < 0: 40 | idx += self.idx 41 | if idx >= len(self._modules): 42 | raise IndexError('index {} is out of range'.format(idx)) 43 | it = iter(self._modules.values()) 44 | for i in range(idx): 45 | next(it) 46 | return next(it) 47 | 48 | def __iter__(self): 49 | return iter(self._modules.values()) 50 | 51 | def __len__(self): 52 | return len(self._modules) 53 | -------------------------------------------------------------------------------- /models/distil.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # NVIDIA CORPORATION and its licensors retain all intellectual property 6 | # and proprietary rights in and to this software, related documentation 7 | # and any modifications thereto. Any use, reproduction, disclosure or 8 | # distribution of this software and related documentation without an express 9 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 10 | 11 | from utils.buffer import Buffer 12 | from torch.nn import functional as F 13 | from models.utils.continual_model import ContinualModel 14 | from augmentations import get_aug 15 | import torch 16 | from utils.losses import LabelSmoothing, KL_div_Loss 17 | from datasets import get_dataset 18 | 19 | def smooth(logits, temp, dim): 20 | log = logits ** (1 / temp) 21 | return log / torch.sum(log, dim).unsqueeze(1) 22 | 23 | 24 | def modified_kl_div(old, new): 25 | return -torch.mean(torch.sum(old * torch.log(new), 1)) 26 | 27 | 28 | 29 | class Distil(ContinualModel): 30 | NAME = 'distil' 31 | COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual'] 32 | 33 | def __init__(self, backbone, loss, args, len_train_loader, transform, global_model): 34 | super(Distil, self).__init__(backbone, loss, args, len_train_loader, transform) 35 | self.global_model = global_model 36 | self.buffer = Buffer(self.args.model.buffer_size, self.device) 37 | self.global_model = global_model 38 | self.criterion_kl = KL_div_Loss(temperature=1.0).cuda() 39 | self.soft = torch.nn.Softmax(dim=1) 40 | 41 | def observe(self, inputs1, labels, inputs2, notaug_inputs, task_id): 42 | 43 | self.opt.zero_grad() 44 | inputs1, labels = inputs1.to(self.device), labels.to(self.device) 45 | inputs2 = inputs2.to(self.device) 46 | notaug_inputs = notaug_inputs.to(self.device) 47 | real_batch_size = inputs1.shape[0] 48 | 49 | if task_id: 50 | self.global_model.eval() 51 | outputs = self.net.module.backbone(inputs1) 52 | with torch.no_grad(): 53 | outputs_teacher = self.global_model.net.module.backbone(inputs1) 54 | 55 | penalty = self.args.train.alpha * self.criterion_kl(outputs, outputs_teacher) 56 | loss = self.loss(outputs, labels) + penalty 57 | else: 58 | outputs = self.net.module.backbone(inputs1) 59 | loss = self.loss(outputs, labels) 60 | 61 | if task_id: 62 | data_dict = {'loss': loss, 'penalty': penalty} 63 | else: 64 | data_dict = {'loss': loss, 'penalty': 0.} 65 | 66 | loss.backward() 67 | self.opt.step() 68 | data_dict.update({'lr': self.args.train.base_lr}) 69 | 70 | return data_dict 71 | -------------------------------------------------------------------------------- /models/distilbuf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # NVIDIA CORPORATION and its licensors retain all intellectual property 6 | # and proprietary rights in and to this software, related documentation 7 | # and any modifications thereto. Any use, reproduction, disclosure or 8 | # distribution of this software and related documentation without an express 9 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 10 | 11 | from utils.buffer import Buffer 12 | from torch.nn import functional as F 13 | from models.utils.continual_model import ContinualModel 14 | from augmentations import get_aug 15 | import torch 16 | from utils.losses import LabelSmoothing, KL_div_Loss 17 | from datasets import get_dataset 18 | 19 | def smooth(logits, temp, dim): 20 | log = logits ** (1 / temp) 21 | return log / torch.sum(log, dim).unsqueeze(1) 22 | 23 | 24 | def modified_kl_div(old, new): 25 | return -torch.mean(torch.sum(old * torch.log(new), 1)) 26 | 27 | 28 | 29 | class DistilBuf(ContinualModel): 30 | NAME = 'distilbuf' 31 | COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual'] 32 | 33 | def __init__(self, backbone, loss, args, len_train_loader, transform, global_model): 34 | super(DistilBuf, self).__init__(backbone, loss, args, len_train_loader, transform) 35 | self.global_model = global_model 36 | self.buffer = Buffer(self.args.model.buffer_size, self.device) 37 | self.global_model = global_model 38 | self.criterion_kl = KL_div_Loss(temperature=1.0).cuda() 39 | self.soft = torch.nn.Softmax(dim=1) 40 | 41 | def observe(self, inputs1, labels, inputs2, notaug_inputs, task_id): 42 | 43 | self.opt.zero_grad() 44 | inputs1, labels = inputs1.to(self.device), labels.to(self.device) 45 | inputs2 = inputs2.to(self.device) 46 | notaug_inputs = notaug_inputs.to(self.device) 47 | real_batch_size = inputs1.shape[0] 48 | 49 | if task_id: 50 | self.global_model.eval() 51 | outputs = self.net.module.backbone(inputs1) 52 | with torch.no_grad(): 53 | outputs_teacher = self.global_model.net.module.backbone(inputs1) 54 | 55 | penalty = self.args.train.alpha * self.criterion_kl(outputs, outputs_teacher) 56 | loss = self.loss(outputs, labels) + penalty 57 | else: 58 | outputs = self.net.module.backbone(inputs1) 59 | loss = self.loss(outputs, labels) 60 | 61 | if not self.global_model.buffer.is_empty(): 62 | buf_inputs, buf_logits = self.global_model.buffer.get_data( 63 | self.args.train.batch_size, transform=self.transform) 64 | buf_outputs = self.net.module.backbone(buf_inputs) 65 | penalty = 0.3 * self.loss(buf_outputs, buf_logits.long()) 66 | loss += penalty 67 | 68 | if task_id: 69 | data_dict = {'loss': loss, 'penalty': penalty} 70 | else: 71 | data_dict = {'loss': loss, 'penalty': 0.} 72 | 73 | loss.backward() 74 | self.opt.step() 75 | data_dict.update({'lr': self.args.train.base_lr}) 76 | self.global_model.buffer.add_data(examples=notaug_inputs, labels=labels[:real_batch_size]) 77 | 78 | return data_dict 79 | -------------------------------------------------------------------------------- /models/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .lars import LARS 2 | import torch 3 | from .lr_scheduler import LR_Scheduler 4 | 5 | 6 | def get_optimizer(name, model, lr, momentum, weight_decay, cl_default): 7 | 8 | predictor_prefix = ('module.predictor', 'predictor') 9 | parameters = [{ 10 | 'name': 'base', 11 | 'params': [param for name, param in model.named_parameters() if not name.startswith(predictor_prefix)], 12 | 'lr': lr 13 | },{ 14 | 'name': 'predictor', 15 | 'params': [param for name, param in model.named_parameters() if name.startswith(predictor_prefix)], 16 | 'lr': lr 17 | }] 18 | if name == 'lars': 19 | optimizer = LARS(parameters, lr=lr, momentum=momentum, weight_decay=weight_decay) 20 | elif name == 'sgd': 21 | if cl_default: 22 | optimizer = torch.optim.SGD(parameters, lr=lr) 23 | else: 24 | optimizer = torch.optim.SGD(parameters, lr=lr, momentum=momentum, weight_decay=weight_decay) 25 | elif name == 'adam': 26 | optimizer = torch.optim.Adam(parameters, lr=lr) 27 | else: 28 | raise NotImplementedError 29 | return optimizer 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /models/optimizers/lars.py: -------------------------------------------------------------------------------- 1 | """ Layer-wise adaptive rate scaling for SGD in PyTorch! """ 2 | import torch 3 | from torch.optim.optimizer import Optimizer, required 4 | 5 | class LARS(Optimizer): 6 | r"""Implements layer-wise adaptive rate scaling for SGD. 7 | 8 | Args: 9 | params (iterable): iterable of parameters to optimize or dicts defining 10 | parameter groups 11 | lr (float): base learning rate (\gamma_0) 12 | momentum (float, optional): momentum factor (default: 0) ("m") 13 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 14 | ("\beta") 15 | eta (float, optional): LARS coefficient 16 | max_epoch: maximum training epoch to determine polynomial LR decay. 17 | 18 | Based on Algorithm 1 of the following paper by You, Gitman, and Ginsburg. 19 | Large Batch Training of Convolutional Networks: 20 | https://arxiv.org/abs/1708.03888 21 | 22 | Example: 23 | >>> optimizer = LARS(model.parameters(), lr=0.1, eta=1e-3) 24 | >>> optimizer.zero_grad() 25 | >>> loss_fn(model(input), target).backward() 26 | >>> optimizer.step() 27 | """ 28 | def __init__(self, params, lr=required, momentum=.9, 29 | weight_decay=.0005, eta=0.001, max_epoch=200): 30 | if lr is not required and lr < 0.0: 31 | raise ValueError("Invalid learning rate: {}".format(lr)) 32 | if momentum < 0.0: 33 | raise ValueError("Invalid momentum value: {}".format(momentum)) 34 | if weight_decay < 0.0: 35 | raise ValueError("Invalid weight_decay value: {}" 36 | .format(weight_decay)) 37 | if eta < 0.0: 38 | raise ValueError("Invalid LARS coefficient value: {}".format(eta)) 39 | 40 | self.epoch = 0 41 | defaults = dict(lr=lr, momentum=momentum, 42 | weight_decay=weight_decay, 43 | eta=eta, max_epoch=max_epoch) 44 | super(LARS, self).__init__(params, defaults) 45 | 46 | def step(self, epoch=None, closure=None): 47 | """Performs a single optimization step. 48 | 49 | Arguments: 50 | closure (callable, optional): A closure that reevaluates the model 51 | and returns the loss. 52 | epoch: current epoch to calculate polynomial LR decay schedule. 53 | if None, uses self.epoch and increments it. 54 | """ 55 | loss = None 56 | if closure is not None: 57 | loss = closure() 58 | 59 | if epoch is None: 60 | epoch = self.epoch 61 | self.epoch += 1 62 | 63 | for group in self.param_groups: 64 | weight_decay = group['weight_decay'] 65 | momentum = group['momentum'] 66 | eta = group['eta'] 67 | lr = group['lr'] 68 | max_epoch = group['max_epoch'] 69 | 70 | for p in group['params']: 71 | if p.grad is None: 72 | continue 73 | 74 | param_state = self.state[p] 75 | d_p = p.grad.data 76 | 77 | weight_norm = torch.norm(p.data) 78 | grad_norm = torch.norm(d_p) 79 | 80 | # Global LR computed on polynomial decay schedule 81 | decay = (1 - float(epoch) / max_epoch) ** 2 82 | global_lr = lr * decay 83 | 84 | # Compute local learning rate for this layer 85 | local_lr = eta * weight_norm / (grad_norm + weight_decay * weight_norm) 86 | 87 | # Update the momentum term 88 | actual_lr = local_lr * global_lr 89 | 90 | if 'momentum_buffer' not in param_state: 91 | buf = param_state['momentum_buffer'] = torch.zeros_like(p.data) 92 | else: 93 | buf = param_state['momentum_buffer'] 94 | buf.mul_(momentum).add_(d_p + weight_decay * p.data, alpha=actual_lr) 95 | p.data.add_(-buf) 96 | 97 | return loss -------------------------------------------------------------------------------- /models/optimizers/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class LR_Scheduler(object): 6 | def __init__(self, optimizer, warmup_epochs, warmup_lr, num_epochs, base_lr, final_lr, iter_per_epoch, constant_predictor_lr=False): 7 | self.base_lr = base_lr 8 | self.constant_predictor_lr = constant_predictor_lr 9 | warmup_iter = iter_per_epoch * warmup_epochs 10 | warmup_lr_schedule = np.linspace(warmup_lr, base_lr, warmup_iter) 11 | decay_iter = iter_per_epoch * (num_epochs - warmup_epochs) 12 | cosine_lr_schedule = final_lr+0.5*(base_lr-final_lr)*(1+np.cos(np.pi*np.arange(decay_iter)/decay_iter)) 13 | 14 | self.lr_schedule = np.concatenate((warmup_lr_schedule, cosine_lr_schedule)) 15 | self.optimizer = optimizer 16 | self.iter = 0 17 | self.current_lr = 0 18 | def step(self): 19 | for param_group in self.optimizer.param_groups: 20 | 21 | if self.constant_predictor_lr and param_group['name'] == 'predictor': 22 | param_group['lr'] = self.base_lr 23 | else: 24 | lr = param_group['lr'] = self.lr_schedule[self.iter] 25 | 26 | self.iter += 1 27 | self.current_lr = lr 28 | return lr 29 | 30 | def reset(self): 31 | self.iter = 0 32 | self.current_lr = 0 33 | 34 | def get_lr(self): 35 | return self.current_lr 36 | 37 | if __name__ == "__main__": 38 | import torchvision 39 | model = torchvision.models.resnet50() 40 | optimizer = torch.optim.SGD(model.parameters(), lr=999) 41 | epochs = 100 42 | n_iter = 1000 43 | scheduler = LR_Scheduler(optimizer, 10, 1, epochs, 3, 0, n_iter) 44 | import matplotlib.pyplot as plt 45 | lrs = [] 46 | for epoch in range(epochs): 47 | for it in range(n_iter): 48 | lr = scheduler.step() 49 | lrs.append(lr) 50 | plt.plot(lrs) 51 | plt.show() 52 | -------------------------------------------------------------------------------- /models/qdi.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # NVIDIA CORPORATION and its licensors retain all intellectual property 6 | # and proprietary rights in and to this software, related documentation 7 | # and any modifications thereto. Any use, reproduction, disclosure or 8 | # distribution of this software and related documentation without an express 9 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 10 | 11 | import torch 12 | import numpy as np 13 | import random 14 | from utils.buffer import Buffer 15 | from torch.nn import functional as F 16 | from models.utils.continual_model import ContinualModel 17 | from augmentations import get_aug 18 | from utils.deep_inversion import DeepInversionFeatureHook 19 | from utils.losses import LabelSmoothing, KL_div_Loss 20 | import torchvision.utils as vutils 21 | from datasets import get_dataset 22 | 23 | 24 | def lr_policy(lr_fn): 25 | def _alr(optimizer, epoch): 26 | lr = lr_fn(epoch) 27 | for param_group in optimizer.param_groups: 28 | param_group['lr'] = lr 29 | return _alr 30 | 31 | def lr_cosine_policy(base_lr, warmup_length, epochs): 32 | def _lr_fn(epoch): 33 | if epoch < warmup_length: 34 | lr = base_lr * (epoch + 1) / warmup_length 35 | else: 36 | e = epoch - warmup_length 37 | es = epochs - warmup_length 38 | lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr 39 | print(lr) 40 | return lr 41 | return lr_policy(_lr_fn) 42 | 43 | 44 | class Qdi(ContinualModel): 45 | NAME = 'qdi' 46 | COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual'] 47 | 48 | def __init__(self, backbone, loss, args, len_train_loader, transform, global_model=None): 49 | super(Qdi, self).__init__(backbone, loss, args, len_train_loader, transform) 50 | self.num_classes = 10 51 | im_size = (32, 32) if args.dataset.name == "seq-cifar10" or args.dataset.name == "seq-cifar100" else (64, 64) 52 | images_per_class = 20 53 | self.buffer = Buffer(self.args.model.buffer_size, self.device) 54 | self.global_model = global_model 55 | self.criterion_kl = KL_div_Loss(temperature=1.0).cuda() 56 | self.lr_scheduler = lr_cosine_policy(args.train.di_lr, 100, args.train.di_itrs) 57 | self.args = args 58 | self.cpt = get_dataset(args).N_CLASSES_PER_TASK 59 | self.current_step = 0 60 | 61 | def begin_task(self, task_id, dataset=None): 62 | if task_id: 63 | self.sample_inputs = [] 64 | if dataset is not None: 65 | for i in range(0, dataset.train_loader.dataset.data.shape[0], self.args.train.batch_size): 66 | inputs = torch.stack([dataset.train_loader.dataset.__getitem__(j)[0][0] 67 | for j in range(i, min(i + self.args.train.batch_size, len(dataset.train_loader.dataset)))]) 68 | self.sample_inputs.append(inputs) 69 | 70 | self.sample_inputs = torch.cat(self.sample_inputs) 71 | 72 | rand_idx = torch.randperm(self.sample_inputs.shape[0]) 73 | sample_inputs = self.sample_inputs[rand_idx].to(self.device) 74 | sample_batch = sample_inputs[:self.args.model.buffer_size * 4].to(self.device) 75 | statistics = [] 76 | 77 | batchnorm_flag = [True if isinstance(module, torch.nn.BatchNorm2d) else False for module in self.global_model.net.module.backbone.modules()] 78 | 79 | if True in batchnorm_flag: 80 | for module in self.global_model.net.module.backbone.modules(): 81 | if isinstance(module, torch.nn.BatchNorm2d): 82 | statistics.append(DeepInversionFeatureHook(module)) 83 | 84 | for item in statistics: 85 | item.capture_bn_stats = False 86 | item.use_stored_stats = False 87 | else: 88 | for module in self.global_model.net.module.backbone.modules(): 89 | if isinstance(module, torch.nn.Conv2d): 90 | statistics.append(DeepInversionFeatureHook(module)) 91 | 92 | for item in statistics: 93 | item.capture_bn_stats = True 94 | item.use_stored_stats = True 95 | 96 | _ = self.global_model.net.module.backbone(sample_batch) 97 | print('Finished capturing post conv2d stats. Freezing the stats.') 98 | 99 | for item in statistics: 100 | item.capture_bn_stats = False 101 | item.use_stored_stats = True 102 | 103 | rand_idx = torch.randperm(self.sample_inputs.shape[0]) 104 | sample_inputs = self.sample_inputs[rand_idx].to(self.device) 105 | sample_batch = sample_inputs[:self.args.model.buffer_size].to(self.device) 106 | vutils.save_image(sample_batch.data.clone(), 107 | f'./di_images_{self.args.dataset.name}/sample_batch_{task_id}.png', 108 | normalize=True, scale_each=True, nrow=5) 109 | sample_batch_size, im_size = sample_batch.shape[0], sample_batch.shape[2] 110 | cls_per_task = task_id * self.cpt 111 | self.label_syn = torch.tensor([np.ones(sample_batch_size//cls_per_task) * i for i in range(cls_per_task)], dtype=torch.long, requires_grad=False, device=self.device).view(-1) 112 | rand_idx = torch.randperm(len(self.label_syn)) 113 | label_syn = self.label_syn[rand_idx] 114 | image_syn = torch.randn(size=(self.label_syn.shape[0], 3, im_size, im_size), dtype=torch.float, requires_grad=True, device=self.device) 115 | sample_batch = sample_batch[:self.label_syn.shape[0]] 116 | image_syn.data = sample_batch.data.clone() 117 | image_opt = torch.optim.Adam([image_syn], lr=self.args.train.di_lr, betas=[0.5, 0.9], eps = 1e-8) 118 | 119 | 120 | self.global_model.eval() 121 | self.net.eval() 122 | 123 | for step in range(self.args.train.di_itrs +1): 124 | self.lr_scheduler(image_opt, step) 125 | image_opt.zero_grad() 126 | self.global_model.zero_grad() 127 | outputs = self.global_model.net.module.backbone(image_syn) 128 | loss_ce = self.loss(outputs, label_syn.long()) 129 | 130 | diff1 = image_syn[:,:,:,:-1] - image_syn[:,:,:,1:] 131 | diff2 = image_syn[:,:,:-1,:] - image_syn[:,:,1:,:] 132 | diff3 = image_syn[:,:,1:,:-1] - image_syn[:,:,:-1,1:] 133 | diff4 = image_syn[:,:,:-1,:-1] - image_syn[:,:,1:,1:] 134 | loss_var = torch.norm(diff1) + torch.norm(diff2) + torch.norm(diff3) + torch.norm(diff4) 135 | 136 | loss_distr = self.args.train.di_feature * sum([mod.r_feature for mod in statistics]) 137 | loss_var = self.args.train.di_var * loss_var 138 | loss_l2 = self.args.train.di_l2 * torch.norm(image_syn, 2) 139 | loss = loss_ce + loss_distr + loss_l2 + loss_var 140 | 141 | if step % 5 == 0: 142 | print('\t step', step, '\t ce', loss_ce.item(), '\t r feature', loss_distr.item(), '\tr var', loss_var.item(), '\tr l2', loss_l2.item(), '\t total', loss.item()) 143 | 144 | loss.backward() 145 | image_opt.step() 146 | if step % 5 == 0: 147 | vutils.save_image(image_syn.data.clone(), 148 | f'./di_images_{self.args.dataset.name}/di_generated_{task_id}_{step//5}.png', 149 | normalize=True, scale_each=True, nrow=5) 150 | 151 | self.global_model.buffer.add_data(examples=image_syn, labels=label_syn) 152 | self.image_syn = image_syn.detach().clone() 153 | self.label_syn = label_syn.detach().clone() 154 | self.net.train() 155 | 156 | def observe(self, inputs1, labels, inputs2, notaug_inputs, task_id): 157 | inputs1, labels = inputs1.to(self.device), labels.to(self.device) 158 | real_batch_size = inputs1.shape[0] 159 | 160 | if task_id: 161 | outputs_clean = self.net.module.backbone(inputs1) 162 | 163 | outputs = self.net.module.backbone(self.image_syn) 164 | outputs_teacher = self.global_model.net.module.backbone(self.image_syn) 165 | outputs_teacher_clean = self.global_model.net.module.backbone(inputs1) 166 | 167 | penalty = self.criterion_kl(outputs_clean, outputs_teacher_clean) + self.criterion_kl(outputs, outputs_teacher) 168 | loss = self.loss(outputs_clean, labels) + self.args.train.alpha * penalty 169 | else: 170 | outputs = self.net.module.backbone(inputs1) 171 | loss = self.loss(outputs, labels) 172 | 173 | if task_id: 174 | data_dict = {'loss': loss, 'penalty': penalty} 175 | else: 176 | data_dict = {'loss': loss, 'penalty': 0.} 177 | 178 | self.opt.zero_grad() 179 | loss.backward() 180 | self.opt.step() 181 | data_dict.update({'lr': self.args.train.base_lr}) 182 | self.current_step += 1 183 | 184 | return data_dict 185 | -------------------------------------------------------------------------------- /models/simsiam.py: -------------------------------------------------------------------------------- 1 | # Adopted from https://github.com/divyam3897/UCL/blob/main/models/simsiam.py 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torchvision.models import resnet50 7 | 8 | 9 | def D(p, z, version='simplified'): # negative cosine similarity 10 | if version == 'original': 11 | z = z.detach() # stop gradient 12 | p = F.normalize(p, dim=1) # l2-normalize 13 | z = F.normalize(z, dim=1) # l2-normalize 14 | return -(p*z).sum(dim=1).mean() 15 | 16 | elif version == 'simplified':# same thing, much faster. Scroll down, speed test in __main__ 17 | return - F.cosine_similarity(p, z.detach(), dim=-1).mean() 18 | else: 19 | raise Exception 20 | 21 | 22 | 23 | class projection_MLP(nn.Module): 24 | def __init__(self, in_dim, hidden_dim=2048, out_dim=2048): 25 | super().__init__() 26 | ''' page 3 baseline setting 27 | Projection MLP. The projection MLP (in f) has BN ap- 28 | plied to each fully-connected (fc) layer, including its out- 29 | put fc. Its output fc has no ReLU. The hidden fc is 2048-d. 30 | This MLP has 3 layers. 31 | ''' 32 | self.layer1 = nn.Sequential( 33 | nn.Linear(in_dim, hidden_dim), 34 | nn.BatchNorm1d(hidden_dim), 35 | nn.ReLU(inplace=True) 36 | ) 37 | self.layer2 = nn.Sequential( 38 | nn.Linear(hidden_dim, hidden_dim), 39 | nn.BatchNorm1d(hidden_dim), 40 | nn.ReLU(inplace=True) 41 | ) 42 | self.layer3 = nn.Sequential( 43 | nn.Linear(hidden_dim, out_dim), 44 | nn.BatchNorm1d(hidden_dim) 45 | ) 46 | self.num_layers = 3 47 | def set_layers(self, num_layers): 48 | self.num_layers = num_layers 49 | 50 | def forward(self, x): 51 | if self.num_layers == 3: 52 | x = self.layer1(x) 53 | x = self.layer2(x) 54 | x = self.layer3(x) 55 | elif self.num_layers == 2: 56 | x = self.layer1(x) 57 | x = self.layer3(x) 58 | else: 59 | raise Exception 60 | return x 61 | 62 | 63 | class prediction_MLP(nn.Module): 64 | def __init__(self, in_dim=2048, hidden_dim=512, out_dim=2048): # bottleneck structure 65 | super().__init__() 66 | ''' page 3 baseline setting 67 | Prediction MLP. The prediction MLP (h) has BN applied 68 | to its hidden fc layers. Its output fc does not have BN 69 | (ablation in Sec. 4.4) or ReLU. This MLP has 2 layers. 70 | The dimension of h’s input and output (z and p) is d = 2048, 71 | and h’s hidden layer’s dimension is 512, making h a 72 | bottleneck structure (ablation in supplement). 73 | ''' 74 | self.layer1 = nn.Sequential( 75 | nn.Linear(in_dim, hidden_dim), 76 | nn.BatchNorm1d(hidden_dim), 77 | nn.ReLU(inplace=True) 78 | ) 79 | self.layer2 = nn.Linear(hidden_dim, out_dim) 80 | """ 81 | Adding BN to the output of the prediction MLP h does not work 82 | well (Table 3d). We find that this is not about collapsing. 83 | The training is unstable and the loss oscillates. 84 | """ 85 | 86 | def forward(self, x): 87 | x = self.layer1(x) 88 | x = self.layer2(x) 89 | return x 90 | 91 | class SimSiam(nn.Module): 92 | def __init__(self, backbone=resnet50()): 93 | super().__init__() 94 | 95 | self.backbone = backbone 96 | self.projector = projection_MLP(backbone.output_dim) 97 | 98 | self.encoder = nn.Sequential( # f encoder 99 | self.backbone, 100 | self.projector 101 | ) 102 | self.predictor = prediction_MLP() 103 | self.distil_predictor = prediction_MLP() 104 | 105 | def forward(self, x1, x2): 106 | 107 | f, h = self.encoder, self.predictor 108 | z1, z2 = f(x1), f(x2) 109 | p1, p2 = h(z1), h(z2) 110 | L = D(p1, z2) / 2 + D(p2, z1) / 2 111 | return {'loss': L, 'z1': z1, 'z2': z2} 112 | 113 | if __name__ == "__main__": 114 | model = SimSiam() 115 | model = torch.nn.DataParallel(model).cuda() 116 | x1 = torch.randn((128, 3, 32, 32)) 117 | x2 = torch.randn_like(x1) 118 | 119 | for i in range(50): 120 | model.forward(x1, x2).backward() 121 | print("forward backwork check") 122 | 123 | z1 = torch.randn((200, 2560)) 124 | z2 = torch.randn_like(z1) 125 | import time 126 | tic = time.time() 127 | print(D(z1, z2, version='original')) 128 | toc = time.time() 129 | print(toc - tic) 130 | tic = time.time() 131 | print(D(z1, z2, version='simplified')) 132 | toc = time.time() 133 | print(toc - tic) 134 | 135 | # Output: 136 | # tensor(-0.0010) 137 | # 0.005159854888916016 138 | # tensor(-0.0010) 139 | # 0.0014872550964355469 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | -------------------------------------------------------------------------------- /models/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/HCL/4aa98387995946061ade7b3eede3715e7c29c1ba/models/utils/__init__.py -------------------------------------------------------------------------------- /models/utils/continual_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch.nn as nn 7 | from torch.optim import SGD 8 | import torch 9 | import torchvision 10 | from argparse import Namespace 11 | from utils.conf import get_device 12 | import numpy as np 13 | from ..optimizers import get_optimizer, LR_Scheduler 14 | 15 | 16 | class ContinualModel(nn.Module): 17 | """ 18 | Continual learning model. 19 | """ 20 | NAME = None 21 | COMPATIBILITY = [] 22 | 23 | def __init__(self, backbone: nn.Module, loss: nn.Module, 24 | args: Namespace, dataset, transform: torchvision.transforms) -> None: 25 | super(ContinualModel, self).__init__() 26 | 27 | self.net = backbone 28 | self.net = nn.DataParallel(self.net) 29 | self.loss = loss 30 | self.args = args 31 | self.transform = transform 32 | self.dataset = dataset 33 | 34 | if args.cl_default: 35 | self.opt = get_optimizer( 36 | args.train.optimizer.name, self.net, 37 | lr=args.train.base_lr, 38 | momentum=args.train.optimizer.momentum, 39 | weight_decay=args.train.optimizer.weight_decay, 40 | cl_default=args.cl_default) 41 | else: 42 | self.opt = get_optimizer( 43 | args.train.optimizer.name, self.net, 44 | lr=args.train.base_lr*args.train.batch_size/256, 45 | momentum=args.train.optimizer.momentum, 46 | weight_decay=args.train.optimizer.weight_decay, 47 | cl_default=args.cl_default) 48 | 49 | # self.lr_scheduler = LR_Scheduler( 50 | # self.opt, 51 | # args.train.warmup_epochs, args.train.warmup_lr*args.train.batch_size/256, 52 | # args.train.num_epochs, args.train.base_lr*args.train.batch_size/256, args.train.final_lr*args.train.batch_size/256, 53 | # len_train_lodaer, 54 | # constant_predictor_lr=True # see the end of section 4.2 predictor 55 | # ) 56 | self.device = get_device() 57 | 58 | def forward(self, x: torch.Tensor) -> torch.Tensor: 59 | """ 60 | Computes a forward pass. 61 | :param x: batch of inputs 62 | :param task_label: some models require the task label 63 | :return: the result of the computation 64 | """ 65 | return self.net.module.backbone.forward(x) 66 | 67 | def observe(self, inputs: torch.Tensor, labels: torch.Tensor, 68 | not_aug_inputs: torch.Tensor) -> float: 69 | """ 70 | Compute a training step over a given batch of examples. 71 | :param inputs: batch of examples 72 | :param labels: ground-truth labels 73 | :param kwargs: some methods could require additional parameters 74 | :return: the value of the loss function 75 | """ 76 | pass 77 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cycler==0.10.0 2 | kiwisolver==1.3.2 3 | matplotlib==3.4.3 4 | numpy==1.21.2 5 | Pillow==8.3.2 6 | pyparsing==2.4.7 7 | python-dateutil==2.8.2 8 | PyYAML==5.4.1 9 | quadprog==0.1.10 10 | six==1.16.0 11 | torch==1.9.1 12 | torchvision==0.10.1 13 | tqdm==4.62.3 14 | typing-extensions==3.10.0.2 15 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- 1 | from .average_meter import AverageMeter 2 | from .accuracy import accuracy 3 | from .knn_monitor import knn_monitor 4 | from .logger import Logger 5 | from .file_exist_fn import file_exist_check 6 | -------------------------------------------------------------------------------- /tools/accuracy.py: -------------------------------------------------------------------------------- 1 | def accuracy(output, target, topk=(1,)): 2 | """Computes the accuracy over the k top predictions for the specified values of k""" 3 | with torch.no_grad(): 4 | maxk = max(topk) 5 | batch_size = target.size(0) 6 | 7 | _, pred = output.topk(maxk, 1, True, True) 8 | pred = pred.t() 9 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 10 | 11 | res = [] 12 | for k in topk: 13 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 14 | res.append(correct_k.mul_(100.0 / batch_size)) 15 | return res 16 | -------------------------------------------------------------------------------- /tools/average_meter.py: -------------------------------------------------------------------------------- 1 | class AverageMeter(): 2 | """Computes and stores the average and current value""" 3 | def __init__(self, name, fmt=':f'): 4 | self.name = name 5 | self.fmt = fmt 6 | self.log = [] 7 | self.val = 0 8 | self.avg = 0 9 | self.sum = 0 10 | self.count = 0 11 | 12 | def reset(self): 13 | self.log.append(self.avg) 14 | self.val = 0 15 | self.avg = 0 16 | self.sum = 0 17 | self.count = 0 18 | 19 | def update(self, val, n=1): 20 | self.val = val 21 | self.sum += val * n 22 | self.count += n 23 | self.avg = self.sum / self.count 24 | 25 | def __str__(self): 26 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 27 | return fmtstr.format(**self.__dict__) 28 | 29 | if __name__ == "__main__": 30 | meter = AverageMeter('sldk') 31 | print(meter.log) 32 | 33 | -------------------------------------------------------------------------------- /tools/file_exist_fn.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | 5 | def file_exist_check(file_dir): 6 | 7 | if os.path.isdir(file_dir): 8 | for i in range(2, 1000): 9 | if not os.path.isdir(file_dir + f'({i})'): 10 | file_dir += f'({i})' 11 | break 12 | return file_dir 13 | 14 | 15 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /tools/knn_monitor.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import torch.nn.functional as F 3 | import torch 4 | import numpy as np 5 | import copy 6 | from utils.metrics import mask_classes 7 | 8 | # code copied from https://colab.research.google.com/github/facebookresearch/moco/blob/colab-notebook/colab/moco_cifar10_demo.ipynb#scrollTo=RI1Y8bSImD7N 9 | # test using a knn monitor 10 | def knn_monitor(net, dataset, memory_data_loader, test_data_loader, device, cl_default, task_id, k=200, t=0.1, hide_progress=False): 11 | net.eval() 12 | try: 13 | classes = len(memory_data_loader.dataset.classes) 14 | except: 15 | classes = 200 16 | total_top1 = total_top1_mask = total_top5 = total_num = 0.0 17 | feature_bank = [] 18 | with torch.no_grad(): 19 | # generate feature bank 20 | for data, target in tqdm(memory_data_loader, desc='Feature extracting', leave=False, disable=True): 21 | if cl_default: 22 | feature = net(data.cuda(non_blocking=True), return_features=True) 23 | else: 24 | feature = net(data.cuda(non_blocking=True)) 25 | feature = F.normalize(feature, dim=1) 26 | feature_bank.append(feature) 27 | # [D, N] 28 | feature_bank = torch.cat(feature_bank, dim=0).t().contiguous() 29 | # [N] 30 | # feature_labels = torch.tensor(memory_data_loader.dataset.targets - np.amin(memory_data_loader.dataset.targets), device=feature_bank.device) 31 | feature_labels = torch.tensor(memory_data_loader.dataset.targets, device=feature_bank.device) 32 | # loop test data to predict the label by weighted knn search 33 | test_bar = tqdm(test_data_loader, desc='kNN', disable=True) 34 | for data, target in test_bar: 35 | data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True) 36 | if cl_default: 37 | feature = net(data, return_features=True) 38 | else: 39 | feature = net(data) 40 | feature = F.normalize(feature, dim=1) 41 | pred_scores = knn_predict(feature, feature_bank, feature_labels, classes, k, t) 42 | 43 | total_num += data.shape[0] 44 | _, preds = torch.max(pred_scores.data, 1) 45 | total_top1 += torch.sum(preds == target).item() 46 | 47 | pred_scores_mask = mask_classes(copy.deepcopy(pred_scores), dataset, task_id) 48 | _, preds_mask = torch.max(pred_scores_mask.data, 1) 49 | total_top1_mask += torch.sum(preds_mask == target).item() 50 | 51 | return total_top1 / total_num * 100, total_top1_mask / total_num * 100 52 | 53 | 54 | # knn monitor as in InstDisc https://arxiv.org/abs/1805.01978 55 | # implementation follows http://github.com/zhirongw/lemniscate.pytorch and https://github.com/leftthomas/SimCLR 56 | def knn_predict(feature, feature_bank, feature_labels, classes, knn_k, knn_t): 57 | # compute cos similarity between each feature vector and feature bank ---> [B, N] 58 | sim_matrix = torch.mm(feature, feature_bank) 59 | # [B, K] 60 | sim_weight, sim_indices = sim_matrix.topk(k=knn_k, dim=-1) 61 | # [B, K] 62 | sim_labels = torch.gather(feature_labels.expand(feature.size(0), -1), dim=-1, index=sim_indices) 63 | sim_weight = (sim_weight / knn_t).exp() 64 | 65 | # counts for each class 66 | one_hot_label = torch.zeros(feature.size(0) * knn_k, classes, device=sim_labels.device) 67 | # [B*K, C] 68 | one_hot_label = one_hot_label.scatter(dim=-1, index=sim_labels.view(-1, 1), value=1.0) 69 | # weighted score ---> [B, C] 70 | pred_scores = torch.sum(one_hot_label.view(feature.size(0), -1, classes) * sim_weight.unsqueeze(dim=-1), dim=1) 71 | 72 | return pred_scores 73 | -------------------------------------------------------------------------------- /tools/logger.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | from collections import OrderedDict 3 | import os 4 | from .plotter import Plotter 5 | 6 | 7 | class Logger(object): 8 | def __init__(self, log_dir, matplotlib=True): 9 | 10 | self.reset(log_dir, matplotlib) 11 | 12 | def reset(self, log_dir=None, tensorboard=True, matplotlib=True): 13 | 14 | if log_dir is not None: self.log_dir=log_dir 15 | self.plotter = Plotter() if matplotlib else None 16 | self.counter = OrderedDict() 17 | 18 | def update_scalers(self, ordered_dict): 19 | 20 | for key, value in ordered_dict.items(): 21 | if isinstance(value, Tensor): 22 | try: 23 | ordered_dict[key] = value.item() 24 | except: 25 | pass 26 | if self.counter.get(key) is None: 27 | self.counter[key] = 1 28 | else: 29 | self.counter[key] += 1 30 | 31 | # if self.plotter: 32 | # self.plotter.update(ordered_dict) 33 | # self.plotter.save(os.path.join(self.log_dir, 'plotter.svg')) 34 | 35 | 36 | 37 | 38 | 39 | -------------------------------------------------------------------------------- /tools/plotter.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use('Agg') #https://stackoverflow.com/questions/49921721/runtimeerror-main-thread-is-not-in-main-loop-with-matplotlib-and-flask 3 | import matplotlib.pyplot as plt 4 | from collections import OrderedDict 5 | from torch import Tensor 6 | 7 | class Plotter(object): 8 | def __init__(self): 9 | self.logger = OrderedDict() 10 | def update(self, ordered_dict): 11 | for key, value in ordered_dict.items(): 12 | if isinstance(value, Tensor): 13 | try: 14 | ordered_dict[key] = value.item() 15 | except: 16 | pass 17 | if self.logger.get(key) is None: 18 | self.logger[key] = [value] 19 | else: 20 | self.logger[key].append(value) 21 | 22 | def save(self, file, **kwargs): 23 | fig, axes = plt.subplots(nrows=len(self.logger), ncols=1, figsize=(8,2*len(self.logger))) 24 | fig.tight_layout() 25 | for ax, (key, value) in zip(axes, self.logger.items()): 26 | ax.plot(value) 27 | ax.set_title(key) 28 | 29 | plt.savefig(file, **kwargs) 30 | plt.close() 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import os 7 | 8 | 9 | def create_if_not_exists(path: str) -> None: 10 | """ 11 | Creates the specified folder if it does not exist. 12 | :param path: the complete path of the folder to be created 13 | """ 14 | if not os.path.exists(path): 15 | os.makedirs(path) 16 | -------------------------------------------------------------------------------- /utils/args.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from argparse import ArgumentParser 7 | from datasets import NAMES as DATASET_NAMES 8 | from models import get_all_models 9 | 10 | 11 | def add_experiment_args(parser: ArgumentParser) -> None: 12 | """ 13 | Adds the arguments used by all the models. 14 | :param parser: the parser instance 15 | """ 16 | parser.add_argument('--dataset', type=str, required=True, 17 | choices=DATASET_NAMES, 18 | help='Which dataset to perform experiments on.') 19 | parser.add_argument('--model', type=str, required=True, 20 | help='Model name.', choices=get_all_models()) 21 | 22 | parser.add_argument('--lr', type=float, required=True, 23 | help='Learning rate.') 24 | parser.add_argument('--warmup_lr', default=0.0, type=float, 25 | help='Warmup Learning rate') 26 | parser.add_argument('--warmup_epochs', default=0, type=int, 27 | help='Warmup epochs') 28 | parser.add_argument('--final_lr', default=0.0, type=float, 29 | help='Final Learning rate') 30 | parser.add_argument('--batch_size', type=int, required=True, 31 | help='Batch size.') 32 | parser.add_argument('--n_epochs', type=int, required=True, 33 | help='The number of epochs for each task.') 34 | parser.add_argument('--sim_siam', action='store_true', 35 | help='Use SimSiam') 36 | 37 | 38 | def add_management_args(parser: ArgumentParser) -> None: 39 | parser.add_argument('--seed', type=int, default=None, 40 | help='The random seed.') 41 | parser.add_argument('--notes', type=str, default=None, 42 | help='Notes for this run.') 43 | 44 | parser.add_argument('--csv_log', action='store_true', 45 | help='Enable csv logging') 46 | parser.add_argument('--tensorboard', action='store_true', 47 | help='Enable tensorboard logging') 48 | parser.add_argument('--validation', action='store_true', 49 | help='Test on the validation set') 50 | 51 | 52 | def add_rehearsal_args(parser: ArgumentParser) -> None: 53 | """ 54 | Adds the arguments used by all the rehearsal-based methods 55 | :param parser: the parser instance 56 | """ 57 | parser.add_argument('--buffer_size', type=int, required=True, 58 | help='The size of the memory buffer.') 59 | parser.add_argument('--minibatch_size', type=int, 60 | help='The batch size of the memory buffer.') 61 | -------------------------------------------------------------------------------- /utils/batch_norm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | class bn_track_stats: 10 | def __init__(self, module: nn.Module, condition=True): 11 | self.module = module 12 | self.enable = condition 13 | 14 | def __enter__(self): 15 | if not self.enable: 16 | for m in self.module.modules(): 17 | if isinstance(m, (torch.nn.BatchNorm2d, torch.nn.BatchNorm1d)): 18 | m.track_running_stats = False 19 | 20 | def __exit__(self ,type, value, traceback): 21 | if not self.enable: 22 | for m in self.module.modules(): 23 | if isinstance(m, (torch.nn.BatchNorm2d, torch.nn.BatchNorm1d)): 24 | m.track_running_stats = True 25 | -------------------------------------------------------------------------------- /utils/buffer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import numpy as np 8 | from typing import Tuple 9 | from torchvision import transforms 10 | 11 | def icarl_replay(self, dataset, val_set_split=0): 12 | """ 13 | Merge the replay buffer with the current task data. 14 | Optionally split the replay buffer into a validation set. 15 | :param self: the model instance 16 | :param dataset: the dataset 17 | :param val_set_split: the fraction of the replay buffer to be used as validation set 18 | """ 19 | if self.task > 0: 20 | buff_val_mask = torch.rand(len(self.buffer)) < val_set_split 21 | val_train_mask = torch.zeros(len(dataset.train_loader.dataset.data)).bool() 22 | val_train_mask[torch.randperm(len(dataset.train_loader.dataset.data))[:buff_val_mask.sum()]] = True 23 | 24 | if val_set_split > 0: 25 | self.val_loader = deepcopy(dataset.train_loader) 26 | data_concatenate = torch.cat if type(dataset.train_loader.dataset.data) == torch.Tensor else np.concatenate 27 | need_aug = hasattr(dataset.train_loader.dataset, 'not_aug_transform') 28 | if not need_aug: 29 | refold_transform = lambda x: x.cpu() 30 | else: 31 | data_shape = len(dataset.train_loader.dataset.data[0].shape) 32 | if data_shape == 3: 33 | refold_transform = lambda x: (x.cpu()*255).permute([0, 2, 3, 1]).numpy().astype(np.uint8) 34 | elif data_shape == 2: 35 | refold_transform = lambda x: (x.cpu()*255).squeeze(1).type(torch.uint8) 36 | 37 | # REDUCE AND MERGE TRAINING SET 38 | dataset.train_loader.dataset.targets = np.concatenate([ 39 | dataset.train_loader.dataset.targets[~val_train_mask], 40 | self.buffer.labels.cpu().numpy()[:len(self.buffer)][~buff_val_mask] 41 | ]) 42 | dataset.train_loader.dataset.data = data_concatenate([ 43 | dataset.train_loader.dataset.data[~val_train_mask], 44 | refold_transform((self.buffer.examples)[:len(self.buffer)][~buff_val_mask]) 45 | ]) 46 | 47 | if val_set_split > 0: 48 | # REDUCE AND MERGE VALIDATION SET 49 | self.val_loader.dataset.targets = np.concatenate([ 50 | self.val_loader.dataset.targets[val_train_mask], 51 | self.buffer.labels.cpu().numpy()[:len(self.buffer)][buff_val_mask] 52 | ]) 53 | self.val_loader.dataset.data = data_concatenate([ 54 | self.val_loader.dataset.data[val_train_mask], 55 | refold_transform((self.buffer.examples)[:len(self.buffer)][buff_val_mask]) 56 | ]) 57 | 58 | def reservoir(num_seen_examples: int, buffer_size: int) -> int: 59 | """ 60 | Reservoir sampling algorithm. 61 | :param num_seen_examples: the number of seen examples 62 | :param buffer_size: the maximum buffer size 63 | :return: the target index if the current image is sampled, else -1 64 | """ 65 | if num_seen_examples < buffer_size: 66 | return num_seen_examples 67 | 68 | rand = np.random.randint(0, num_seen_examples + 1) 69 | if rand < buffer_size: 70 | return rand 71 | else: 72 | return -1 73 | 74 | 75 | def ring(num_seen_examples: int, buffer_portion_size: int, task: int) -> int: 76 | return num_seen_examples % buffer_portion_size + task * buffer_portion_size 77 | 78 | 79 | class Buffer: 80 | """ 81 | The memory buffer of rehearsal method. 82 | """ 83 | def __init__(self, buffer_size, device, n_tasks=None, mode='reservoir'): 84 | assert mode in ['ring', 'reservoir'] 85 | self.buffer_size = buffer_size 86 | self.device = device 87 | self.num_seen_examples = 0 88 | self.functional_index = eval(mode) 89 | if mode == 'ring': 90 | assert n_tasks is not None 91 | self.task_number = n_tasks 92 | self.buffer_portion_size = buffer_size // n_tasks 93 | self.attributes = ['examples', 'labels', 'logits', 'task_labels'] 94 | 95 | def init_tensors(self, examples: torch.Tensor, labels: torch.Tensor, 96 | logits: torch.Tensor, task_labels: torch.Tensor) -> None: 97 | """ 98 | Initializes just the required tensors. 99 | :param examples: tensor containing the images 100 | :param labels: tensor containing the labels 101 | :param logits: tensor containing the outputs of the network 102 | :param task_labels: tensor containing the task labels 103 | """ 104 | for attr_str in self.attributes: 105 | attr = eval(attr_str) 106 | if attr is not None and not hasattr(self, attr_str): 107 | typ = torch.int64 if attr_str.endswith('els') else torch.float32 108 | setattr(self, attr_str, torch.zeros((self.buffer_size, 109 | *attr.shape[1:]), dtype=typ, device=self.device)) 110 | 111 | def add_data(self, examples, labels=None, logits=None, task_labels=None): 112 | """ 113 | Adds the data to the memory buffer according to the reservoir strategy. 114 | :param examples: tensor containing the images 115 | :param labels: tensor containing the labels 116 | :param logits: tensor containing the outputs of the network 117 | :param task_labels: tensor containing the task labels 118 | :return: 119 | """ 120 | if not hasattr(self, 'examples'): 121 | self.init_tensors(examples, labels, logits, task_labels) 122 | 123 | for i in range(examples.shape[0]): 124 | index = reservoir(self.num_seen_examples, self.buffer_size) 125 | self.num_seen_examples += 1 126 | if index >= 0: 127 | self.examples[index] = examples[i].to(self.device) 128 | if labels is not None: 129 | self.labels[index] = labels[i].to(self.device) 130 | if logits is not None: 131 | self.logits[index] = logits[i].to(self.device) 132 | if task_labels is not None: 133 | self.task_labels[index] = task_labels[i].to(self.device) 134 | 135 | def get_data(self, size: int, transform: transforms=None) -> Tuple: 136 | """ 137 | Random samples a batch of size items. 138 | :param size: the number of requested items 139 | :param transform: the transformation to be applied (data augmentation) 140 | :return: 141 | """ 142 | if size > min(self.num_seen_examples, self.examples.shape[0]): 143 | size = min(self.num_seen_examples, self.examples.shape[0]) 144 | 145 | choice = np.random.choice(min(self.num_seen_examples, self.examples.shape[0]), 146 | size=size, replace=False) 147 | if transform is None: transform = lambda x: x 148 | # import pdb 149 | # pdb.set_trace() 150 | ret_tuple = (torch.stack([transform(ee.cpu()) 151 | for ee in self.examples[choice]]).to(self.device),) 152 | for attr_str in self.attributes[1:]: 153 | if hasattr(self, attr_str): 154 | attr = getattr(self, attr_str) 155 | ret_tuple += (attr[choice],) 156 | 157 | return ret_tuple 158 | 159 | def is_empty(self) -> bool: 160 | """ 161 | Returns true if the buffer is empty, false otherwise. 162 | """ 163 | if self.num_seen_examples == 0: 164 | return True 165 | else: 166 | return False 167 | 168 | def get_all_data(self, transform: transforms=None) -> Tuple: 169 | """ 170 | Return all the items in the memory buffer. 171 | :param transform: the transformation to be applied (data augmentation) 172 | :return: a tuple with all the items in the memory buffer 173 | """ 174 | if transform is None: transform = lambda x: x 175 | ret_tuple = (torch.stack([transform(ee.cpu()) 176 | for ee in self.examples]).to(self.device),) 177 | for attr_str in self.attributes[1:]: 178 | if hasattr(self, attr_str): 179 | attr = getattr(self, attr_str) 180 | ret_tuple += (attr,) 181 | return ret_tuple 182 | 183 | def empty(self) -> None: 184 | """ 185 | Set all the tensors to None. 186 | """ 187 | for attr_str in self.attributes: 188 | if hasattr(self, attr_str): 189 | delattr(self, attr_str) 190 | self.num_seen_examples = 0 191 | -------------------------------------------------------------------------------- /utils/conf.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import random 7 | import torch 8 | import numpy as np 9 | 10 | def get_device() -> torch.device: 11 | """ 12 | Returns the GPU device if available else CPU. 13 | """ 14 | return torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 15 | 16 | 17 | def base_path() -> str: 18 | """ 19 | Returns the base bath where to log accuracies and tensorboard data. 20 | """ 21 | return './data/' 22 | 23 | 24 | def set_random_seed(seed: int) -> None: 25 | """ 26 | Sets the seeds at a certain value. 27 | :param seed: the value to be set 28 | """ 29 | random.seed(seed) 30 | np.random.seed(seed) 31 | torch.manual_seed(seed) 32 | torch.cuda.manual_seed_all(seed) 33 | -------------------------------------------------------------------------------- /utils/continual_training.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | from datasets import get_gcl_dataset 8 | from models import get_model 9 | from utils.status import progress_bar 10 | from utils.tb_logger import * 11 | from utils.status import create_fake_stash 12 | from models.utils.continual_model import ContinualModel 13 | from argparse import Namespace 14 | 15 | 16 | def evaluate(model: ContinualModel, dataset) -> float: 17 | """ 18 | Evaluates the final accuracy of the model. 19 | :param model: the model to be evaluated 20 | :param dataset: the GCL dataset at hand 21 | :return: a float value that indicates the accuracy 22 | """ 23 | model.net.eval() 24 | correct, total = 0, 0 25 | while not dataset.test_over: 26 | inputs, labels = dataset.get_test_data() 27 | inputs, labels = inputs.to(model.device), labels.to(model.device) 28 | outputs = model(inputs) 29 | _, predicted = torch.max(outputs.data, 1) 30 | correct += torch.sum(predicted == labels).item() 31 | total += labels.shape[0] 32 | 33 | acc = correct / total * 100 34 | return acc 35 | 36 | 37 | def train(args: Namespace): 38 | """ 39 | The training process, including evaluations and loggers. 40 | :param model: the module to be trained 41 | :param dataset: the continual dataset at hand 42 | :param args: the arguments of the current execution 43 | """ 44 | if args.csv_log: 45 | from utils.loggers import CsvLogger 46 | 47 | dataset = get_gcl_dataset(args) 48 | backbone = dataset.get_backbone() 49 | loss = dataset.get_loss() 50 | model = get_model(args, backbone, loss, dataset.get_transform()) 51 | model.net.to(model.device) 52 | 53 | model_stash = create_fake_stash(model, args) 54 | 55 | if args.csv_log: 56 | csv_logger = CsvLogger(dataset.SETTING, dataset.NAME, model.NAME) 57 | 58 | model.net.train() 59 | epoch, i = 0, 0 60 | while not dataset.train_over: 61 | inputs, labels, not_aug_inputs = dataset.get_train_data() 62 | inputs, labels = inputs.to(model.device), labels.to(model.device) 63 | not_aug_inputs = not_aug_inputs.to(model.device) 64 | loss = model.observe(inputs, labels, not_aug_inputs) 65 | progress_bar(i, dataset.LENGTH // args.batch_size, epoch, 'C', loss) 66 | i += 1 67 | 68 | if model.NAME == 'joint_gcl': 69 | model.end_task(dataset) 70 | 71 | acc = evaluate(model, dataset) 72 | print('Accuracy:', acc) 73 | 74 | if args.csv_log: 75 | csv_logger.log(acc) 76 | csv_logger.write(vars(args)) 77 | -------------------------------------------------------------------------------- /utils/deep_inversion.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import torch 10 | 11 | class DeepInversionFeatureHook(): 12 | ''' 13 | Implementation of the forward hook to track feature statistics and compute a loss on them. 14 | Will compute mean and variance, and will use l2 as a loss 15 | ''' 16 | 17 | def __init__(self, module): 18 | self.hook = module.register_forward_hook(self.hook_fn) 19 | self.mean = None 20 | self.var = None 21 | self.use_stored_stats = False 22 | self.capture_bn_stats = False 23 | 24 | 25 | def hook_fn(self, module, input, output): 26 | # hook co compute deepinversion's feature distribution regularization 27 | nch_in = input[0].shape[1] 28 | nch_out = output.shape[1] 29 | 30 | mean = input[0].mean([0, 2, 3]) 31 | var = input[0].permute(1, 0, 2, 3).contiguous().view([nch_in, -1]).var(1, unbiased=False) 32 | 33 | in_mean = input[0].mean([0, 2, 3]) 34 | in_var = input[0].permute(1, 0, 2, 3).contiguous().view([nch_in, -1]).var(1, unbiased=False) 35 | 36 | out_mean = output.mean([0, 2, 3]) 37 | out_var = output.permute(1, 0, 2, 3).contiguous().view([nch_out, -1]).var(1, unbiased=False) 38 | 39 | if self.capture_bn_stats: 40 | self.out_mean = out_mean.clone().detach() 41 | self.out_var = out_var.clone().detach() 42 | 43 | if not self.use_stored_stats: 44 | r_feature = torch.norm(module.running_var.data.type(in_var.type()) - in_var, 2) + torch.norm( 45 | module.running_mean.data.type(in_mean.type()) - in_mean, 2) 46 | else: 47 | r_feature = torch.norm(self.out_var - out_var, 2) + torch.norm(self.out_mean - out_mean, 2) 48 | 49 | self.r_feature = r_feature 50 | # must have no output 51 | 52 | def close(self): 53 | self.hook.remove() 54 | 55 | -------------------------------------------------------------------------------- /utils/loggers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import csv 7 | import os 8 | import sys 9 | from typing import Dict, Any 10 | from utils.metrics import * 11 | 12 | from utils import create_if_not_exists 13 | from utils.conf import base_path 14 | import numpy as np 15 | 16 | useless_args = ['dataset', 'tensorboard', 'validation', 'model', 17 | 'csv_log', 'notes', 'load_best_args'] 18 | 19 | 20 | def print_mean_accuracy(mean_acc: np.ndarray, task_number: int, 21 | setting: str) -> None: 22 | """ 23 | Prints the mean accuracy on stderr. 24 | :param mean_acc: mean accuracy value 25 | :param task_number: task index 26 | :param setting: the setting of the benchmark 27 | """ 28 | if setting == 'domain-il': 29 | mean_acc, _ = mean_acc 30 | print('\nAccuracy for {} task(s): {} %'.format( 31 | task_number, round(mean_acc, 2)), file=sys.stderr) 32 | else: 33 | mean_acc_class_il, mean_acc_task_il = mean_acc 34 | print('\nAccuracy for {} task(s): \t [Class-IL]: {} %' 35 | ' \t [Task-IL]: {} %\n'.format(task_number, round( 36 | mean_acc_class_il, 2), round(mean_acc_task_il, 2)), file=sys.stderr) 37 | 38 | 39 | class CsvLogger: 40 | def __init__(self, setting_str: str, dataset_str: str, 41 | model_str: str) -> None: 42 | self.accs = [] 43 | if setting_str == 'class-il': 44 | self.accs_mask_classes = [] 45 | self.setting = setting_str 46 | self.dataset = dataset_str 47 | self.model = model_str 48 | self.fwt = None 49 | self.fwt_mask_classes = None 50 | self.bwt = None 51 | self.bwt_mask_classes = None 52 | self.forgetting = None 53 | self.forgetting_mask_classes = None 54 | 55 | def add_fwt(self, results, accs, results_mask_classes, accs_mask_classes): 56 | self.fwt = forward_transfer(results, accs) 57 | if self.setting == 'class-il': 58 | self.fwt_mask_classes = forward_transfer(results_mask_classes, accs_mask_classes) 59 | 60 | def add_bwt(self, results, results_mask_classes): 61 | self.bwt = backward_transfer(results) 62 | self.bwt_mask_classes = backward_transfer(results_mask_classes) 63 | 64 | def add_forgetting(self, results, results_mask_classes): 65 | self.forgetting = forgetting(results) 66 | self.forgetting_mask_classes = forgetting(results_mask_classes) 67 | 68 | def log(self, mean_acc: np.ndarray) -> None: 69 | """ 70 | Logs a mean accuracy value. 71 | :param mean_acc: mean accuracy value 72 | """ 73 | if self.setting == 'general-continual': 74 | self.accs.append(mean_acc) 75 | elif self.setting == 'domain-il': 76 | mean_acc, _ = mean_acc 77 | self.accs.append(mean_acc) 78 | else: 79 | mean_acc_class_il, mean_acc_task_il = mean_acc 80 | self.accs.append(mean_acc_class_il) 81 | self.accs_mask_classes.append(mean_acc_task_il) 82 | 83 | def write(self, ckpt_dir, args: Dict[str, Any]) -> None: 84 | """ 85 | writes out the logged value along with its arguments. 86 | :param args: the namespace of the current experiment 87 | """ 88 | for cc in useless_args: 89 | if cc in args: 90 | del args[cc] 91 | 92 | columns = list(args.keys()) 93 | 94 | new_cols = [] 95 | for i, acc in enumerate(self.accs): 96 | args['task' + str(i + 1)] = acc 97 | new_cols.append('task' + str(i + 1)) 98 | 99 | args['forward_transfer'] = self.fwt 100 | new_cols.append('forward_transfer') 101 | 102 | args['backward_transfer'] = self.bwt 103 | new_cols.append('backward_transfer') 104 | 105 | args['forgetting'] = self.forgetting 106 | new_cols.append('forgetting') 107 | 108 | columns = new_cols + columns 109 | 110 | create_if_not_exists(ckpt_dir + "results/" + self.setting) 111 | create_if_not_exists(ckpt_dir + "results/" + self.setting + 112 | "/" + self.dataset) 113 | create_if_not_exists(ckpt_dir + "results/" + self.setting + 114 | "/" + self.dataset + "/" + self.model) 115 | 116 | write_headers = False 117 | path = ckpt_dir + "results/" + self.setting + "/" + self.dataset\ 118 | + "/" + self.model + "/mean_accs.csv" 119 | if not os.path.exists(path): 120 | write_headers = True 121 | with open(path, 'a') as tmp: 122 | writer = csv.DictWriter(tmp, fieldnames=columns) 123 | if write_headers: 124 | writer.writeheader() 125 | writer.writerow(args) 126 | 127 | if self.setting == 'class-il': 128 | create_if_not_exists(ckpt_dir + "results/task-il/" 129 | + self.dataset) 130 | create_if_not_exists(ckpt_dir + "results/task-il/" 131 | + self.dataset + "/" + self.model) 132 | 133 | for i, acc in enumerate(self.accs_mask_classes): 134 | args['task' + str(i + 1)] = acc 135 | 136 | args['forward_transfer'] = self.fwt_mask_classes 137 | args['backward_transfer'] = self.bwt_mask_classes 138 | args['forgetting'] = self.forgetting_mask_classes 139 | 140 | write_headers = False 141 | path = ckpt_dir + "results/task-il" + "/" + self.dataset + "/"\ 142 | + self.model + "/mean_accs.csv" 143 | if not os.path.exists(path): 144 | write_headers = True 145 | with open(path, 'a') as tmp: 146 | writer = csv.DictWriter(tmp, fieldnames=columns) 147 | if write_headers: 148 | writer.writeheader() 149 | writer.writerow(args) 150 | -------------------------------------------------------------------------------- /utils/losses.py: -------------------------------------------------------------------------------- 1 | # Originated from https://github.com/sutd-visual-computing-group/LS-KD-compatibility/blob/master/src/image_classification/imagenet/utils.py 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torchvision.models as models 6 | import torch.nn.functional as F 7 | 8 | 9 | # Define Smooth Loss 10 | class LabelSmoothing(nn.Module): 11 | """ 12 | NLL loss with label smoothing. 13 | https://github.com/NVIDIA/DeepLearningExamples/blob/8d8b21a933fff3defb692e0527fca15532da5dc6/PyTorch/Classification/ConvNets/image_classification/smoothing.py 14 | """ 15 | 16 | def __init__(self, smoothing=0.0): 17 | """ 18 | Constructor for the LabelSmoothing module. 19 | :param smoothing: label smoothing factor 20 | """ 21 | super(LabelSmoothing, self).__init__() 22 | self.confidence = 1.0 - smoothing 23 | self.smoothing = smoothing 24 | 25 | def forward(self, x, target): 26 | logprobs = torch.nn.functional.log_softmax(x, dim=-1) 27 | 28 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) 29 | nll_loss = nll_loss.squeeze(1) 30 | smooth_loss = -logprobs.mean(dim=-1) 31 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss 32 | return loss.mean() 33 | 34 | 35 | # Define KL divergence loss 36 | class KL_div_Loss(nn.Module): 37 | """ 38 | We use formulation of Hinton et. for KD loss. 39 | $T^2$ scaling is implemented to avoid gradient rescaling when using T!=1 40 | """ 41 | 42 | def __init__(self, temperature): 43 | """ 44 | Constructor for the LabelSmoothing module. 45 | :param smoothing: label smoothing factor 46 | """ 47 | super(KL_div_Loss, self).__init__() 48 | self.temperature = temperature 49 | #print( "Setting temperature = {} for KD (Only Teacher)".format(self.temperature) ) 50 | print( "Setting temperature = {} for KD".format(self.temperature) ) 51 | 52 | 53 | def forward(self, y, teacher_scores): 54 | p = F.log_softmax(y / self.temperature, dim=1) # Hinton formulation 55 | 56 | #p = F.log_softmax(y, dim=1) # Muller et. al used this. 57 | 58 | q = F.softmax(teacher_scores / self.temperature, dim=1) 59 | l_kl = F.kl_div(p, q, reduction='batchmean') 60 | return l_kl*(self.temperature**2) # $T^2$ scaling is important 61 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import numpy as np 8 | from datasets.utils.continual_dataset import ContinualDataset 9 | from typing import Tuple 10 | 11 | 12 | def backward_transfer(results): 13 | n_tasks = len(results) 14 | li = list() 15 | for i in range(n_tasks - 1): 16 | li.append(results[-1][i] - results[i][i]) 17 | 18 | return np.mean(li) 19 | 20 | 21 | def forward_transfer(results, random_results): 22 | n_tasks = len(results) 23 | li = list() 24 | for i in range(1, n_tasks): 25 | li.append(results[i-1][i] - random_results[i]) 26 | 27 | return np.mean(li) 28 | 29 | 30 | def forgetting(results): 31 | n_tasks = len(results) 32 | li = list() 33 | for i in range(n_tasks - 1): 34 | results[i] += [0.0] * (n_tasks - len(results[i])) 35 | np_res = np.array(results) 36 | maxx = np.max(np_res, axis=0) 37 | for i in range(n_tasks - 1): 38 | li.append(maxx[i] - results[-1][i]) 39 | 40 | return np.mean(li) 41 | 42 | 43 | def mask_classes(outputs: torch.Tensor, dataset: ContinualDataset, k: int): 44 | """ 45 | Given the output tensor, the dataset at hand and the current task, 46 | masks the former by setting the responses for the other tasks at -inf. 47 | It is used to obtain the results for the task-il setting. 48 | :param outputs: the output tensor 49 | :param dataset: the continual dataset 50 | :param k: the task index 51 | """ 52 | outputs[:, 0:k * dataset.N_CLASSES_PER_TASK] = -float('inf') 53 | outputs[:, (k + 1) * dataset.N_CLASSES_PER_TASK: 54 | dataset.N_TASKS * dataset.N_CLASSES_PER_TASK] = -float('inf') 55 | 56 | return outputs 57 | 58 | 59 | -------------------------------------------------------------------------------- /utils/status.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from datetime import datetime 7 | import sys 8 | import os 9 | from utils.conf import base_path 10 | from typing import Any, Dict, Union 11 | from torch import nn 12 | from argparse import Namespace 13 | from datasets.utils.continual_dataset import ContinualDataset 14 | 15 | 16 | def create_stash(model: nn.Module, args: Namespace, 17 | dataset: ContinualDataset) -> Dict[Any, str]: 18 | """ 19 | Creates the dictionary where to save the model status. 20 | :param model: the model 21 | :param args: the current arguments 22 | :param dataset: the dataset at hand 23 | """ 24 | now = datetime.now() 25 | model_stash = {'task_idx': 0, 'epoch_idx': 0, 'batch_idx': 0} 26 | name_parts = [args.dataset, model.NAME] 27 | if 'buffer_size' in vars(args).keys(): 28 | name_parts.append('buf_' + str(args.buffer_size)) 29 | name_parts.append(now.strftime("%Y%m%d_%H%M%S_%f")) 30 | model_stash['model_name'] = '/'.join(name_parts) 31 | model_stash['mean_accs'] = [] 32 | model_stash['args'] = args 33 | model_stash['backup_folder'] = os.path.join(base_path(), 'backups', 34 | dataset.SETTING, 35 | model_stash['model_name']) 36 | return model_stash 37 | 38 | 39 | def create_fake_stash(model: nn.Module, args: Namespace) -> Dict[Any, str]: 40 | """ 41 | Create a fake stash, containing just the model name. 42 | This is used in general continual, as it is useless to backup 43 | a lightweight MNIST-360 training. 44 | :param model: the model 45 | :param args: the arguments of the call 46 | :return: a dict containing a fake stash 47 | """ 48 | now = datetime.now() 49 | model_stash = {'task_idx': 0, 'epoch_idx': 0} 50 | name_parts = [args.dataset, model.NAME] 51 | if 'buffer_size' in vars(args).keys(): 52 | name_parts.append('buf_' + str(args.buffer_size)) 53 | name_parts.append(now.strftime("%Y%m%d_%H%M%S_%f")) 54 | model_stash['model_name'] = '/'.join(name_parts) 55 | 56 | return model_stash 57 | 58 | 59 | def progress_bar(i: int, max_iter: int, epoch: Union[int, str], 60 | task_number: int, loss: float) -> None: 61 | """ 62 | Prints out the progress bar on the stderr file. 63 | :param i: the current iteration 64 | :param max_iter: the maximum number of iteration 65 | :param epoch: the epoch 66 | :param task_number: the task index 67 | :param loss: the current value of the loss function 68 | """ 69 | if not (i + 1) % 10 or (i + 1) == max_iter: 70 | progress = min(float((i + 1) / max_iter), 1) 71 | progress_bar = ('█' * int(50 * progress)) + ('┈' * (50 - int(50 * progress))) 72 | print('\r[ {} ] Task {} | epoch {}: |{}| loss: {}'.format( 73 | datetime.now().strftime("%m-%d | %H:%M"), 74 | task_number + 1 if isinstance(task_number, int) else task_number, 75 | epoch, 76 | progress_bar, 77 | round(loss, 8) 78 | ), file=sys.stderr, end='', flush=True) 79 | -------------------------------------------------------------------------------- /utils/tb_logger.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from utils.conf import base_path 7 | import os 8 | from argparse import Namespace 9 | from typing import Dict, Any 10 | import numpy as np 11 | import torchvision 12 | import matplotlib.pyplot as plt 13 | import torch 14 | import torch.nn.functional as F 15 | 16 | def img_denormlaize(img): 17 | """Scaling and shift a batch of images (NCHW) 18 | """ 19 | mean = [0.4914, 0.4822, 0.4465] 20 | std = [0.2470, 0.2435, 0.2615] 21 | nch = img.shape[1] 22 | 23 | mean = torch.tensor(mean, device=img.device).reshape(1, nch, 1, 1) 24 | std = torch.tensor(std, device=img.device).reshape(1, nch, 1, 1) 25 | 26 | return img * std + mean 27 | 28 | 29 | def save_img(img, unnormalize=True, max_num=5, size=32, nrow=5, dataname='imagenet'): 30 | img = img[:max_num].detach() 31 | if unnormalize: 32 | img = img_denormlaize(img) 33 | images = torch.clamp(img, min=0., max=1.) 34 | images = torchvision.utils.make_grid(images, nrow=nrow, padding=2) 35 | # print(images.shape) 36 | # if img.shape[-1] > size: 37 | # img = F.interpolate(img, size) 38 | 39 | return images 40 | 41 | class TensorboardLogger: 42 | def __init__(self, args: Namespace, setting: str, 43 | stash: Dict[Any, str]=None) -> None: 44 | from torch.utils.tensorboard import SummaryWriter 45 | 46 | self.settings = [setting] 47 | if setting == 'class-il': 48 | self.settings.append('task-il') 49 | self.loggers = {} 50 | self.name = args.model.backbone 51 | for a_setting in self.settings: 52 | self.loggers[a_setting] = SummaryWriter( 53 | os.path.join(args.ckpt_dir, 'tensorboard_runs')) 54 | config_text = ', '.join( 55 | ["%s=%s" % (name, getattr(args, name)) for name in args.__dir__() 56 | if not name.startswith('_')]) 57 | for a_logger in self.loggers.values(): 58 | a_logger.add_text('config', config_text) 59 | 60 | def get_name(self) -> str: 61 | """ 62 | :return: the name of the model 63 | """ 64 | return self.name 65 | 66 | def log_accuracy(self, all_accs: np.ndarray, all_mean_accs: np.ndarray, 67 | args: Namespace, task_number: int) -> None: 68 | """ 69 | Logs the current accuracy value for each task. 70 | :param all_accs: the accuracies (class-il, task-il) for each task 71 | :param all_mean_accs: the mean accuracies for (class-il, task-il) 72 | :param args: the arguments of the run 73 | :param task_number: the task index 74 | """ 75 | mean_acc_common, mean_acc_task_il = all_mean_accs 76 | for setting, a_logger in self.loggers.items(): 77 | mean_acc = mean_acc_task_il\ 78 | if setting == 'task-il' else mean_acc_common 79 | index = 1 if setting == 'task-il' else 0 80 | accs = [all_accs[index][kk] for kk in range(len(all_accs[0]))] 81 | for kk, acc in enumerate(accs): 82 | a_logger.add_scalar('acc_task%02d' % (kk + 1), acc, 83 | task_number * args.train.num_epochs) 84 | a_logger.add_scalar('acc_mean', mean_acc, task_number * args.train.num_epochs) 85 | 86 | def log_loss(self, loss: float, args: Namespace, epoch: int, 87 | task_number: int, iteration: int) -> None: 88 | """ 89 | Logs the loss value at each iteration. 90 | :param loss: the loss value 91 | :param args: the arguments of the run 92 | :param epoch: the epoch index 93 | :param task_number: the task index 94 | :param iteration: the current iteration 95 | """ 96 | for a_logger in self.loggers.values(): 97 | a_logger.add_scalar('loss', loss, task_number * args.train.num_epochs + epoch) 98 | 99 | 100 | def log_penalty(self, penalty: float, args: Namespace, epoch: int, 101 | task_number: int, iteration: int) -> None: 102 | """ 103 | Logs the loss penalty value at each iteration. 104 | :param loss penalty: the loss penalty value 105 | :param args: the arguments of the run 106 | :param epoch: the epoch index 107 | :param task_number: the task index 108 | :param iteration: the current iteration 109 | """ 110 | for a_logger in self.loggers.values(): 111 | a_logger.add_scalar('penalty', penalty, task_number * args.train.num_epochs + epoch) 112 | 113 | 114 | def log_lr(self, lr: float, args: Namespace, epoch: int, 115 | task_number: int, iteration: int) -> None: 116 | """ 117 | Logs the lr value at each iteration. 118 | :param lr: the lr value 119 | :param iteration: the current iteration 120 | """ 121 | for a_logger in self.loggers.values(): 122 | a_logger.add_scalar('lr', lr, iteration) 123 | a_logger.add_scalar('lr', lr, task_number * args.train.num_epochs + epoch) 124 | 125 | def log_images(self, images, args: Namespace, epoch: int, 126 | task_number: int, iteration: int) -> None: 127 | """ 128 | Logs the lr value at each iteration. 129 | :param lr: the lr value 130 | :param iteration: the current iteration 131 | """ 132 | # img_grid = torchvision.utils.make_grid(images) 133 | # matplotlib_imshow(img_grid) 134 | images = save_img(images) 135 | for a_logger in self.loggers.values(): 136 | a_logger.add_image('syn_images', images, task_number * args.train.num_epochs + epoch) 137 | 138 | def log_loss_gcl(self, loss: float, iteration: int) -> None: 139 | """ 140 | Logs the loss value at each iteration. 141 | :param loss: the loss value 142 | :param iteration: the current iteration 143 | """ 144 | for a_logger in self.loggers.values(): 145 | a_logger.add_scalar('loss', loss, iteration) 146 | 147 | def close(self) -> None: 148 | """ 149 | At the end of the execution, closes the logger. 150 | """ 151 | for a_logger in self.loggers.values(): 152 | a_logger.close() 153 | --------------------------------------------------------------------------------