├── models ├── __init__.py ├── cvae.py ├── resnet_imagenet.py └── resnet_aug.py ├── appendix.pdf ├── evaluations ├── __init__.py └── extract_featrure.py ├── appendix_new.pdf ├── utils ├── HyperparamterDisplay.py ├── __init__.py ├── logging.py ├── utils.py └── osutils.py ├── scripts └── run.sh ├── README.md ├── opts_eTag.py ├── eTag_eval.py ├── ImageFolder.py ├── CIFAR100.py ├── LICENSE └── eTag_train.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- -------------------------------------------------------------------------------- /appendix.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/libo-huang/eTag/HEAD/appendix.pdf -------------------------------------------------------------------------------- /evaluations/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- -------------------------------------------------------------------------------- /appendix_new.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/libo-huang/eTag/HEAD/appendix_new.pdf -------------------------------------------------------------------------------- /utils/HyperparamterDisplay.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | 3 | 4 | def display(args): 5 | # Display information of current training 6 | # print('Learn Rate \t%.1e' % args.lr) 7 | print('Epochs \t%05d' % args.epochs) 8 | print('Log Path \t%s' % args.log_dir) 9 | print('GPU \t %s' % args.gpu) 10 | # print('Data Set \t %s' % args.data) 11 | # print('Batch Size \t %d' % args.BatchSize) 12 | # print('Embedded Dimension \t %d' % args.num_class) 13 | print('Begin to train the network') 14 | print(50 * '-') 15 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from .meters import * 3 | from .sampler import RandomIdentitySampler 4 | import torch 5 | from .osutils import mkdir_if_missing 6 | from .orthogonal_regularizaton import orth_reg 7 | from .str2nums import chars2nums 8 | from .HyperparamterDisplay import display 9 | from .osutils import get_vector 10 | from .osutils import truncated_z_sample 11 | def to_numpy(tensor): 12 | if torch.is_tensor(tensor): 13 | return tensor.cpu().numpy() 14 | elif type(tensor).__module__ != 'numpy': 15 | raise ValueError("Cannot convert {} to numpy array" 16 | .format(type(tensor))) 17 | return tensor 18 | 19 | 20 | def to_torch(ndarray): 21 | if type(ndarray).__module__ == 'numpy': 22 | return torch.from_numpy(ndarray) 23 | elif not torch.is_tensor(ndarray): 24 | raise ValueError("Cannot convert {} to torch tensor" 25 | .format(type(ndarray))) 26 | return ndarray 27 | -------------------------------------------------------------------------------- /utils/logging.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import sys 4 | 5 | from .osutils import mkdir_if_missing 6 | 7 | 8 | class Logger(object): 9 | def __init__(self, fpath=None): 10 | self.console = sys.stdout 11 | self.file = None 12 | if fpath is not None: 13 | mkdir_if_missing(os.path.dirname(fpath)) 14 | self.file = open(fpath, 'w') 15 | 16 | def __del__(self): 17 | self.close() 18 | 19 | def __enter__(self): 20 | pass 21 | 22 | def __exit__(self, *args): 23 | self.close() 24 | 25 | def write(self, msg): 26 | self.console.write(msg) 27 | if self.file is not None: 28 | self.file.write(msg) 29 | 30 | def flush(self): 31 | self.console.flush() 32 | if self.file is not None: 33 | self.file.flush() 34 | os.fsync(self.file.fileno()) 35 | 36 | def close(self): 37 | self.console.close() 38 | if self.file is not None: 39 | self.file.close() 40 | -------------------------------------------------------------------------------- /scripts/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # bash -i run.sh 3 | # dataset: imagenet 4 | cd .. 5 | eval $(conda shell.bash hook) 6 | conda activate HedTog 7 | conda info | egrep "conda version|active environment" 8 | 9 | if [ "$1" != "" ]; then 10 | echo "Running on dataset: $1" 11 | else 12 | echo "No dataset has been assigned." 13 | fi 14 | 15 | if [ "$2" != "" ]; then 16 | echo "Running on gpu: $2" 17 | else 18 | echo "No gpu has been assigned." 19 | fi 20 | 21 | if [ "$3" != "" ]; then 22 | echo "Running with # tasks: $3" 23 | else 24 | echo "No # task has been assigned." 25 | fi 26 | 27 | for SEED in 0 1 2 3 4 28 | do 29 | if [ "$1" = "imagenet" ]; then 30 | python eTag_train.py -data imagenet_sub -log_dir ./checkpoints/imagenet -num_task $3 -nb_cl_fg 50 -gpu $2 -epochs 70 -epochs_gan 100 -tau 3 -lr_decay_step 30 -seed $SEED; 31 | python eTag_eval.py -data imagenet_sub -log_dir ./checkpoints/imagenet -num_task $3 -nb_cl_fg 50 -gpu $2 -epochs 70 -seed $SEED; 32 | elif [ "$1" = "cifar" ]; then 33 | python eTag_train.py -data cifar100 -log_dir ./checkpoints/cifar -num_task $3 -nb_cl_fg 50 -gpu $2 -epochs 100 -epochs_gan 100 -tau 3 -lr_decay_step 30 -seed $SEED; 34 | python eTag_eval.py -data cifar100 -log_dir ./checkpoints/cifar -num_task $3 -nb_cl_fg 50 -gpu $2 -epochs 100 -seed $SEED; 35 | else 36 | echo "No dataset has been assigned." 37 | done -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | import os, sys, errno, torch 4 | 5 | 6 | class Logger(object): 7 | def __init__(self, fpath=None): 8 | self.console = sys.stdout 9 | self.file = None 10 | if fpath is not None: 11 | mkdir_if_missing(os.path.dirname(fpath)) 12 | self.file = open(fpath, 'w') 13 | 14 | def __del__(self): 15 | self.close() 16 | 17 | def __enter__(self): 18 | pass 19 | 20 | def __exit__(self, *args): 21 | self.close() 22 | 23 | def write(self, msg): 24 | self.console.write(msg) 25 | if self.file is not None: 26 | self.file.write(msg) 27 | 28 | def flush(self): 29 | self.console.flush() 30 | if self.file is not None: 31 | self.file.flush() 32 | os.fsync(self.file.fileno()) 33 | 34 | def close(self): 35 | self.console.close() 36 | if self.file is not None: 37 | self.file.close() 38 | 39 | def mkdir_if_missing(dir_path): 40 | try: 41 | os.makedirs(dir_path) 42 | except OSError as e: 43 | if e.errno != errno.EEXIST: 44 | raise 45 | 46 | 47 | def idx2onehot(idx, n): 48 | assert torch.max(idx).item() < n 49 | 50 | if idx.dim() == 1: 51 | idx = idx.unsqueeze(1) 52 | onehot = torch.zeros(idx.size(0), n).to(idx.device) 53 | onehot.scatter_(1, idx, 1) 54 | 55 | return onehot -------------------------------------------------------------------------------- /utils/osutils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import errno 4 | import torch 5 | import pdb 6 | from scipy.stats import truncnorm 7 | 8 | 9 | def mkdir_if_missing(dir_path): 10 | try: 11 | os.makedirs(dir_path) 12 | except OSError as e: 13 | if e.errno != errno.EEXIST: 14 | raise 15 | def truncated_z_sample(batch_size, z_dim, truncation=0.5, seed=None): 16 | state = None if seed is None else np.random.RandomState(seed) 17 | values = truncnorm.rvs(-2, 2, size=(batch_size, z_dim), random_state=state) 18 | return truncation * values 19 | 20 | def get_vector(inputs,model,layer_num): 21 | 22 | # 3. Create a vector of zeros that will hold our feature vector 23 | # The 'avgpool' layer has an output size of 512 24 | #my_embedding = torch.zeros(32,512) 25 | #pdb.set_trace() 26 | #pdb.set_trace() 27 | if layer_num == 0: 28 | tensor_size = [128,64,32,32] 29 | layer = model._modules.get('layer1')[1]._modules.get('bn2') 30 | elif layer_num ==1: 31 | tensor_size = [128,128,16,16] 32 | layer = model._modules.get('layer2')[1]._modules.get('bn2') 33 | elif layer_num ==2: 34 | tensor_size = [128,256,8,8] 35 | layer = model._modules.get('layer3')[1]._modules.get('bn2') 36 | elif layer_num ==3: 37 | tensor_size = [128,512,4,4] 38 | layer = model._modules.get('layer4')[1]._modules.get('bn2') 39 | elif layer_num ==4: 40 | pass 41 | 42 | my_embedding = torch.zeros(tensor_size) 43 | #my_embedding = torch.zeros(32,128,28,28) 44 | #my_embedding = torch.zeros(32,256,14,14) 45 | # 4. Define a function that will copy the output of a layer 46 | def copy_data(m, i, o): 47 | my_embedding.copy_(o.data) 48 | # 5. Attach that function to our selected layer 49 | 50 | 51 | 52 | h = layer.register_forward_hook(copy_data) 53 | # 6. Run the model on our transformed image 54 | model(inputs) 55 | # 7. Detach our copy function from the layer 56 | h.remove() 57 | # 8. Return the feature vector 58 | return my_embedding -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # eTag: Class-Incremental Learning via Hierarchical Embedding Distillation and Task-Oriented Generation 2 | 3 | ## Introduction 4 | 5 | This repository contains the key training and evaluation codes for the AAAI-2024 paper titled **"eTag: Class-Incremental Learning via Hierarchical Embedding Distillation and Task-Oriented Generation"**. 6 | 7 | ## Requirements 8 | 9 | To run the code, ensure the following dependencies are installed: 10 | 11 | - Python 3.8.5 12 | - PyTorch 1.7.1 13 | - torchvision 0.8.2 14 | 15 | ## How to Run 16 | 17 | ### Dataset Preparation 18 | Before running the code, ensure the dataset is properly downloaded or softly linked in the `./dataset` directory. 19 | 20 | ### Execution 21 | You can test our method by executing the provided scripts or running the following commands in the `./scripts` directory: 22 | 23 | #### CIFAR-100 Dataset 24 | ```sh 25 | # 5 tasks 26 | bash -i run.sh cifar 0 5 27 | # 10 tasks 28 | bash -i run.sh cifar 0 10 29 | # 25 tasks 30 | bash -i run.sh cifar 0 25 31 | ``` 32 | 33 | #### ImageNet Subset Dataset 34 | ```sh 35 | # 5 tasks 36 | bash -i run.sh imagenet 0 5 37 | # 10 tasks 38 | bash -i run.sh imagenet 0 10 39 | # 25 tasks 40 | bash -i run.sh imagenet 0 25 41 | ``` 42 | 43 | ### Arguments 44 | - `-data`: Dataset name. Choose from `cifar100` or `imagenet_sub`. 45 | - `-log_dir`: Directory to save models, logs, and events. 46 | - `-num_task`: Number of tasks after the initial task. 47 | - `-nb_cl_fg`: Number of classes in the first task. 48 | 49 | For additional tunable arguments, refer to the `opts_eTag.py` file. 50 | 51 | ## License 52 | 53 | This project is licensed under the **Apache License 2.0**. 54 | A permissive license that requires preservation of copyright and license notices. Contributors provide an express grant of patent rights. Licensed works, modifications, and larger works may be distributed under different terms and without source code. 55 | 56 | | Permissions | Conditions | Limitations | 57 | | ------------------- | ------------------------------- | ---------------- | 58 | | :white_check_mark: Commercial use | ⓘ License and copyright notice | :x: Trademark use | 59 | | :white_check_mark: Modification | ⓘ State changes | :x: Liability | 60 | | :white_check_mark: Distribution | | :x: Warranty | 61 | | :white_check_mark: Patent use | | | 62 | | :white_check_mark: Private use | | | 63 | 64 | -------------------------------------------------------------------------------- /opts_eTag.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import argparse 5 | import os 6 | 7 | def parse_common_args(parser): 8 | parser.add_argument('-data', default='cifar100', help='path to Data Set', 9 | choices=['imagenet_sub', 'cifar100', 'imagenet_full ']) 10 | parser.add_argument('-num_class', default=100, type=int, metavar='n', help='dimension of embedding space') 11 | parser.add_argument('-nb_cl_fg', type=int, default=50, help="Number of class, first group") # TODO 12 | parser.add_argument('-num_task', type=int, default=5, help="Number of Task after initial Task") # TODO 13 | parser.add_argument('-log_dir', type=str, default='./checkpoints/debug', metavar='PATH', # TODO 14 | help='where the models, logs, and events to save') 15 | parser.add_argument('-dir', default='../a_Data', help='data dir') 16 | parser.add_argument("-gpu", type=str, default='3', help='which gpu to choose') 17 | parser.add_argument('-nThreads', '-j', default=4, type=int, metavar='N', help='number of data loading threads') 18 | parser.add_argument('-BatchSize', '-b', default=128, type=int, metavar='N', help='mini-batch size') 19 | parser.add_argument('-epochs', default=1, type=int, metavar='N', help='epochs for training process') # TODO 20 | parser.add_argument('-epochs_vae', default=1, type=int, metavar='N', help='epochs for training GAN') # TODO 21 | parser.add_argument('-seed', default=1, type=int, metavar='N', help='seeds for training process') 22 | return parser 23 | 24 | def parse_train_args(parser): 25 | parser.add_argument('-tradeoff', type=float, default=1.0, help="tradeoff parameter between feature extractor and classifier") 26 | parser.add_argument('-lr', type=float, default=1e-3, help="learning rate of backbone network") 27 | parser.add_argument('-lr_decay', type=float, default=0.1, help='Decay learning rate of backbone network') 28 | parser.add_argument('-lr_decay_step', type=float, default=200, help='Decay learning rate every x steps of backbone network') 29 | parser.add_argument('-weight_decay', type=float, default=2e-4, help='weight decay of backbone network') 30 | parser.add_argument('-vae_tradeoff', type=float, default=1e-3, help='tradeoff parameter of lifelong training vae model') 31 | parser.add_argument('-vae_lr', type=float, default=0.001, help="learning rate of vae") 32 | parser.add_argument('-latent_dim', type=int, default=200, help="dimentions of latent variable") 33 | parser.add_argument('-feat_dim', type=int, default=512, help="dimention of feature") 34 | parser.add_argument('-hidden_dim', type=int, default=512, help="dimention of hidden linear layer") 35 | parser.add_argument('-start', default=0, type=int, help='start from which task to train') 36 | 37 | parser.add_argument('-tau', default=3, type=int, help='KD temperature') 38 | return parser 39 | 40 | def parse_test_args(parser): 41 | parser.add_argument('-top5', action='store_true', help='output top5 accuracy') 42 | return parser 43 | 44 | def get_train_args(): 45 | parser = argparse.ArgumentParser(description='eTag') 46 | parser = parse_common_args(parser) 47 | parser = parse_train_args(parser) 48 | args = parser.parse_args() 49 | 50 | args.log_dir = os.path.join(args.log_dir, args.data+'_{}tasks_s{}_{}'.format(args.num_task, args.nb_cl_fg, args.seed)) 51 | 52 | 53 | return args 54 | 55 | def get_test_args(): 56 | parser = argparse.ArgumentParser(description='PyTorch Testing') 57 | parser = parse_common_args(parser) 58 | parser = parse_test_args(parser) 59 | args = parser.parse_args() 60 | 61 | args.log_dir = os.path.join(args.log_dir, args.data + '_{}tasks_s{}_{}'.format(args.num_task, args.nb_cl_fg, args.seed)) 62 | return args 63 | 64 | 65 | if __name__ == '__main__': 66 | train_args = get_train_args() 67 | test_args = get_test_args() -------------------------------------------------------------------------------- /eTag_eval.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import absolute_import, print_function 3 | from torch.backends import cudnn 4 | from evaluations import extract_features_classification_aug 5 | import torchvision.transforms as transforms 6 | from ImageFolder import * 7 | from utils import * 8 | from CIFAR100 import CIFAR100 9 | from utils import mkdir_if_missing 10 | from opts_eTag import get_test_args 11 | 12 | cudnn.benchmark = True 13 | args = get_test_args() 14 | 15 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 16 | models = [] 17 | for i in os.listdir(args.log_dir): 18 | if i.endswith("_%d_model.pkl" % (args.epochs - 1)): # 500_model.pkl 19 | models.append(os.path.join(args.log_dir, i)) 20 | 21 | models.sort() 22 | 23 | if 'cifar' in args.data: 24 | transform_test = transforms.Compose([ 25 | transforms.ToTensor(), 26 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 27 | ]) 28 | testdir = args.dir + '/cifar100' 29 | 30 | if args.data == 'imagenet_sub' or args.data == 'imagenet_full': 31 | mean_values = [0.485, 0.456, 0.406] 32 | std_values = [0.229, 0.224, 0.225] 33 | transform_test = transforms.Compose([ 34 | transforms.CenterCrop(224), # TODO 35 | #transforms.RandomHorizontalFlip(), 36 | transforms.ToTensor(), 37 | transforms.Normalize(mean=mean_values, 38 | std=std_values) 39 | ]) 40 | testdir = os.path.join(args.dir, 'ILSVRC12_256', 'val') 41 | 42 | num_classes = args.num_class 43 | num_task = args.num_task 44 | num_class_per_task = (num_classes - args.nb_cl_fg) // num_task 45 | 46 | np.random.seed(args.seed) 47 | torch.manual_seed(args.seed) 48 | torch.cuda.manual_seed_all(args.seed) 49 | random_perm = list(range(num_classes)) 50 | 51 | print('Test starting -->\t') 52 | acc_all = np.zeros((num_task+3, num_task+1), dtype = 'float') # Save for csv 53 | 54 | for task_id in range(num_task+1): 55 | if task_id == 0: 56 | index = random_perm[:args.nb_cl_fg] 57 | else: 58 | index = random_perm[:args.nb_cl_fg + (task_id) * num_class_per_task] 59 | if 'imagenet' in args.data: 60 | testfolder = ImageFolder(testdir, transform_test, index=index) 61 | test_loader = torch.utils.data.DataLoader(testfolder, batch_size=128, shuffle=False, num_workers=4, drop_last=False) 62 | elif args.data =='cifar100': 63 | np.random.seed(args.seed) 64 | target_transform = np.random.permutation(num_classes) 65 | testset = CIFAR100(root=testdir, train=False, download=True, transform=transform_test, target_transform = target_transform, index = index) 66 | test_loader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2,drop_last=False) 67 | 68 | print('Test %d\t' % task_id) 69 | 70 | model_id = task_id 71 | model = torch.load(models[model_id]) 72 | 73 | 74 | val_embeddings_cl, val_labels_cl = extract_features_classification_aug(model, test_loader, print_freq=32, metric=None) 75 | # Unknown task ID 76 | 77 | num_class = 0 78 | ave = 0.0 79 | weighted_ave = 0.0 80 | for k in range(task_id + 1): 81 | if k==0: 82 | tmp = random_perm[:args.nb_cl_fg] 83 | else: 84 | tmp = random_perm[args.nb_cl_fg + (k-1) * num_class_per_task:args.nb_cl_fg + (k) * num_class_per_task] 85 | gt = np.isin(val_labels_cl, tmp) 86 | if args.top5: 87 | estimate = np.argsort(val_embeddings_cl, axis=1)[:,-5:] 88 | estimate_label = estimate 89 | estimate_tmp = np.asarray(estimate_label)[gt] 90 | labels_tmp = np.tile(val_labels_cl[gt].reshape([len(val_labels_cl[gt]),1]),[1,5]) 91 | acc = np.sum(estimate_tmp == labels_tmp) / float(len(estimate_tmp)) 92 | else: 93 | estimate = np.argmax(val_embeddings_cl, axis=1) 94 | estimate_label = estimate 95 | estimate_tmp = np.asarray(estimate_label)[gt] 96 | acc = np.sum(estimate_tmp == val_labels_cl[gt]) / float(len(estimate_tmp)) 97 | ave += acc 98 | weighted_ave += acc * len(tmp) 99 | num_class += len(tmp) 100 | print("Accuracy of Model %d on Task %d with unknown task boundary is %.3f" % (model_id, k, acc)) 101 | acc_all[k, task_id] = acc 102 | print('Average: %.3f Weighted Average: %.3f' %(ave / (task_id + 1), weighted_ave / num_class)) 103 | acc_all[num_task + 1, task_id] = ave / (task_id + 1) 104 | acc_all[num_task + 2, task_id] = weighted_ave / num_class 105 | 106 | np.savetxt(args.log_dir + '/epoch_%d.csv' % args.epochs, acc_all*100, delimiter=',') -------------------------------------------------------------------------------- /models/cvae.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | #!/usr/bin/env python 4 | # -*- encoding: utf-8 -*- 5 | from torch import nn 6 | import torch, sys 7 | sys.path.append("..") 8 | 9 | from utils.utils import idx2onehot 10 | 11 | 12 | class CVAE_Mnist_v0(nn.Module): 13 | 14 | def __init__(self, args): 15 | super(CVAE_Mnist, self).__init__() 16 | self.feat_dim = args.feat_dim 17 | self.latent_dim = args.latent_dim 18 | self.hidden_dim = args.hidden_dim 19 | self.class_dim = args.class_dim 20 | 21 | self.model_encoder = nn.Sequential( 22 | nn.Linear(self.feat_dim+self.class_dim, self.hidden_dim), 23 | nn.LeakyReLU(), 24 | nn.Linear(self.hidden_dim, self.hidden_dim), 25 | nn.LeakyReLU(), 26 | ) 27 | self.fc_mu = nn.Linear(self.hidden_dim, self.latent_dim) 28 | self.fc_var = nn.Linear(self.hidden_dim, self.latent_dim) 29 | self.model_decoder = nn.Sequential( 30 | nn.Linear(self.latent_dim+self.class_dim, self.hidden_dim), 31 | nn.LeakyReLU(), 32 | nn.Linear(self.hidden_dim, self.hidden_dim), 33 | nn.LeakyReLU(), 34 | nn.Linear(self.hidden_dim, self.feat_dim), 35 | ) 36 | self.sigmoid = nn.LeakyReLU() # 因为Mnist像素值在0-1之间,所以建议用Sigmoid(), tanh() 37 | 38 | def encode(self, x, y): 39 | x = torch.cat((x, y), 1) 40 | x = self.model_encoder(x) 41 | mu = self.fc_mu(x) 42 | logvar = self.fc_var(x) 43 | return mu, logvar 44 | 45 | def decode(self, z, y): 46 | z = torch.cat((z, y), 1) 47 | logit = self.model_decoder(z) 48 | feat = self.sigmoid(logit) 49 | return feat 50 | 51 | def reparameterize(self, mu, logvar): 52 | std = logvar.mul(0.5).exp_() 53 | eps = torch.FloatTensor(std.size()).normal_().to(std.device) 54 | z = eps.mul(std).add_(mu) 55 | return z 56 | 57 | def forward(self, x, y): 58 | x = x.reshape(x.shape[0], -1) 59 | y = torch.eye(self.class_dim)[y].to(x.device) 60 | mu, logvar = self.encode(x, y) 61 | z = self.reparameterize(mu, logvar) 62 | x_rec = self.decode(z, y).reshape(x.shape[0], 1, 32, 32) 63 | 64 | return x_rec, mu, logvar 65 | # test 66 | # x = torch.randn(2, 1, 32, 32) 67 | # y = torch.randint(0, 9, [2, ]) 68 | # 69 | # class arg(object): 70 | # def __init__(self): 71 | # self.feat_dim = 32 * 32 72 | # self.latent_dim = 2 73 | # self.hidden_dim = 100 74 | # self.class_dim = 10 75 | # 76 | # args = arg() 77 | # model = CVAE_Mnist(args) 78 | # x_rec, _, _ = model(x, y) 79 | 80 | 81 | class CVAE_Cifar(nn.Module): 82 | 83 | def __init__(self, args): 84 | 85 | super().__init__() 86 | self.encoder_layer_sizes = args.encoder_layer_sizes 87 | self.latent_size = args.latent_size 88 | self.decoder_layer_sizes = args.decoder_layer_sizes 89 | self.num_labels = args.class_dim 90 | 91 | self.encoder = Encoder(self.encoder_layer_sizes, self.latent_size, self.num_labels) 92 | self.decoder = Decoder(self.decoder_layer_sizes, self.latent_size, self.num_labels) 93 | 94 | def forward(self, x, c=None): 95 | 96 | if x.dim() > 2: 97 | x = x.view(-1, 32*32) 98 | 99 | means, log_var = self.encoder(x, c) 100 | z = self.reparameterize(means, log_var) 101 | recon_x = self.decoder(z, c) 102 | 103 | return recon_x, means, log_var, z 104 | 105 | def reparameterize(self, mu, log_var): 106 | 107 | std = torch.exp(0.5 * log_var) 108 | eps = torch.randn_like(std) 109 | 110 | return mu + eps * std 111 | 112 | def inference(self, z, c=None): 113 | 114 | recon_x = self.decoder(z, c) 115 | 116 | return recon_x 117 | 118 | 119 | class Encoder(nn.Module): 120 | 121 | def __init__(self, layer_sizes, latent_size, num_labels): 122 | 123 | super().__init__() 124 | 125 | # layer_sizes[0] += num_labels 126 | self.num_labels = num_labels 127 | self.MLP = nn.Sequential() 128 | 129 | for i, (in_size, out_size) in enumerate(zip([layer_sizes[0] + num_labels] + layer_sizes[1:-1], layer_sizes[1:])): 130 | self.MLP.add_module( 131 | name="L{:d}".format(i), module=nn.Linear(in_size, out_size)) 132 | self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU()) 133 | 134 | self.linear_means = nn.Linear(layer_sizes[-1], latent_size) 135 | self.linear_log_var = nn.Linear(layer_sizes[-1], latent_size) 136 | 137 | def forward(self, x, c=None): 138 | 139 | c = idx2onehot(c, n=self.num_labels) 140 | x = torch.cat((x, c), dim=-1) 141 | 142 | x = self.MLP(x) 143 | 144 | means = self.linear_means(x) 145 | log_vars = self.linear_log_var(x) 146 | 147 | return means, log_vars 148 | 149 | 150 | class Decoder(nn.Module): 151 | 152 | def __init__(self, layer_sizes, latent_size, num_labels): 153 | 154 | super().__init__() 155 | self.num_labels = num_labels 156 | self.MLP = nn.Sequential() 157 | 158 | input_size = latent_size + num_labels 159 | 160 | for i, (in_size, out_size) in enumerate(zip([input_size]+layer_sizes[:-1], layer_sizes)): 161 | self.MLP.add_module( 162 | name="L{:d}".format(i), module=nn.Linear(in_size, out_size)) 163 | if i+1 < len(layer_sizes): 164 | self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU()) 165 | else: 166 | # self.MLP.add_module(name="sigmoid", module=nn.Sigmoid()) 167 | self.MLP.add_module(name="sigmoid", module=nn.ReLU()) 168 | 169 | def forward(self, z, c): 170 | 171 | c = idx2onehot(c, n=self.num_labels) 172 | z = torch.cat((z, c), dim=-1) 173 | 174 | x = self.MLP(z) 175 | # x = x.view(x.size(0), 1, 32, 32) 176 | return x 177 | 178 | 179 | if __name__ == '__main__': 180 | x = torch.randn(2, 512) 181 | y = torch.randint(0, 9, [2, ]) 182 | 183 | class arg(object): 184 | def __init__(self): 185 | self.encoder_layer_sizes = [512, 512, 512] 186 | self.decoder_layer_sizes = [512, 512, 512] 187 | self.latent_size = 200 188 | self.class_dim = 100 189 | args = arg() 190 | model = CVAE_Cifar(args) 191 | x_rec, mean, log_var, z = model(x, y) 192 | print('end') -------------------------------------------------------------------------------- /evaluations/extract_featrure.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import time 3 | from collections import OrderedDict 4 | 5 | import torch 6 | from utils import to_numpy 7 | import numpy as np 8 | 9 | # from .evaluation_metrics import cmc, mean_ap 10 | from utils.meters import AverageMeter 11 | from .cnn import extract_cnn_feature, extract_cnn_feature_classification 12 | import pdb 13 | 14 | 15 | def extract_features(model, data_loader, print_freq=1, metric=None): 16 | model = model.cuda() 17 | model.eval() 18 | # features = OrderedDict() 19 | # labels = OrderedDict() 20 | feats = [] 21 | labels = [] 22 | for i, data in enumerate(data_loader, 0): 23 | imgs, label = data 24 | inputs = imgs.cuda() 25 | with torch.no_grad(): 26 | feat = model(inputs, is_feat=True) 27 | for i in range(len(feat)): 28 | feat[i] = feat[i].cpu().numpy() 29 | if feats == []: 30 | feats = feat 31 | labels = label 32 | else: 33 | for i in range(len(feat)): 34 | feats[i] = np.vstack((feats[i], feat[i])) 35 | labels = np.hstack((labels, label)) 36 | return feats, labels 37 | 38 | 39 | def extract_features_classification(model, data_loader, print_freq=1, metric=None): 40 | model = model.cuda() 41 | model.eval() 42 | # batch_time = AverageMeter() 43 | # data_time = AverageMeter() 44 | 45 | # features = OrderedDict() 46 | # labels = OrderedDict() 47 | features = [] 48 | labels = [] 49 | # end = time.time() 50 | # pdb.set_trace() 51 | for i, data in enumerate(data_loader, 0): 52 | imgs, pids = data 53 | # data_time.update(time.time() - end) 54 | # print(imgs.size()) 55 | # outputs = extract_cnn_feature(model, imgs) 56 | inputs = imgs.cuda() 57 | with torch.no_grad(): 58 | outputs = model(inputs) 59 | outputs = model.embed(outputs) 60 | outputs = outputs.cpu().numpy() 61 | # print(outputs.size()) 62 | # for output, pid in zip(outputs, pids): 63 | if features == []: 64 | features = outputs 65 | labels = pids 66 | else: 67 | features = np.vstack((features, outputs)) 68 | labels = np.hstack((labels, pids)) 69 | 70 | # batch_time.update(time.time() - end) 71 | # end = time.time() 72 | 73 | # if (i + 1) % print_freq == 0: 74 | # print('Extract Features: [{}/{}]\t' 75 | # 'Time {:.3f} ({:.3f})\t' 76 | # 'Data {:.3f} ({:.3f})\t' 77 | # .format(i + 1, len(data_loader), 78 | # batch_time.val, batch_time.avg, 79 | # data_time.val, data_time.avg)) 80 | return features, labels 81 | 82 | def extract_features_logits(model, data_loader, print_freq=1, metric=None): 83 | model = model.cuda() 84 | model.eval() 85 | # batch_time = AverageMeter() 86 | # data_time = AverageMeter() 87 | 88 | # features = OrderedDict() 89 | # labels = OrderedDict() 90 | features = [] 91 | labels = [] 92 | # end = time.time() 93 | # pdb.set_trace() 94 | for i, data in enumerate(data_loader, 0): 95 | imgs, pids = data 96 | # data_time.update(time.time() - end) 97 | # print(imgs.size()) 98 | # outputs = extract_cnn_feature(model, imgs) 99 | inputs = imgs.cuda() 100 | with torch.no_grad(): 101 | outputs = model(inputs) 102 | outputs = model.embed(outputs) 103 | outputs = outputs.cpu().numpy() 104 | # print(outputs.size()) 105 | # for output, pid in zip(outputs, pids): 106 | if features == []: 107 | features = outputs 108 | labels = pids 109 | else: 110 | features = np.vstack((features, outputs)) 111 | labels = np.hstack((labels, pids)) 112 | 113 | # batch_time.update(time.time() - end) 114 | # end = time.time() 115 | 116 | # if (i + 1) % print_freq == 0: 117 | # print('Extract Features: [{}/{}]\t' 118 | # 'Time {:.3f} ({:.3f})\t' 119 | # 'Data {:.3f} ({:.3f})\t' 120 | # .format(i + 1, len(data_loader), 121 | # batch_time.val, batch_time.avg, 122 | # data_time.val, data_time.avg)) 123 | return features, labels 124 | 125 | 126 | def pairwise_distance(features, metric=None): 127 | n = len(features) 128 | x = torch.cat(features) 129 | x = x.view(n, -1) 130 | # print(4*'\n', x.size()) 131 | if metric is not None: 132 | x = metric.transform(x) 133 | dist = torch.pow(x, 2).sum(dim=1, keepdim=True) 134 | # print(dist.size()) 135 | dist = dist.expand(n, n) 136 | dist = dist + dist.t() 137 | dist = dist - 2 * torch.mm(x, x.t()) + 1e5 * torch.eye(n) 138 | dist = torch.sqrt(dist) 139 | return dist 140 | 141 | 142 | def pairwise_similarity(features): 143 | n = len(features) 144 | x = torch.cat(features) 145 | x = x.view(n, -1) 146 | # print(4*'\n', x.size()) 147 | similarity = torch.mm(x, x.t()) - 1e5 * torch.eye(n) 148 | return similarity 149 | 150 | # 151 | # features = torch.round(2*torch.rand(4, 2)) 152 | # print(features) 153 | # distmat = pairwise_similarity(features) 154 | # distmat = to_numpy(distmat) 155 | # indices = np.argsort(distmat, axis=1) 156 | # print(distmat) 157 | # print(indices) 158 | 159 | 160 | def extract_features_classification_aug(model, data_loader, print_freq=1, metric=None): 161 | model = model.cuda() 162 | model.eval() 163 | # batch_time = AverageMeter() 164 | # data_time = AverageMeter() 165 | 166 | # features = OrderedDict() 167 | # labels = OrderedDict() 168 | features = [] 169 | labels = [] 170 | # end = time.time() 171 | # pdb.set_trace() 172 | for i, data in enumerate(data_loader, 0): 173 | imgs, pids = data 174 | # data_time.update(time.time() - end) 175 | # print(imgs.size()) 176 | # outputs = extract_cnn_feature(model, imgs) 177 | inputs = imgs.cuda() 178 | with torch.no_grad(): 179 | outputs, _ = model(inputs) 180 | outputs = model.backbone.embed(outputs) 181 | outputs = outputs.cpu().numpy() 182 | # print(outputs.size()) 183 | # for output, pid in zip(outputs, pids): 184 | if features == []: 185 | features = outputs 186 | labels = pids 187 | else: 188 | features = np.vstack((features, outputs)) 189 | labels = np.hstack((labels, pids)) 190 | 191 | # batch_time.update(time.time() - end) 192 | # end = time.time() 193 | 194 | # if (i + 1) % print_freq == 0: 195 | # print('Extract Features: [{}/{}]\t' 196 | # 'Time {:.3f} ({:.3f})\t' 197 | # 'Data {:.3f} ({:.3f})\t' 198 | # .format(i + 1, len(data_loader), 199 | # batch_time.val, batch_time.avg, 200 | # data_time.val, data_time.avg)) 201 | return features, labels -------------------------------------------------------------------------------- /ImageFolder.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | 5 | import os 6 | import os.path 7 | import sys 8 | import pdb 9 | import random 10 | import numpy as np 11 | def has_file_allowed_extension(filename, extensions): 12 | """Checks if a file is an allowed extension. 13 | 14 | Args: 15 | filename (string): path to a file 16 | extensions (iterable of strings): extensions to consider (lowercase) 17 | 18 | Returns: 19 | bool: True if the filename ends with one of given extensions 20 | """ 21 | filename_lower = filename.lower() 22 | return any(filename_lower.endswith(ext) for ext in extensions) 23 | 24 | 25 | def is_image_file(filename): 26 | """Checks if a file is an allowed image extension. 27 | 28 | Args: 29 | filename (string): path to a file 30 | 31 | Returns: 32 | bool: True if the filename ends with a known image extension 33 | """ 34 | return has_file_allowed_extension(filename, IMG_EXTENSIONS) 35 | 36 | 37 | def make_dataset(dir, class_to_idx, extensions,num_instance_per_class): 38 | images = [] 39 | dir = os.path.expanduser(dir) 40 | for target in sorted(class_to_idx.keys()): 41 | d = os.path.join(dir, target) 42 | if not os.path.isdir(d): 43 | continue 44 | 45 | for root, _, fnames in sorted(os.walk(d)): 46 | if num_instance_per_class==0: 47 | num = len(fnames) 48 | else: 49 | num = min(num_instance_per_class,len(fnames)) 50 | for fname in sorted(fnames)[:num]: 51 | if has_file_allowed_extension(fname, extensions): 52 | path = os.path.join(root, fname) 53 | item = (path, class_to_idx[target]) 54 | images.append(item) 55 | 56 | return images 57 | 58 | 59 | class DatasetFolder(data.Dataset): 60 | """A generic data loader where the samples are arranged in this way: :: 61 | 62 | root/class_x/xxx.ext 63 | root/class_x/xxy.ext 64 | root/class_x/xxz.ext 65 | 66 | root/class_y/123.ext 67 | root/class_y/nsdf3.ext 68 | root/class_y/asd932_.ext 69 | 70 | Args: 71 | root (string): Root directory path. 72 | loader (callable): A function to load a sample given its path. 73 | extensions (list[string]): A list of allowed extensions. 74 | transform (callable, optional): A function/transform that takes in 75 | a sample and returns a transformed version. 76 | E.g, ``transforms.RandomCrop`` for images. 77 | target_transform (callable, optional): A function/transform that takes 78 | in the target and transforms it. 79 | 80 | Attributes: 81 | classes (list): List of the class names. 82 | class_to_idx (dict): Dict with items (class_name, class_index). 83 | samples (list): List of (sample path, class_index) tuples 84 | targets (list): The class_index value for each image in the dataset 85 | """ 86 | 87 | def __init__(self, root, loader, extensions, transform=None, target_transform=None, index=None,num_instance_per_class=0): 88 | classes, class_to_idx = self._find_classes(root) 89 | #np.random.seed(1) 90 | #class_to_idx = {k: v for k, v in random.choice(list(class_to_idx.items())) if v in index} 91 | #pdb.set_trace() 92 | np.random.seed(1993) 93 | list_permutaion = np.random.permutation( len(class_to_idx.items())) 94 | class_to_idx = {k: list_permutaion[v] for k, v in class_to_idx.items() if list_permutaion[v] in index} 95 | samples = make_dataset(root, class_to_idx, extensions,num_instance_per_class) 96 | # if index is not None: 97 | # samples = [samples[i] for i in index] 98 | if len(samples) == 0: 99 | raise (RuntimeError("Found 0 files in subfolders of: " + root + "\n" 100 | "Supported extensions are: " + ",".join( 101 | extensions))) 102 | 103 | self.root = root 104 | self.loader = loader 105 | self.extensions = extensions 106 | 107 | self.classes = classes 108 | self.class_to_idx = class_to_idx 109 | self.samples = samples 110 | self.targets = [s[1] for s in samples] 111 | 112 | self.transform = transform 113 | self.target_transform = target_transform 114 | 115 | def _find_classes(self, dir): 116 | """ 117 | Finds the class folders in a dataset. 118 | 119 | Args: 120 | dir (string): Root directory path. 121 | 122 | Returns: 123 | tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary. 124 | 125 | Ensures: 126 | No class is a subdirectory of another. 127 | """ 128 | if sys.version_info >= (3, 5): 129 | # Faster and available in Python 3.5 and above 130 | classes = [d.name for d in os.scandir(dir) if d.is_dir()] 131 | else: 132 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 133 | classes.sort() 134 | class_to_idx = {classes[i]: i for i in range(len(classes))} 135 | return classes, class_to_idx 136 | 137 | def __getitem__(self, index): 138 | """ 139 | Args: 140 | index (int): Index 141 | 142 | Returns: 143 | tuple: (sample, target) where target is class_index of the target class. 144 | """ 145 | path, target = self.samples[index] 146 | sample = self.loader(path) 147 | 148 | if self.transform is not None: 149 | sample = self.transform(sample) 150 | if self.target_transform is not None: 151 | target = self.target_transform(target) 152 | 153 | return sample, target 154 | 155 | def __len__(self): 156 | return len(self.samples) 157 | 158 | def __repr__(self): 159 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 160 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 161 | fmt_str += ' Root Location: {}\n'.format(self.root) 162 | tmp = ' Transforms (if any): ' 163 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 164 | tmp = ' Target Transforms (if any): ' 165 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 166 | return fmt_str 167 | 168 | 169 | IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif'] 170 | 171 | 172 | def pil_loader(path): 173 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 174 | with open(path, 'rb') as f: 175 | img = Image.open(f) 176 | return img.convert('RGB') 177 | 178 | 179 | def accimage_loader(path): 180 | import accimage 181 | try: 182 | return accimage.Image(path) 183 | except IOError: 184 | # Potentially a decoding problem, fall back to PIL.Image 185 | return pil_loader(path) 186 | 187 | 188 | def default_loader(path): 189 | from torchvision import get_image_backend 190 | if get_image_backend() == 'accimage': 191 | return accimage_loader(path) 192 | else: 193 | return pil_loader(path) 194 | 195 | 196 | class ImageFolder(DatasetFolder): 197 | """A generic data loader where the images are arranged in this way: :: 198 | 199 | root/dog/xxx.png 200 | root/dog/xxy.png 201 | root/dog/xxz.png 202 | 203 | root/cat/123.png 204 | root/cat/nsdf3.png 205 | root/cat/asd932_.png 206 | 207 | Args: 208 | root (string): Root directory path. 209 | transform (callable, optional): A function/transform that takes in an PIL image 210 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 211 | target_transform (callable, optional): A function/transform that takes in the 212 | target and transforms it. 213 | loader (callable, optional): A function to load an image given its path. 214 | 215 | Attributes: 216 | classes (list): List of the class names. 217 | class_to_idx (dict): Dict with items (class_name, class_index). 218 | imgs (list): List of (image path, class_index) tuples 219 | """ 220 | 221 | def __init__(self, root, transform=None, target_transform=None, 222 | loader=default_loader, index=None,num_instance_per_class=0): 223 | super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS, 224 | transform=transform, 225 | index=index, 226 | target_transform=target_transform, 227 | num_instance_per_class=num_instance_per_class) 228 | self.imgs = self.samples 229 | -------------------------------------------------------------------------------- /CIFAR100.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from PIL import Image 3 | import os 4 | import os.path 5 | import numpy as np 6 | import sys 7 | if sys.version_info[0] == 2: 8 | import cPickle as pickle 9 | else: 10 | import pickle 11 | 12 | import torch.utils.data as data 13 | #from torch.utils import download_url, check_integrity 14 | 15 | import pdb 16 | 17 | class CIFAR10(data.Dataset): 18 | """`CIFAR10 `_ Dataset. 19 | Args: 20 | root (string): Root directory of dataset where directory 21 | ``cifar-10-batches-py`` exists or will be saved to if download is set to True. 22 | train (bool, optional): If True, creates dataset from training set, otherwise 23 | creates from test set. 24 | transform (callable, optional): A function/transform that takes in an PIL image 25 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 26 | target_transform (callable, optional): A function/transform that takes in the 27 | target and transforms it. 28 | download (bool, optional): If true, downloads the dataset from the internet and 29 | puts it in root directory. If dataset is already downloaded, it is not 30 | downloaded again. 31 | """ 32 | base_folder = 'cifar-10-batches-py' 33 | url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" 34 | filename = "cifar-10-python.tar.gz" 35 | tgz_md5 = 'c58f30108f718f92721af3b95e74349a' 36 | train_list = [ 37 | ['data_batch_1', 'c99cafc152244af753f735de768cd75f'], 38 | ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'], 39 | ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'], 40 | ['data_batch_4', '634d18415352ddfa80567beed471001a'], 41 | ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'], 42 | ] 43 | 44 | test_list = [ 45 | ['test_batch', '40351d587109b95175f43aff81a1287e'], 46 | ] 47 | meta = { 48 | 'filename': 'batches.meta', 49 | 'key': 'label_names', 50 | 'md5': '5ff9c542aee3614f3951f8cda6e48888', 51 | } 52 | 53 | def __init__(self, root, train=True, 54 | transform=None, target_transform=None, 55 | download=False, index=None,num_instance_per_class=0): 56 | self.root = os.path.expanduser(root) 57 | self.transform = transform 58 | self.target_transform = target_transform 59 | self.train = train # training set or test set 60 | 61 | # if download: 62 | # self.download() 63 | 64 | # if not self._check_integrity(): 65 | # raise RuntimeError('Dataset not found or corrupted.' + 66 | # ' You can use download=True to download it') 67 | 68 | if self.train: 69 | downloaded_list = self.train_list 70 | else: 71 | downloaded_list = self.test_list 72 | 73 | self.data = [] 74 | self.targets = [] 75 | 76 | # now load the picked numpy arrays 77 | for file_name, checksum in downloaded_list: 78 | file_path = os.path.join(self.root, self.base_folder, file_name) 79 | with open(file_path, 'rb') as f: 80 | if sys.version_info[0] == 2: 81 | entry = pickle.load(f) 82 | else: 83 | entry = pickle.load(f, encoding='latin1') 84 | self.data.append(entry['data']) 85 | if 'labels' in entry: 86 | self.targets.extend(entry['labels']) 87 | else: 88 | self.targets.extend(entry['fine_labels']) 89 | 90 | self.data = np.vstack(self.data).reshape(-1, 3, 32, 32) 91 | self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC 92 | #self.data = self.data/255. 93 | #pdb.set_trace() 94 | self.targets = np.asarray(self.targets) 95 | #index_sort = np.argsort(self.targets) 96 | # Sort label and corresponding data from 0-9 97 | #self.data = self.data[index_sort] 98 | #self.targets=np.asarray(sorted(self.targets)) 99 | 100 | self.targets = target_transform[self.targets] # 重新编号(不影响同类样本在一个class中),因为使用的随机种子一样,所以与train中的label对齐了 todo-libo 101 | 102 | if num_instance_per_class==0: 103 | self.data,self.targets = self.RandomPercentage(self.data,self.targets,index) 104 | else: 105 | self.data,self.targets = self.RandomExempalers(self.data,self.targets,index,num_instance_per_class) 106 | 107 | 108 | self._load_meta() 109 | 110 | def RandomPercentage(self, data,targets,index): 111 | data_tmp = [] 112 | targets_tmp = [] 113 | for i in index: 114 | ind_cl = np.where(i == targets)[0] 115 | if data_tmp==[]: 116 | data_tmp = data[ind_cl] 117 | targets_tmp = targets[ind_cl] 118 | else: 119 | data_tmp = np.vstack((data_tmp,data[ind_cl])) 120 | targets_tmp = np.hstack((targets_tmp,targets[ind_cl])) 121 | 122 | return data_tmp,targets_tmp 123 | def RandomExempalers(self, data,targets,index,num): 124 | data_tmp = [] 125 | targets_tmp = [] 126 | for i in index: 127 | ind_cl = np.where(i == targets)[0][:num] 128 | if data_tmp==[]: 129 | data_tmp = data[ind_cl] 130 | targets_tmp = targets[ind_cl] 131 | else: 132 | data_tmp = np.vstack((data_tmp,data[ind_cl])) 133 | targets_tmp = np.hstack((targets_tmp,targets[ind_cl])) 134 | 135 | return data_tmp,targets_tmp 136 | def _load_meta(self): 137 | path = os.path.join(self.root, self.base_folder, self.meta['filename']) 138 | # if not check_integrity(path, self.meta['md5']): 139 | # raise RuntimeError('Dataset metadata file not found or corrupted.' + 140 | # ' You can use download=True to download it') 141 | with open(path, 'rb') as infile: 142 | if sys.version_info[0] == 2: 143 | data = pickle.load(infile) 144 | else: 145 | data = pickle.load(infile, encoding='latin1') 146 | self.classes = data[self.meta['key']] 147 | self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)} 148 | 149 | def __getitem__(self, index): 150 | """ 151 | Args: 152 | index (int): Index 153 | Returns: 154 | tuple: (image, target) where target is index of the target class. 155 | """ 156 | img, target = self.data[index], self.targets[index] 157 | 158 | # doing this so that it is consistent with all other datasets 159 | # to return a PIL Image 160 | img = Image.fromarray(img) 161 | 162 | if self.transform is not None: 163 | img = self.transform(img) 164 | 165 | # if self.target_transform is not None: 166 | # target = self.target_transform(target) 167 | 168 | return img, target 169 | 170 | def __len__(self): 171 | return len(self.data) 172 | 173 | # def _check_integrity(self): 174 | # root = self.root 175 | # for fentry in (self.train_list + self.test_list): 176 | # filename, md5 = fentry[0], fentry[1] 177 | # fpath = os.path.join(root, self.base_folder, filename) 178 | # if not check_integrity(fpath, md5): 179 | # return False 180 | # return True 181 | 182 | # def download(self): 183 | # import tarfile 184 | 185 | # # if self._check_integrity(): 186 | # # print('Files already downloaded and verified') 187 | # # return 188 | 189 | # download_url(self.url, self.root, self.filename, self.tgz_md5) 190 | 191 | # # extract file 192 | # with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar: 193 | # tar.extractall(path=self.root) 194 | 195 | def __repr__(self): 196 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 197 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 198 | tmp = 'train' if self.train is True else 'test' 199 | fmt_str += ' Split: {}\n'.format(tmp) 200 | fmt_str += ' Root Location: {}\n'.format(self.root) 201 | tmp = ' Transforms (if any): ' 202 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 203 | tmp = ' Target Transforms (if any): ' 204 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 205 | return fmt_str 206 | 207 | class CIFAR100(CIFAR10): 208 | """`CIFAR100 `_ Dataset. 209 | This is a subclass of the `CIFAR10` Dataset. 210 | """ 211 | base_folder = 'cifar-100-python' 212 | url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" 213 | filename = "cifar-100-python.tar.gz" 214 | tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85' 215 | train_list = [ 216 | ['train', '16019d7e3df5f24257cddd939b257f8d'], 217 | ] 218 | 219 | test_list = [ 220 | ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'], 221 | ] 222 | meta = { 223 | 'filename': 'meta', 224 | 'key': 'fine_label_names', 225 | 'md5': '7973b15100ade9c7d40fb424638fde48', 226 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /models/resnet_imagenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | __all__ = ['resnet18_imagenet', 'resnet18_imagenet_aux', 'resnet34_imagenet', 5 | 'resnet34_imagenet_aux', 'resnet50_imagenet', 'resnet50_imagenet_aux'] 6 | 7 | 8 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 9 | """3x3 convolution with padding""" 10 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 11 | padding=dilation, groups=groups, bias=False, dilation=dilation) 12 | 13 | 14 | def conv1x1(in_planes, out_planes, stride=1): 15 | """1x1 convolution""" 16 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 17 | 18 | 19 | class BasicBlock(nn.Module): 20 | expansion = 1 21 | 22 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 23 | base_width=64, dilation=1, norm_layer=None): 24 | super(BasicBlock, self).__init__() 25 | if norm_layer is None: 26 | norm_layer = nn.BatchNorm2d 27 | if groups != 1 or base_width != 64: 28 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 29 | if dilation > 1: 30 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 31 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 32 | self.conv1 = conv3x3(inplanes, planes, stride) 33 | self.bn1 = norm_layer(planes) 34 | self.relu = nn.ReLU(inplace=True) 35 | self.conv2 = conv3x3(planes, planes) 36 | self.bn2 = norm_layer(planes) 37 | self.downsample = downsample 38 | self.stride = stride 39 | 40 | def forward(self, x): 41 | identity = x 42 | 43 | out = self.conv1(x) 44 | out = self.bn1(out) 45 | out = self.relu(out) 46 | 47 | out = self.conv2(out) 48 | out = self.bn2(out) 49 | 50 | if self.downsample is not None: 51 | identity = self.downsample(x) 52 | 53 | out += identity 54 | out = self.relu(out) 55 | 56 | return out 57 | 58 | 59 | class Bottleneck(nn.Module): 60 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 61 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 62 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 63 | # This variant is also known as ResNet V1.5 and improves accuracy according to 64 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 65 | 66 | expansion = 4 67 | 68 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 69 | base_width=64, dilation=1, norm_layer=None): 70 | super(Bottleneck, self).__init__() 71 | if norm_layer is None: 72 | norm_layer = nn.BatchNorm2d 73 | width = int(planes * (base_width / 64.)) * groups 74 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 75 | self.conv1 = conv1x1(inplanes, width) 76 | self.bn1 = norm_layer(width) 77 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 78 | self.bn2 = norm_layer(width) 79 | self.conv3 = conv1x1(width, planes * self.expansion) 80 | self.bn3 = norm_layer(planes * self.expansion) 81 | self.relu = nn.ReLU(inplace=True) 82 | self.downsample = downsample 83 | self.stride = stride 84 | 85 | def forward(self, x): 86 | identity = x 87 | 88 | out = self.conv1(x) 89 | out = self.bn1(out) 90 | out = self.relu(out) 91 | 92 | out = self.conv2(out) 93 | out = self.bn2(out) 94 | out = self.relu(out) 95 | 96 | out = self.conv3(out) 97 | out = self.bn3(out) 98 | 99 | if self.downsample is not None: 100 | identity = self.downsample(x) 101 | 102 | out += identity 103 | out = self.relu(out) 104 | 105 | return out 106 | 107 | 108 | class ResNet(nn.Module): 109 | 110 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 111 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 112 | norm_layer=None): 113 | super(ResNet, self).__init__() 114 | if norm_layer is None: 115 | norm_layer = nn.BatchNorm2d 116 | self._norm_layer = norm_layer 117 | 118 | self.inplanes = 64 119 | self.dilation = 1 120 | if replace_stride_with_dilation is None: 121 | # each element in the tuple indicates if we should replace 122 | # the 2x2 stride with a dilated convolution instead 123 | replace_stride_with_dilation = [False, False, False] 124 | if len(replace_stride_with_dilation) != 3: 125 | raise ValueError("replace_stride_with_dilation should be None " 126 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 127 | self.groups = groups 128 | self.base_width = width_per_group 129 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 130 | bias=False) 131 | self.bn1 = norm_layer(self.inplanes) 132 | self.relu = nn.ReLU(inplace=True) 133 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 134 | self.layer1 = self._make_layer(block, 64, layers[0]) 135 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 136 | dilate=replace_stride_with_dilation[0]) 137 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 138 | dilate=replace_stride_with_dilation[1]) 139 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 140 | dilate=replace_stride_with_dilation[2]) 141 | # self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 142 | self.avgpool = nn.AvgPool2d(7, stride=1) 143 | self.embed = nn.Linear(512 * block.expansion, num_classes) 144 | 145 | for m in self.modules(): 146 | if isinstance(m, nn.Conv2d): 147 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 148 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 149 | nn.init.constant_(m.weight, 1) 150 | nn.init.constant_(m.bias, 0) 151 | 152 | # Zero-initialize the last BN in each residual branch, 153 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 154 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 155 | if zero_init_residual: 156 | for m in self.modules(): 157 | if isinstance(m, Bottleneck): 158 | nn.init.constant_(m.bn3.weight, 0) 159 | elif isinstance(m, BasicBlock): 160 | nn.init.constant_(m.bn2.weight, 0) 161 | 162 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 163 | norm_layer = self._norm_layer 164 | downsample = None 165 | previous_dilation = self.dilation 166 | if dilate: 167 | self.dilation *= stride 168 | stride = 1 169 | if stride != 1 or self.inplanes != planes * block.expansion: 170 | downsample = nn.Sequential( 171 | conv1x1(self.inplanes, planes * block.expansion, stride), 172 | norm_layer(planes * block.expansion), 173 | ) 174 | 175 | layers = [] 176 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 177 | self.base_width, previous_dilation, norm_layer)) 178 | self.inplanes = planes * block.expansion 179 | for _ in range(1, blocks): 180 | layers.append(block(self.inplanes, planes, groups=self.groups, 181 | base_width=self.base_width, dilation=self.dilation, 182 | norm_layer=norm_layer)) 183 | 184 | return nn.Sequential(*layers) 185 | 186 | def forward(self, x, is_feat=False): 187 | # See note [TorchScript super()] 188 | x = self.conv1(x) 189 | x = self.bn1(x) 190 | x = self.relu(x) 191 | x = self.maxpool(x) 192 | 193 | x = self.layer1(x) 194 | f1 = x 195 | x = self.layer2(x) 196 | f2 = x 197 | x = self.layer3(x) 198 | f3 = x 199 | x = self.layer4(x) 200 | f4 = x 201 | 202 | x = self.avgpool(x) 203 | x = torch.flatten(x, 1) 204 | # x = self.fc(x) 205 | # x = x.view(x.size(0), -1) 206 | 207 | if is_feat: 208 | return [f1, f2, f3, f4], x 209 | else: 210 | return x 211 | 212 | 213 | class Auxiliary_Classifier(nn.Module): 214 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 215 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 216 | norm_layer=None): 217 | super(Auxiliary_Classifier, self).__init__() 218 | 219 | self.dilation = 1 220 | self.groups = groups 221 | self.base_width = width_per_group 222 | self.inplanes = 64 * block.expansion 223 | self.block_extractor1 = nn.Sequential(*[self._make_layer(block, 128, layers[1], stride=2), 224 | self._make_layer(block, 256, layers[2], stride=2), 225 | self._make_layer(block, 512, layers[3], stride=2)]) 226 | 227 | self.inplanes = 128 * block.expansion 228 | self.block_extractor2 = nn.Sequential(*[self._make_layer(block, 256, layers[2], stride=2), 229 | self._make_layer(block, 512, layers[3], stride=2)]) 230 | 231 | self.inplanes = 256 * block.expansion 232 | self.block_extractor3 = nn.Sequential(*[self._make_layer(block, 512, layers[3], stride=2)]) 233 | 234 | self.inplanes = 512 * block.expansion 235 | self.block_extractor4 = nn.Sequential(*[self._make_layer(block, 512, layers[3], stride=1)]) 236 | 237 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 238 | self.fc1 = nn.Linear(512 * block.expansion, num_classes) 239 | self.fc2 = nn.Linear(512 * block.expansion, num_classes) 240 | self.fc3 = nn.Linear(512 * block.expansion, num_classes) 241 | self.fc4 = nn.Linear(512 * block.expansion, num_classes) 242 | 243 | for m in self.modules(): 244 | if isinstance(m, nn.Conv2d): 245 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 246 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 247 | nn.init.constant_(m.weight, 1) 248 | nn.init.constant_(m.bias, 0) 249 | 250 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 251 | norm_layer = nn.BatchNorm2d 252 | downsample = None 253 | previous_dilation = self.dilation 254 | if dilate: 255 | self.dilation *= stride 256 | stride = 1 257 | if stride != 1 or self.inplanes != planes * block.expansion: 258 | downsample = nn.Sequential( 259 | conv1x1(self.inplanes, planes * block.expansion, stride), 260 | norm_layer(planes * block.expansion), 261 | ) 262 | 263 | layers = [] 264 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 265 | self.base_width, previous_dilation, norm_layer)) 266 | self.inplanes = planes * block.expansion 267 | for _ in range(1, blocks): 268 | layers.append(block(self.inplanes, planes, groups=self.groups, 269 | base_width=self.base_width, dilation=self.dilation, 270 | norm_layer=norm_layer)) 271 | 272 | return nn.Sequential(*layers) 273 | 274 | def forward(self, x): 275 | ss_logits = [] 276 | for i in range(len(x)): 277 | idx = i + 1 278 | 279 | out = getattr(self, 'block_extractor' + str(idx))(x[i]) 280 | out = self.avg_pool(out) 281 | out = out.view(out.size(0), -1) 282 | out = getattr(self, 'fc' + str(idx))(out) 283 | ss_logits.append(out) 284 | return ss_logits 285 | 286 | 287 | class ResNet_Auxiliary(nn.Module): 288 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False): 289 | super(ResNet_Auxiliary, self).__init__() 290 | self.backbone = ResNet(block, layers, num_classes=num_classes, zero_init_residual=zero_init_residual) 291 | self.auxiliary_classifier = Auxiliary_Classifier(block, layers, num_classes=num_classes * 4, zero_init_residual=zero_init_residual) 292 | 293 | def forward(self, x, grad=False): 294 | if grad is False: 295 | feats, logit = self.backbone(x, is_feat=True) 296 | for i in range(len(feats)): 297 | feats[i] = feats[i].detach() 298 | else: 299 | feats, logit = self.backbone(x, is_feat=True) 300 | 301 | ss_logits = self.auxiliary_classifier(feats) 302 | return logit, ss_logits 303 | 304 | 305 | def resnet18_imagenet(**kwargs): 306 | return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 307 | 308 | 309 | def resnet18_imagenet_aux(**kwargs): 310 | return ResNet_Auxiliary(BasicBlock, [2, 2, 2, 2], **kwargs) 311 | 312 | 313 | def resnet34_imagenet(**kwargs): 314 | return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 315 | 316 | 317 | def resnet34_imagenet_aux(**kwargs): 318 | return ResNet_Auxiliary(BasicBlock, [3, 4, 6, 3], **kwargs) 319 | 320 | 321 | def resnet50_imagenet(**kwargs): 322 | return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 323 | 324 | 325 | def resnet50_imagenet_aux(**kwargs): 326 | return ResNet_Auxiliary(Bottleneck, [3, 4, 6, 3], **kwargs) 327 | 328 | 329 | if __name__ == '__main__': 330 | x = torch .randn(2, 3, 224, 224) 331 | net = resnet18_imagenet(num_classes=100) 332 | y_hat = net(x) 333 | print(y_hat.shape) 334 | -------------------------------------------------------------------------------- /models/resnet_aug.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | 5 | 6 | __all__ = ['resnet18_cifar', 'resnet18_cifar_aux', 'resnet34_cifar', 7 | 'resnet34_cifar_aux', 'resnet50_cifar','resnet50_cifar_aux'] 8 | 9 | 10 | 11 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 12 | """3x3 convolution with padding""" 13 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 14 | padding=dilation, groups=groups, bias=False, dilation=dilation) 15 | 16 | 17 | def conv1x1(in_planes, out_planes, stride=1): 18 | """1x1 convolution""" 19 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 20 | 21 | 22 | class BasicBlock(nn.Module): 23 | expansion = 1 24 | 25 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 26 | base_width=64, dilation=1, norm_layer=None): 27 | super(BasicBlock, self).__init__() 28 | if norm_layer is None: 29 | norm_layer = nn.BatchNorm2d 30 | if groups != 1 or base_width != 64: 31 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 32 | if dilation > 1: 33 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 34 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 35 | self.conv1 = conv3x3(inplanes, planes, stride) 36 | self.bn1 = norm_layer(planes) 37 | self.relu = nn.ReLU(inplace=True) 38 | self.conv2 = conv3x3(planes, planes) 39 | self.bn2 = norm_layer(planes) 40 | self.downsample = downsample 41 | self.stride = stride 42 | 43 | def forward(self, x): 44 | identity = x 45 | 46 | out = self.conv1(x) 47 | out = self.bn1(out) 48 | out = self.relu(out) 49 | 50 | out = self.conv2(out) 51 | out = self.bn2(out) 52 | 53 | if self.downsample is not None: 54 | identity = self.downsample(x) 55 | 56 | out += identity 57 | out = self.relu(out) 58 | 59 | return out 60 | 61 | 62 | class Bottleneck(nn.Module): 63 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 64 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 65 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 66 | # This variant is also known as ResNet V1.5 and improves accuracy according to 67 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 68 | 69 | expansion = 4 70 | 71 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 72 | base_width=64, dilation=1, norm_layer=None): 73 | super(Bottleneck, self).__init__() 74 | if norm_layer is None: 75 | norm_layer = nn.BatchNorm2d 76 | width = int(planes * (base_width / 64.)) * groups 77 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 78 | self.conv1 = conv1x1(inplanes, width) 79 | self.bn1 = norm_layer(width) 80 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 81 | self.bn2 = norm_layer(width) 82 | self.conv3 = conv1x1(width, planes * self.expansion) 83 | self.bn3 = norm_layer(planes * self.expansion) 84 | self.relu = nn.ReLU(inplace=True) 85 | self.downsample = downsample 86 | self.stride = stride 87 | 88 | def forward(self, x): 89 | identity = x 90 | 91 | out = self.conv1(x) 92 | out = self.bn1(out) 93 | out = self.relu(out) 94 | 95 | out = self.conv2(out) 96 | out = self.bn2(out) 97 | out = self.relu(out) 98 | 99 | out = self.conv3(out) 100 | out = self.bn3(out) 101 | 102 | if self.downsample is not None: 103 | identity = self.downsample(x) 104 | 105 | out += identity 106 | out = self.relu(out) 107 | 108 | return out 109 | 110 | 111 | class ResNet(nn.Module): 112 | 113 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 114 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 115 | norm_layer=None): 116 | super(ResNet, self).__init__() 117 | if norm_layer is None: 118 | norm_layer = nn.BatchNorm2d 119 | self._norm_layer = norm_layer 120 | 121 | self.inplanes = 64 122 | self.dilation = 1 123 | if replace_stride_with_dilation is None: 124 | # each element in the tuple indicates if we should replace 125 | # the 2x2 stride with a dilated convolution instead 126 | replace_stride_with_dilation = [False, False, False] 127 | if len(replace_stride_with_dilation) != 3: 128 | raise ValueError("replace_stride_with_dilation should be None " 129 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 130 | self.groups = groups 131 | self.base_width = width_per_group 132 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, 133 | bias=False) 134 | self.bn1 = norm_layer(self.inplanes) 135 | self.relu = nn.ReLU(inplace=True) 136 | # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 137 | self.layer1 = self._make_layer(block, 64, layers[0]) 138 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 139 | dilate=replace_stride_with_dilation[0]) 140 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 141 | dilate=replace_stride_with_dilation[1]) 142 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 143 | dilate=replace_stride_with_dilation[2]) 144 | # self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 145 | self.embed = nn.Linear(512 * block.expansion, num_classes) 146 | # self.embed = nn.Sequential(nn.Linear(512 * block.expansion, 512 * block.expansion), 147 | # nn.LeakyReLU(0.2), 148 | # nn.Linear(512 * block.expansion, num_classes)) 149 | 150 | for m in self.modules(): 151 | if isinstance(m, nn.Conv2d): 152 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 153 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 154 | nn.init.constant_(m.weight, 1) 155 | nn.init.constant_(m.bias, 0) 156 | 157 | # Zero-initialize the last BN in each residual branch, 158 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 159 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 160 | if zero_init_residual: 161 | for m in self.modules(): 162 | if isinstance(m, Bottleneck): 163 | nn.init.constant_(m.bn3.weight, 0) 164 | elif isinstance(m, BasicBlock): 165 | nn.init.constant_(m.bn2.weight, 0) 166 | 167 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 168 | norm_layer = self._norm_layer 169 | downsample = None 170 | previous_dilation = self.dilation 171 | if dilate: 172 | self.dilation *= stride 173 | stride = 1 174 | if stride != 1 or self.inplanes != planes * block.expansion: 175 | downsample = nn.Sequential( 176 | conv1x1(self.inplanes, planes * block.expansion, stride), 177 | norm_layer(planes * block.expansion), 178 | ) 179 | 180 | layers = [] 181 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 182 | self.base_width, previous_dilation, norm_layer)) 183 | self.inplanes = planes * block.expansion 184 | for _ in range(1, blocks): 185 | layers.append(block(self.inplanes, planes, groups=self.groups, 186 | base_width=self.base_width, dilation=self.dilation, 187 | norm_layer=norm_layer)) 188 | 189 | return nn.Sequential(*layers) 190 | 191 | def forward(self, x, is_feat=False): 192 | # See note [TorchScript super()] 193 | x = self.conv1(x) 194 | x = self.bn1(x) 195 | x = self.relu(x) 196 | # x = self.maxpool(x) 197 | 198 | x = self.layer1(x) 199 | f1 = x 200 | x = self.layer2(x) 201 | f2 = x 202 | x = self.layer3(x) 203 | f3 = x 204 | x = self.layer4(x) 205 | f4 = x 206 | 207 | # x = self.avgpool(x) 208 | x = nn.functional.avg_pool2d(x, 4) 209 | x = torch.flatten(x, 1) 210 | # x = self.fc(x) 211 | 212 | if is_feat: 213 | return [f1, f2, f3, f4], x 214 | else: 215 | return x 216 | 217 | 218 | class Auxiliary_Classifier(nn.Module): 219 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 220 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 221 | norm_layer=None): 222 | super(Auxiliary_Classifier, self).__init__() 223 | 224 | self.dilation = 1 225 | self.groups = groups 226 | self.base_width = width_per_group 227 | self.inplanes = 64 * block.expansion 228 | self.block_extractor1 = nn.Sequential(*[self._make_layer(block, 128, layers[1], stride=2), 229 | self._make_layer(block, 256, layers[2], stride=2), 230 | self._make_layer(block, 512, layers[3], stride=2)]) 231 | 232 | self.inplanes = 128 * block.expansion 233 | self.block_extractor2 = nn.Sequential(*[self._make_layer(block, 256, layers[2], stride=2), 234 | self._make_layer(block, 512, layers[3], stride=2)]) 235 | 236 | self.inplanes = 256 * block.expansion 237 | self.block_extractor3 = nn.Sequential(*[self._make_layer(block, 512, layers[3], stride=2)]) 238 | 239 | self.inplanes = 512 * block.expansion 240 | self.block_extractor4 = nn.Sequential(*[self._make_layer(block, 512, layers[3], stride=1)]) 241 | 242 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 243 | self.fc1 = nn.Linear(512 * block.expansion, num_classes) 244 | self.fc2 = nn.Linear(512 * block.expansion, num_classes) 245 | self.fc3 = nn.Linear(512 * block.expansion, num_classes) 246 | self.fc4 = nn.Linear(512 * block.expansion, num_classes) 247 | 248 | for m in self.modules(): 249 | if isinstance(m, nn.Conv2d): 250 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 251 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 252 | nn.init.constant_(m.weight, 1) 253 | nn.init.constant_(m.bias, 0) 254 | 255 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 256 | norm_layer = nn.BatchNorm2d 257 | downsample = None 258 | previous_dilation = self.dilation 259 | if dilate: 260 | self.dilation *= stride 261 | stride = 1 262 | if stride != 1 or self.inplanes != planes * block.expansion: 263 | downsample = nn.Sequential( 264 | conv1x1(self.inplanes, planes * block.expansion, stride), 265 | norm_layer(planes * block.expansion), 266 | ) 267 | 268 | layers = [] 269 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 270 | self.base_width, previous_dilation, norm_layer)) 271 | self.inplanes = planes * block.expansion 272 | for _ in range(1, blocks): 273 | layers.append(block(self.inplanes, planes, groups=self.groups, 274 | base_width=self.base_width, dilation=self.dilation, 275 | norm_layer=norm_layer)) 276 | 277 | return nn.Sequential(*layers) 278 | 279 | 280 | def forward(self, x): 281 | ss_logits = [] 282 | for i in range(len(x)): 283 | idx = i + 1 284 | 285 | out = getattr(self, 'block_extractor'+str(idx))(x[i]) 286 | out = self.avg_pool(out) 287 | out = out.view(out.size(0), -1) 288 | out = getattr(self, 'fc'+str(idx))(out) 289 | ss_logits.append(out) 290 | return ss_logits 291 | 292 | 293 | class ResNet_Auxiliary(nn.Module): 294 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False): 295 | super(ResNet_Auxiliary, self).__init__() 296 | self.backbone = ResNet(block, layers, num_classes=num_classes, zero_init_residual=zero_init_residual) 297 | self.auxiliary_classifier = Auxiliary_Classifier(block, layers, num_classes=num_classes*4, zero_init_residual=zero_init_residual) 298 | 299 | def forward(self, x, grad=False): 300 | if grad is False: 301 | feats, logit = self.backbone(x, is_feat=True) 302 | for i in range(len(feats)): 303 | feats[i] = feats[i].detach() 304 | else: 305 | feats, logit = self.backbone(x, is_feat=True) 306 | 307 | ss_logits = self.auxiliary_classifier(feats) 308 | return logit, ss_logits 309 | 310 | 311 | def resnet18_cifar(**kwargs): 312 | return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 313 | 314 | def resnet18_cifar_aux(**kwargs): 315 | return ResNet_Auxiliary(BasicBlock, [2, 2, 2, 2], **kwargs) 316 | 317 | def resnet34_cifar(**kwargs): 318 | return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 319 | def resnet34_cifar_aux(**kwargs): 320 | return ResNet_Auxiliary(BasicBlock, [3, 4, 6, 3], **kwargs) 321 | 322 | def resnet50_cifar(**kwargs): 323 | return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 324 | 325 | def resnet50_cifar_aux(**kwargs): 326 | return ResNet_Auxiliary(Bottleneck, [3, 4, 6, 3], **kwargs) 327 | 328 | def init_normal(m): 329 | if type(m) == torch.nn.Linear: 330 | torch.nn.init.uniform_(m.weight) 331 | 332 | class resnet18_pretrain(nn.Module): 333 | 334 | def __init__(self, args): 335 | super(resnet18_pretrain, self).__init__() 336 | self.feat_dim = args.feat_dim 337 | self.num_class = args.num_class 338 | model = torchvision.models.resnet18(pretrained=True) 339 | modules = list(model.children()) 340 | modules = modules[:3] + modules[4:-1] 341 | self.backbone = nn.Sequential(*modules, nn.Flatten()) 342 | self.backbone.eval() 343 | for p in self.backbone.parameters(): 344 | p.requires_grad = False 345 | # self.embed = nn.Sequential( 346 | # nn.Linear(self.feat_dim, self.feat_dim), 347 | # nn.LeakyReLU(), 348 | # nn.Linear(self.feat_dim, self.feat_dim), 349 | # nn.LeakyReLU(), 350 | # nn.Linear(self.feat_dim, self.num_class), 351 | # nn.Sigmoid(), 352 | # ) 353 | self.embed = nn.Sequential( 354 | nn.Linear(self.feat_dim, self.num_class) 355 | ) 356 | 357 | self.embed.apply(init_normal) 358 | 359 | def forward(self, x): 360 | feas = self.backbone(x) 361 | # torch.flatten(feas, 1) 362 | y_hat = self.embed(feas) 363 | return y_hat 364 | 365 | 366 | 367 | if __name__ == '__main__': 368 | x = torch .randn(2, 3, 32, 32) 369 | 370 | class params(object): 371 | def __init__(self): 372 | self.feat_dim = 512 373 | self.num_class = 100 374 | args = params() 375 | net = resnet18_pretrain(args) 376 | y_hat = net(x) 377 | print(y_hat.shape) 378 | 379 | # x = torch.randn(2, 3, 32, 32) 380 | # net = resnet18_cifar_aux(num_classes=100) 381 | # logit, ss_logits = net(x) 382 | # print(logit.size()) 383 | 384 | # net = resnet34_cifar_aux(num_classes=1000) 385 | # from utils.utils import cal_param_size, cal_multi_adds 386 | # print('Params: %.2fM, Multi-adds: %.3fM' 387 | # % (cal_param_size(net) / 1e6, cal_multi_adds(net, (2, 3, 224, 224)) / 1e6)) 388 | -------------------------------------------------------------------------------- /eTag_train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env Python 2 | # coding=utf-8 3 | from __future__ import absolute_import, print_function 4 | import torch.utils.data 5 | from torch.utils.tensorboard import SummaryWriter 6 | from torch.backends import cudnn 7 | from models.resnet_aug import resnet18_cifar_aux 8 | from models.resnet_imagenet import resnet18_imagenet_aux 9 | 10 | from utils import mkdir_if_missing, logging, display 11 | from torch.optim.lr_scheduler import StepLR 12 | from ImageFolder import * 13 | import torch.utils.data 14 | import torchvision.transforms as transforms 15 | from models.cvae import CVAE_Cifar 16 | from CIFAR100 import CIFAR100 17 | from copy import deepcopy 18 | from opts_eTag import get_train_args 19 | import sys 20 | 21 | cudnn.benchmark = True 22 | 23 | def to_binary(labels,args): 24 | # Y_onehot is used to generate one-hot encoding 25 | y_onehot = torch.FloatTensor(len(labels), args.num_class) 26 | y_onehot.zero_() 27 | y_onehot.scatter_(1, labels.cpu()[:,None], 1) 28 | code_binary = y_onehot.cuda() 29 | return code_binary 30 | 31 | def get_model(model): 32 | return deepcopy(model.state_dict()) 33 | 34 | def set_model_(model, state_dict): 35 | model.load_state_dict(deepcopy(state_dict)) 36 | return model 37 | 38 | def freeze_model(model): 39 | for param in model.parameters(): 40 | param.requires_grad = False 41 | return model 42 | 43 | class DistillKL(torch.nn.Module): 44 | """Distilling the Knowledge in a Neural Network""" 45 | def __init__(self, T): 46 | super(DistillKL, self).__init__() 47 | self.T = T 48 | 49 | def forward(self, y_s, y_t): 50 | p_s = torch.nn.functional.log_softmax(y_s/self.T, dim=1) 51 | p_t = torch.nn.functional.softmax(y_t/self.T, dim=1) 52 | loss = torch.nn.functional.kl_div(p_s, p_t, reduction='batchmean') * (self.T**2) 53 | return loss 54 | 55 | def train_task(args, train_loader, current_task, pre_index=0): 56 | num_class_per_task = (args.num_class-args.nb_cl_fg) // args.num_task 57 | if num_class_per_task==0: 58 | pass # JT 59 | else: 60 | old_task_factor = args.nb_cl_fg // num_class_per_task + current_task - 1 61 | log_dir = args.log_dir 62 | mkdir_if_missing(log_dir) 63 | 64 | sys.stdout = logging.Logger(os.path.join(log_dir, 'log_task{}.txt'.format(current_task))) 65 | tb_writer = SummaryWriter(log_dir) 66 | display(args) 67 | # One-hot encoding or attribute encoding 68 | if 'imagenet' in args.data: 69 | model = resnet18_imagenet_aux(num_classes=args.num_class) 70 | elif 'cifar' in args.data: 71 | model = resnet18_cifar_aux(num_classes = args.num_class) 72 | 73 | if current_task > 0: # TODO 74 | model = torch.load(os.path.join(log_dir, 'task_' + str(current_task - 1).zfill(2) + '_%d_model.pkl' % int(args.epochs - 1))) 75 | model_old = deepcopy(model) 76 | model_old.eval() 77 | model_old = freeze_model(model_old) 78 | 79 | model = model.cuda() 80 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 81 | scheduler = StepLR(optimizer, step_size=args.lr_decay_step, gamma=args.lr_decay) 82 | loss_mse = torch.nn.MSELoss(reduction='sum') 83 | 84 | # Initialize generator and discriminator 85 | if current_task == 0: 86 | args.encoder_layer_sizes = [args.feat_dim, args.hidden_dim, args.hidden_dim] 87 | args.decoder_layer_sizes = [args.hidden_dim, args.hidden_dim, args.feat_dim] 88 | args.latent_size = args.latent_dim 89 | args.class_dim = args.num_class 90 | cvae = CVAE_Cifar(args) 91 | else: 92 | cvae = torch.load(os.path.join(log_dir, 'task_' + str(current_task - 1).zfill(2) + '_%d_model_vae.pkl' % int(args.epochs_vae - 1))) 93 | cvae_old = deepcopy(cvae) 94 | cvae_old.eval() 95 | cvae_old = freeze_model(cvae_old) 96 | cvae = cvae.cuda() 97 | 98 | optimizer_cvae = torch.optim.Adam(cvae.parameters(), lr=vae_lr) 99 | 100 | for p in cvae.parameters(): # set requires_grad to False 101 | p.requires_grad = False 102 | 103 | ###############################################################Feature extractor training#################################################### 104 | if current_task > 0: 105 | model = model.eval() 106 | if not os.path.exists(os.path.join(log_dir, 'task_' + str(current_task).zfill(2) + '_%d_model.pkl' % (args.epochs-1))): 107 | for epoch in range(args.epochs): 108 | 109 | loss_log = {'C/loss': 0.0, 110 | 'C/cls_previous': 0.0, 111 | 'C/cls_current': 0.0, 112 | 'C/aug_blockORkd': 0.0, 113 | 'C/aug_dist': 0.0} 114 | # scheduler.step() 115 | for batch, data in enumerate(train_loader, 0): 116 | inputs1, labels1 = data 117 | inputs1, labels1 = inputs1.cuda(), labels1.cuda() 118 | inputs, labels = inputs1, labels1 #! 119 | size = inputs.shape[1:] 120 | inputs = torch.stack([torch.rot90(inputs, k, (2, 3)) for k in range(4)], 1).view(-1, *size) 121 | labels = torch.stack([labels * 4 + i for i in range(4)], 1).view(-1) 122 | 123 | embed_feat, ss_logits = model(inputs, grad=True) 124 | 125 | loss = torch.zeros(1).cuda() 126 | cls_current = torch.zeros(1).cuda() # classification loss for the current task 127 | cls_previous = torch.zeros(1).cuda() # classification loss for the previous tasks 128 | aug_blockORkd = torch.zeros(1).cuda() # augmented loss for the block-wise feature map or the knowledge distillation 129 | aug_dist = torch.zeros(1).cuda() # augmented loss for the distance of the last blocks between previous and current model 130 | 131 | optimizer.zero_grad() 132 | if current_task == 0: 133 | soft_feat = model.backbone.embed(embed_feat) 134 | for i in range(len(ss_logits)): 135 | aug_blockORkd = aug_blockORkd + torch.nn.CrossEntropyLoss()(ss_logits[i], labels) 136 | cls_current = torch.nn.CrossEntropyLoss()(soft_feat[0::4], labels1) 137 | loss = cls_current + aug_blockORkd 138 | else: 139 | embed_feat_old, ss_logits_old = model_old(inputs) 140 | aug_dist = torch.dist(embed_feat, embed_feat_old, 2) 141 | aug_dist = old_task_factor * aug_dist 142 | for i in range(len(ss_logits)-1): 143 | aug_blockORkd = aug_blockORkd + DistillKL(args.tau)(ss_logits[i], ss_logits_old[i]) 144 | aug_blockORkd = old_task_factor * aug_blockORkd 145 | 146 | embed_sythesis = [] 147 | embed_label_sythesis = [] 148 | ind = list(range(len(pre_index))) 149 | for _ in range(args.BatchSize): 150 | np.random.shuffle(ind) 151 | embed_label_sythesis.append(pre_index[ind[0]]) 152 | embed_label_sythesis = np.asarray(embed_label_sythesis) 153 | embed_label_sythesis = torch.from_numpy(embed_label_sythesis).cuda() 154 | 155 | z = torch.Tensor(np.random.normal(0, 1, (args.BatchSize, args.latent_dim))).cuda() 156 | embed_sythesis = cvae.inference(z, c=embed_label_sythesis) 157 | embed_sythesis = torch.cat((embed_feat[0::4],embed_sythesis)) 158 | embed_label_sythesis = torch.cat((labels1,embed_label_sythesis.cuda())) 159 | soft_feat_syt = model.backbone.embed(embed_sythesis) 160 | batch_size1 = inputs1.shape[0] 161 | cls_current = torch.nn.CrossEntropyLoss()(soft_feat_syt[:batch_size1], embed_label_sythesis[:batch_size1]) 162 | cls_current = cls_current / (old_task_factor + 1) 163 | cls_previous = torch.nn.CrossEntropyLoss()(soft_feat_syt[batch_size1:], embed_label_sythesis[batch_size1:]) 164 | cls_previous = (cls_previous * old_task_factor) / (old_task_factor + 1) 165 | loss = cls_current + cls_previous + (aug_blockORkd + aug_dist) * args.tradeoff 166 | loss.backward() 167 | torch.nn.utils.clip_grad_norm_(model.parameters(), 10000) 168 | optimizer.step() 169 | loss_log['C/loss'] += loss.item() 170 | loss_log['C/cls_current'] += cls_current.item() 171 | loss_log['C/cls_previous'] += cls_previous.item() 172 | loss_log['C/aug_blockORkd'] += aug_blockORkd.item() 173 | loss_log['C/aug_dist'] += aug_dist.item() 174 | if epoch == 0 and batch == 0: 175 | print(50 * '#') 176 | scheduler.step() 177 | print('[CLS %05d]\t C/loss: %.3f \t C/cls_current: %.3f \t C/aug_blockORkd: %.3f \t C/cls_previous: %.3f \t C/aug_dist: %.3f' 178 | % (epoch + 1, loss_log['C/loss'], loss_log['C/cls_current'], loss_log['C/aug_blockORkd'], loss_log['C/cls_previous'], loss_log['C/aug_dist'])) 179 | for k, v in loss_log.items(): 180 | if v != 0: 181 | tb_writer.add_scalar('Task {} - Classifier/{}'.format(current_task, k), v, epoch + 1) 182 | 183 | if epoch == args.epochs-1: 184 | torch.save(model, os.path.join(log_dir, 'task_' + str(current_task).zfill(2) + '_%d_model.pkl' % epoch)) 185 | else: 186 | model = torch.load(os.path.join(log_dir, 'task_' + str(current_task).zfill(2) + '_%d_model.pkl' % (args.epochs-1))) 187 | 188 | ################################################################## CVAE Training stage#################################################### 189 | model = model.eval() 190 | for p in model.parameters(): 191 | p.requires_grad = False 192 | for p in cvae.parameters(): 193 | p.requires_grad = True 194 | if current_task != args.num_task: 195 | for epoch in range(args.epochs_vae): 196 | loss_log = {'V/loss': 0.0, 'V/var': 0.0, 'V/rec': 0.0, 'V/cls_current': 0.0, 'V/cls_previous': 0.0, 'V/aug': 0.0} 197 | # scheduler_VAE.step() 198 | for batch, data in enumerate(train_loader, 0): 199 | inputs, labels = data 200 | inputs, labels = inputs.cuda(), labels.cuda() 201 | real_feat = model.backbone(inputs) 202 | optimizer_cvae.zero_grad() 203 | fake_feat, mu, logvar, z = cvae(real_feat, labels) 204 | loss_rec = torch.nn.MSELoss(reduction='sum')(real_feat, fake_feat) / real_feat.size(0) 205 | loss_var = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / real_feat.size(0) 206 | if current_task == 0: 207 | loss_aug, cls_previous = torch.zeros(1).cuda(), torch.zeros(1).cuda() 208 | fake_feat_soft = model.backbone.embed(fake_feat) 209 | cls_current = torch.nn.CrossEntropyLoss()(fake_feat_soft, labels) 210 | else: 211 | # labels of pre-tasks 212 | ind = list(range(len(pre_index))) 213 | embed_label_sythesis = [] 214 | for _ in range(args.BatchSize): 215 | np.random.shuffle(ind) 216 | embed_label_sythesis.append(pre_index[ind[0]]) 217 | embed_label_sythesis = np.asarray(embed_label_sythesis) 218 | embed_label_sythesis = torch.from_numpy(embed_label_sythesis).cuda() 219 | z = torch.Tensor(np.random.normal(0, 1, (args.BatchSize, args.latent_dim))).cuda() 220 | pre_feat = cvae.inference(z, embed_label_sythesis) 221 | fake_feat_soft = model.backbone.embed(fake_feat) 222 | cls_current = torch.nn.CrossEntropyLoss()(fake_feat_soft, labels.cuda()) 223 | cls_current = cls_current / (old_task_factor + 1) 224 | fake_feat_soft_ = model.backbone.embed(pre_feat) 225 | cls_previous = torch.nn.CrossEntropyLoss()(fake_feat_soft_, embed_label_sythesis.cuda()) * old_task_factor 226 | cls_previous = cls_previous / (old_task_factor + 1) 227 | pre_feat_old = cvae_old.inference(z, embed_label_sythesis) 228 | loss_aug = loss_mse(pre_feat, pre_feat_old) * old_task_factor 229 | cvae_loss = loss_rec + loss_var + (cls_previous + cls_current) * args.tradeoff + loss_aug * args.vae_tradeoff 230 | loss_log['V/loss'] += cvae_loss.item() 231 | loss_log['V/var'] += loss_var.item() 232 | loss_log['V/rec'] += loss_rec.item() 233 | loss_log['V/cls_current'] += cls_current.item() * args.tradeoff 234 | loss_log['V/cls_previous'] += cls_previous.item() * args.tradeoff 235 | loss_log['V/aug'] += loss_aug.item() * args.vae_tradeoff 236 | cvae_loss.backward() 237 | optimizer_cvae.step() 238 | print('[CVAE %05d]\t V/loss: %.3f \t V/var: %.3f \t V/rec: %.3f \t V/cls_current: %.3f \t V/cls_previous: %.3f \t V/aug: %.3f' % 239 | (epoch + 1, loss_log['V/loss'], loss_log['V/var'], loss_log['V/rec'], loss_log['V/cls_current'], loss_log['V/cls_previous'], loss_log['V/aug'])) 240 | for k, v in loss_log.items(): 241 | if v != 0: 242 | tb_writer.add_scalar('Task {} - VAE/{}'.format(current_task, k), v, epoch + 1) 243 | torch.save(cvae, os.path.join(log_dir, 'task_' + str(current_task).zfill(2) + '_%d_model_vae.pkl' % (args.epochs_vae - 1))) 244 | tb_writer.close() 245 | 246 | 247 | if __name__ == '__main__': 248 | 249 | args = get_train_args() 250 | 251 | # Data 252 | print('==> Preparing data..') 253 | 254 | if args.data == 'imagenet_sub' or args.data == 'imagenet_full': 255 | mean_values = [0.485, 0.456, 0.406] 256 | std_values = [0.229, 0.224, 0.225] 257 | transform_train = transforms.Compose([ 258 | #transforms.Resize(256), 259 | transforms.RandomResizedCrop(224), 260 | transforms.RandomHorizontalFlip(), 261 | transforms.ToTensor(), 262 | transforms.Normalize(mean=mean_values, 263 | std=std_values) 264 | ]) 265 | traindir = os.path.join(args.dir, 'ILSVRC12_256', 'train') 266 | 267 | if args.data == 'cifar100': 268 | transform_train = transforms.Compose([ 269 | transforms.RandomCrop(32, padding=4), 270 | transforms.RandomHorizontalFlip(), 271 | transforms.ToTensor(), 272 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 273 | ]) 274 | traindir = args.dir + '/cifar100' 275 | 276 | num_classes = args.num_class 277 | num_task = args.num_task 278 | num_class_per_task = (num_classes-args.nb_cl_fg) // num_task 279 | 280 | random_perm = list(range(num_classes)) # multihead fails if random permutaion here 281 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 282 | 283 | for i in range(args.start, num_task+1): 284 | print("-------------------Get started--------------- ") 285 | print("Training on Task " + str(i)) 286 | if i == 0: 287 | pre_index = 0 288 | class_index = random_perm[:args.nb_cl_fg] 289 | else: 290 | pre_index = random_perm[:args.nb_cl_fg + (i-1) * num_class_per_task] 291 | class_index = random_perm[args.nb_cl_fg + (i-1) * num_class_per_task:args.nb_cl_fg + (i) * num_class_per_task] 292 | 293 | if args.data == 'cifar100': 294 | np.random.seed(args.seed) 295 | torch.manual_seed(args.seed) 296 | torch.cuda.manual_seed_all(args.seed) 297 | 298 | target_transform = np.random.permutation(num_classes) 299 | trainset = CIFAR100(root=traindir, train=True, download=True, transform=transform_train, target_transform = target_transform, index = class_index) 300 | train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.BatchSize, shuffle=True, num_workers=args.nThreads,drop_last=True) 301 | else: 302 | np.random.seed(args.seed) 303 | torch.manual_seed(args.seed) 304 | torch.cuda.manual_seed_all(args.seed) 305 | 306 | trainfolder = ImageFolder(traindir, transform_train, index=class_index) 307 | train_loader = torch.utils.data.DataLoader( 308 | trainfolder, batch_size=args.BatchSize, 309 | shuffle=True, 310 | drop_last=True, num_workers=args.nThreads) 311 | train_task(args, train_loader, i, pre_index=pre_index) --------------------------------------------------------------------------------