├── libs ├── loss │ ├── SoftmaxLoss.py │ └── CosineDistill.py ├── utils │ ├── globals.py │ ├── default_config.yaml │ ├── logger.py │ ├── experiments_maker.py │ └── utils.py ├── core │ ├── BaselineCosineCE.py │ ├── CBDistillCE.py │ └── core_base.py ├── models │ ├── ecbd_converter.py │ ├── DotProductClassifier.py │ ├── ResNet50Feature.py │ ├── CosineDotProductClassifier.py │ ├── ResNetFeature.py │ ├── ResNet32Feature.py │ └── ResNet32Featureb4.py ├── samplers │ └── ClassAwareSampler.py └── data │ └── dataloader.py ├── multi_runs.sh ├── Readme.md └── main.py /libs/loss/SoftmaxLoss.py: -------------------------------------------------------------------------------- 1 | # Softmax Loss 2 | 3 | import torch.nn as nn 4 | 5 | 6 | def create_loss(): 7 | print("Loading Softmax Loss.") 8 | return nn.CrossEntropyLoss() 9 | -------------------------------------------------------------------------------- /libs/utils/globals.py: -------------------------------------------------------------------------------- 1 | # Loggers 2 | wandb_log = False 3 | log_offline = False 4 | log_dir = None #For offline logging 5 | 6 | # x-axis 7 | epoch_global = 0 8 | step_global = 0 9 | 10 | # seed 11 | seed = 1 12 | 13 | -------------------------------------------------------------------------------- /libs/loss/CosineDistill.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | import wandb 6 | import libs.utils.globals as g 7 | 8 | 9 | 10 | class CosineDistill(nn.Module): 11 | def __init__(self, beta=100): 12 | super(CosineDistill, self).__init__() 13 | self.beta = beta 14 | 15 | def forward(self, student, teacher): 16 | cos = nn.CosineSimilarity(dim=1) 17 | return self.beta*(1-cos(student, teacher)).mean(dim=0) 18 | 19 | def create_loss(*args): 20 | print("Loading Cosine Distance Loss.") 21 | return CosineDistill(*args) 22 | -------------------------------------------------------------------------------- /multi_runs.sh: -------------------------------------------------------------------------------- 1 | #-----Train 3 Normal teachers 2 | for seeds in 10 20 3 | do 4 | python main.py --experiment=0.1 --seed=$seeds --gpu="0,1" --log_offline 5 | done 6 | 7 | #-----Train 3 Augmentation teachers 8 | for seeds in 20 30 9 | do 10 | python main.py --experiment=0.2 --seed=$seeds --gpu="0,1" --log_offline 11 | done 12 | 13 | #-----Train CBD_ensemble_K 14 | for seeds in 1 15 | do 16 | for alphas in 0.2 0.4 0.8 17 | do 18 | for betas in 50 100 200 19 | do 20 | python main.py --experiment=0.3 --alpha=$alphas --beta=$betas --seed=$seeds --gpu="0,1" --log_offline --normal_teacher="10,20" --aug_teacher="20,30" 21 | done 22 | done 23 | done 24 | -------------------------------------------------------------------------------- /libs/core/BaselineCosineCE.py: -------------------------------------------------------------------------------- 1 | 2 | import matplotlib.pyplot as plt 3 | from torch.optim import optimizer 4 | import libs.utils.globals as g 5 | import os 6 | import copy 7 | import pickle 8 | import torch 9 | import torch.nn as nn 10 | import torch.optim as optim 11 | import torch.nn.functional as F 12 | from tqdm import tqdm 13 | from libs.utils.utils import * 14 | from libs.utils.logger import Logger 15 | import time 16 | import numpy as np 17 | import warnings 18 | import pdb 19 | 20 | from libs.core.core_base import model as base_model 21 | 22 | #----------------------------------------------------- 23 | 24 | # This is there so that we can use source_import from the utils to import model 25 | def get_core(*args): 26 | return base_model(*args) -------------------------------------------------------------------------------- /libs/models/ecbd_converter.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from os import path 3 | import torch 4 | import torch.nn.functional as F 5 | from libs.utils.utils import * 6 | 7 | class ecbd_converter(nn.Module): 8 | def __init__(self, feat_in, feat_out, *args): 9 | super(ecbd_converter, self).__init__() 10 | self.fc = nn.Linear(feat_in, feat_out) 11 | 12 | def forward(self, x, *args): 13 | return self.fc(x) 14 | 15 | 16 | def create_model(feat_in, feat_out, pretrain=False, pretrain_dir=None, *args): 17 | """Initialize the model 18 | 19 | Args: 20 | feat_dim (int): output dimension of the previous feature extractor 21 | num_classes (int, optional): Number of classes. Defaults to 1000. 22 | 23 | Returns: 24 | Class: Model 25 | """ 26 | print("ECBD Converter.") 27 | clf = ecbd_converter(feat_in, feat_out) 28 | 29 | if pretrain: 30 | if path.exists(pretrain_dir): 31 | print("===> Load Pretrain Initialization for CosineDotProductClassfier") 32 | weights = torch.load(pretrain_dir)["state_dict_best"]["classifier"] 33 | 34 | weights = { 35 | k: weights["module." + k] 36 | if "module." + k in weights 37 | else clf.state_dict()[k] 38 | for k in clf.state_dict() 39 | } 40 | clf.load_state_dict(weights) 41 | else: 42 | raise Exception("Pretrain path doesn't exist!!") 43 | else: 44 | print("===> Train classifier from the scratch") 45 | 46 | return clf 47 | -------------------------------------------------------------------------------- /libs/models/DotProductClassifier.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from os import path 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | 8 | class DotProduct_Classifier(nn.Module): 9 | def __init__(self, num_classes=1000, feat_dim=2048, *args): 10 | super(DotProduct_Classifier, self).__init__() 11 | self.fc = nn.Linear(feat_dim, num_classes) 12 | 13 | def forward(self, x, *args): 14 | x = self.fc(x) 15 | return x 16 | 17 | 18 | def create_model(feat_dim, num_classes=1000, pretrain=False, pretrain_dir=None, *args): 19 | """Initialize the model 20 | 21 | Args: 22 | feat_dim (int): output dimension of the previous feature extractor 23 | num_classes (int, optional): Number of classes. Defaults to 1000. 24 | 25 | Returns: 26 | Class: Model 27 | """ 28 | print("Loading Dot Product Classifier.") 29 | clf = DotProduct_Classifier(num_classes, feat_dim) 30 | 31 | if pretrain: 32 | if path.exists(pretrain_dir): 33 | print("===> Load Pretrain Initialization for DotProductClassfier") 34 | weights = torch.load(pretrain_dir)["state_dict_best"]["classifier"] 35 | 36 | weights = { 37 | k: weights["module." + k] 38 | if "module." + k in weights 39 | else clf.state_dict()[k] 40 | for k in clf.state_dict() 41 | } 42 | clf.load_state_dict(weights) 43 | else: 44 | raise Exception("Pretrain path doesn't exist!!") 45 | else: 46 | print("===> Train classifier from the scratch") 47 | 48 | return clf 49 | -------------------------------------------------------------------------------- /libs/utils/default_config.yaml: -------------------------------------------------------------------------------- 1 | criterions: 2 | EmbeddingLoss: 3 | def_file: ./loss/SoftmaxLoss.py 4 | loss_params: {} 5 | optim_params: null 6 | weight: 1.0 7 | PerformanceLoss: 8 | def_file: ./loss/SoftmaxLoss.py 9 | loss_params: {} 10 | optim_params: null 11 | weight: 1.0 12 | ClassifierLoss: 13 | def_file: ./loss/SoftmaxLoss.py 14 | loss_params: {} 15 | optim_params: null 16 | weight: 1.0 17 | endlr: 0.0 18 | last: false 19 | networks: 20 | classifier: 21 | def_file: ./models/DotProductClassifier.py 22 | fix: false 23 | optim_params: {} 24 | scheduler_params: {} 25 | params: {} 26 | embedding: 27 | def_file: ./models/DotProductClassifier.py 28 | fix: false 29 | optim_params: {lr: 0.2, momentum: 0.9, weight_decay: 0.0005} 30 | scheduler_params: {coslr: true} 31 | params: {feat_dim: 128, embedding_dim: 64, num_classes: , pretrain: False, pretrain_dir: None} 32 | feat_model: 33 | def_file: ./models/ResNet32Feature.py 34 | fix: false 35 | optim_params: {lr: 0.2, momentum: 0.9, weight_decay: 0.0005} 36 | scheduler_params: {coslr: true} 37 | params: {pretrain: False, pretrain_dir: None} 38 | shuffle: false 39 | training_opt: 40 | backbone: resnet32 41 | batch_size: 512 42 | accumulation_step: 1 43 | dataset: 44 | display_step: 10 45 | log_dir: ./logs/CIFAR100LT/models/resnet32_normal_learning_CIFAR100_LT_imb10_e100 46 | num_classes: 47 | cifar_imb_ratio: # 0.01, 0.02, 0.1 for 100, 50, 10 48 | # optimizer: adam 49 | num_epochs: 100 50 | num_workers: 12 51 | open_threshold: 0.1 52 | sampler: null 53 | stage: resnet32_normal_learning_CIFAR100_LT_imb10_e100 54 | wandb_tags: ["",""] 55 | 56 | pg: 57 | generate: False 58 | 59 | retrain: 60 | protobias: True 61 | 62 | 63 | -------------------------------------------------------------------------------- /libs/models/ResNet50Feature.py: -------------------------------------------------------------------------------- 1 | """Copyright (c) Facebook, Inc. and its affiliates. 2 | All rights reserved. 3 | 4 | This source code is licensed under the license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | 9 | from libs.models.ResNetFeature import * 10 | from libs.utils.utils import * 11 | from os import path 12 | from collections import OrderedDict 13 | import torch 14 | 15 | def create_model(pretrain=False, pretrain_dir=None, *args): 16 | """Initialize/load the model 17 | 18 | Args: 19 | pretrain (bool, optional): Use pre-trained model?. Defaults to False. 20 | pretrain_dir (str, optional): Directory of the pre-trained model. Defaults to None. 21 | 22 | Returns: 23 | class: Model 24 | """ 25 | 26 | print("Loading ResNet 50 Feature Model.") 27 | resnet50 = ResNet(Bottleneck, [3, 4, 6, 3], use_fc=False, dropout=None) 28 | 29 | if pretrain: 30 | if path.exists(pretrain_dir): 31 | print("===> Load Pretrain Initialization for ResNet50") 32 | model_dict = resnet50.state_dict() 33 | new_dict = load_model(pretrain_dir=pretrain_dir) 34 | model_dict.update(new_dict) 35 | resnet50.load_state_dict(model_dict) 36 | print("Backbone model has been loaded......") 37 | 38 | else: 39 | raise Exception(f"Pretrain path doesn't exist!!-{pretrain_dir}") 40 | else: 41 | print("===> Train backbone from the scratch") 42 | 43 | return resnet50 44 | 45 | def load_model(pretrain_dir): 46 | """Load a pre-trained model 47 | 48 | Args: 49 | pretrain_dir (str): path of pretrained model 50 | """ 51 | print(f"Loading Backbone pretrain model from {pretrain_dir}......") 52 | pretrain_dict = torch.load(pretrain_dir)["state_dict_best"]["feat_model"] 53 | 54 | new_dict = OrderedDict() 55 | 56 | # Removing FC and Classifier layers 57 | for k, v in pretrain_dict.items(): 58 | if k.startswith("module"): 59 | k = k[7:] 60 | if "fc" not in k and "classifier" not in k: 61 | new_dict[k] = v 62 | 63 | return new_dict 64 | -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | # Pytorch Implementation of [Class-Balanced Distillation for Long-Tailed Visual Recognition](https://arxiv.org/abs/2104.05279) by [Ahmet Iscen](https://cmp.felk.cvut.cz/~iscenahm/), André Araujo, Boqing Gong, Cordelia Schmid 2 | --- 3 | ### Note: 4 | - Implemented only for ImageNetLT 5 | - `normal_teachers` is the `Standard model` from the paper 6 | - `aug_teachers` is the `Data Augmentation model` from the paper 7 | 8 | ## Things to do before you run : 9 | - Change the `data_root` for your dataset in `main.py`. 10 | - If you are using wandb logging ([Weights & Biases](https://docs.wandb.ai/quickstart)), make sure to change the `wandb.init` in `main.py` accordingly. 11 | 12 | ## How to use? 13 | - Easy to use : Check this script - `multi_runs.sh` 14 | - Train the normal teachers : 15 | ``` 16 | python main.py --experiment=0.1 --seed=1 --gpu="0,1" --train --log_offline 17 | ``` 18 | - Train the augmentation teachers : 19 | ``` 20 | python main.py --experiment=0.2 --seed=1 --gpu="0,1" --train --log_offline 21 | ``` 22 | - Train the Class Balanced Distilled Student : 23 | ``` 24 | python main.py --experiment=0.3 --alpha=0.4 --beta=100 --seed=$seeds --gpu="0,1" --train --log_offline --normal_teacher="10,20" --aug_teacher="20,30" 25 | ``` 26 | 27 | ### Arguments : 28 | (General) 29 | - `--seed`: Seed of your current run 30 | - `--gpu`: GPUs to be used 31 | - `--experiment`: Experiment number (Check `libs/utils/experiment_maker.py` for more details) 32 | - `--wandb_logger`: Does wandb Logging 33 | - `--log_offline`: Does offline Logging 34 | - `--resume`: Resumes the training if the run crashes 35 | 36 | (Specific to Distillation and Student's training) 37 | - `--alpha`: Weightage between Classifier loss and distillation loss 38 | - `--beta`: weightage for the Cosine Similarity between teachers and student 39 | - `--normal_teachers`: What all seed of norma teachers do you want to use? If you want to use only augmentation teachers, just don't use this argument. It is `None` by default. 40 | - `--aug_teachers`: What all seed of augmented teachers do you want to use? If you want to use only normal teachers, just don't use this argument. It is `None` by default. 41 | 42 | ## Raise an issue : 43 | If something is not clear or you found a bug, raise an issue!! 44 | -------------------------------------------------------------------------------- /libs/models/CosineDotProductClassifier.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from os import path 3 | import torch 4 | import torch.nn.functional as F 5 | from libs.utils.utils import * 6 | 7 | class CosineDotProduct_Classifier(nn.Module): 8 | def __init__(self, num_classes=1000, feat_dim=2048, scale=10.0, *args): 9 | super(CosineDotProduct_Classifier, self).__init__() 10 | self.fc = nn.Linear(feat_dim, num_classes) 11 | self.scale = nn.Parameter(torch.FloatTensor(1).fill_(scale), requires_grad=True) 12 | 13 | def forward(self, x, *args): 14 | # x = torch.hstack((x, torch.ones(x.shape[0]).cuda().unsqueeze(dim=1))) # (Batch, feature) - > (Batch, feature+1) 15 | # x = torch.addmm(0, x, F.normalize(self.fc.weight.T, dim=1), *, 0, self.scale) # (Batch, feature) x (out, feature).T -> (Batch, out) 16 | x = F.softplus(self.scale)*(torch.mm(x, F.normalize(self.fc.weight.T, dim=1))) + self.fc.bias # (Batch, feature) x (out, feature).T -> (Batch, out) 17 | wandb_log({"Classifier Scale": self.scale.item()}) 18 | return x 19 | 20 | 21 | def create_model(feat_dim, num_classes=1000, scale=10.0, pretrain=False, pretrain_dir=None, *args): 22 | """Initialize the model 23 | 24 | Args: 25 | feat_dim (int): output dimension of the previous feature extractor 26 | num_classes (int, optional): Number of classes. Defaults to 1000. 27 | 28 | Returns: 29 | Class: Model 30 | """ 31 | print("Loading Dot Product Classifier.") 32 | clf = CosineDotProduct_Classifier(num_classes, feat_dim, scale) 33 | 34 | if pretrain: 35 | if path.exists(pretrain_dir): 36 | print("===> Load Pretrain Initialization for CosineDotProductClassfier") 37 | weights = torch.load(pretrain_dir)["state_dict_best"]["classifier"] 38 | 39 | weights = { 40 | k: weights["module." + k] 41 | if "module." + k in weights 42 | else clf.state_dict()[k] 43 | for k in clf.state_dict() 44 | } 45 | clf.load_state_dict(weights) 46 | else: 47 | raise Exception("Pretrain path doesn't exist!!") 48 | else: 49 | print("===> Train classifier from the scratch") 50 | 51 | return clf 52 | -------------------------------------------------------------------------------- /libs/core/CBDistillCE.py: -------------------------------------------------------------------------------- 1 | 2 | import matplotlib.pyplot as plt 3 | import libs.utils.globals as g 4 | import os 5 | import copy 6 | import pickle 7 | import torch 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | import torch.nn.functional as F 11 | from torch.utils.data import TensorDataset, DataLoader 12 | from tqdm import tqdm 13 | from libs.utils.utils import * 14 | from libs.utils.logger import Logger 15 | import time 16 | import numpy as np 17 | import warnings 18 | import pdb 19 | 20 | from libs.core.core_base import model as base_model 21 | 22 | 23 | 24 | class model(base_model): 25 | def batch_forward(self, inputs, labels=None, phase="train", retrain= False): 26 | """Batch Forward 27 | """ 28 | 29 | self.features_temp = self.networks["feat_model"](inputs) 30 | self.features_temp = F.normalize(self.features_temp, p=2, dim=1) 31 | 32 | if len(self.networks.keys()) > 3: 33 | #-----Convert student feature to match with concatenated teachers' features 34 | self.features = self.networks["ecbd_converter"](self.features_temp) 35 | self.features = F.normalize(self.features, p=2, dim=1) 36 | else: 37 | self.features = self.features_temp 38 | 39 | if phase =="train": 40 | # Calculate Features and outputs 41 | self.features_teacher = [] 42 | for i in self.networks.keys(): 43 | if not(("feat_model" in i) or ("classifier" in i) or ("ecbd_converter" in i)): 44 | self.temp = self.networks[i](inputs) 45 | self.features_teacher.append(F.normalize(self.temp, p=2, dim=1)) 46 | self.features_teacher = torch.hstack(self.features_teacher) 47 | self.features_teacher = F.normalize(self.features_teacher, p=2, dim=1) 48 | 49 | self.logits = self.networks["classifier"](self.features_temp, labels) 50 | 51 | def batch_loss(self, labels): 52 | """Calculate training loss 53 | """ 54 | self.loss = 0 55 | 56 | # Calculating loss 57 | if "DistillLoss" in self.criterions.keys(): 58 | self.loss_distill = self.criterions["DistillLoss"](self.features, self.features_teacher) 59 | self.loss_distill *= self.criterion_weights["DistillLoss"] 60 | self.loss += self.loss_distill 61 | 62 | # Calculating loss 63 | if "ClassifierLoss" in self.criterions.keys(): 64 | self.loss_classifier = self.criterions["ClassifierLoss"](self.logits, labels) 65 | self.loss_classifier *= self.criterion_weights["ClassifierLoss"] 66 | self.loss += self.loss_classifier 67 | 68 | # This is there so that we can use source_import from the utils to import model 69 | def get_core(*args): 70 | return model(*args) -------------------------------------------------------------------------------- /libs/utils/logger.py: -------------------------------------------------------------------------------- 1 | """Copyright (c) Facebook, Inc. and its affiliates. 2 | All rights reserved. 3 | 4 | This source code is licensed under the license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import os 9 | import yaml 10 | import csv 11 | import h5py 12 | 13 | import libs.utils.globals as g 14 | 15 | 16 | 17 | class Logger(object): 18 | def __init__(self, logdir): 19 | if g.log_offline: 20 | self.logdir = logdir 21 | if not os.path.isdir(logdir): 22 | os.makedirs(logdir) 23 | self.cfg_file = os.path.join(self.logdir, 'cfg.yaml') 24 | self.acc_file = os.path.join(self.logdir, 'acc.csv') 25 | self.loss_file = os.path.join(self.logdir, 'loss.csv') 26 | self.ws_file = os.path.join(self.logdir, 'ws.h5') 27 | self.acc_keys = None 28 | self.loss_keys = None 29 | self.logging_ws = False 30 | 31 | def log_cfg(self, cfg): 32 | if g.log_offline: 33 | print('===> Saving cfg parameters to: ', self.cfg_file) 34 | with open(self.cfg_file, 'w') as f: 35 | yaml.dump(cfg, f) 36 | 37 | def log_acc(self, accs): 38 | if g.log_offline: 39 | if self.acc_keys is None: 40 | self.acc_keys = [k for k in accs.keys()] 41 | with open(self.acc_file, 'w') as f: 42 | writer = csv.DictWriter(f, fieldnames=self.acc_keys) 43 | writer.writeheader() 44 | writer.writerow(accs) 45 | else: 46 | with open(self.acc_file, 'a') as f: 47 | writer = csv.DictWriter(f, fieldnames=self.acc_keys) 48 | writer.writerow(accs) 49 | 50 | 51 | def log_loss(self, losses): 52 | if g.log_offline: 53 | valid_losses = losses 54 | if self.loss_keys is None: 55 | self.loss_keys = [k for k in valid_losses.keys()] 56 | with open(self.loss_file, 'w') as f: 57 | writer = csv.DictWriter(f, fieldnames=self.loss_keys) 58 | writer.writeheader() 59 | writer.writerow(valid_losses) 60 | else: 61 | with open(self.loss_file, 'a') as f: 62 | writer = csv.DictWriter(f, fieldnames=self.loss_keys) 63 | writer.writerow(valid_losses) 64 | 65 | def log_ws(self, e, ws): 66 | if g.log_offline: 67 | mode = 'a' if self.logging_ws else 'w' 68 | self.logging_ws = True 69 | 70 | key = 'Epoch{:02d}'.format(e) 71 | with h5py.File(self.ws_file, mode) as f: 72 | g = f.create_group(key) 73 | for k, v in ws.items(): 74 | g.create_dataset(k, data=v) 75 | -------------------------------------------------------------------------------- /libs/samplers/ClassAwareSampler.py: -------------------------------------------------------------------------------- 1 | """Copyright (c) Facebook, Inc. and its affiliates. 2 | All rights reserved. 3 | 4 | This source code is licensed under the license found in the 5 | LICENSE file in the root directory of this source tree. 6 | 7 | Portions of the source code are from the OLTR project which 8 | notice below and in LICENSE in the root directory of 9 | this source tree. 10 | 11 | Copyright (c) 2019, Zhongqi Miao 12 | All rights reserved. 13 | """ 14 | 15 | import random 16 | import numpy as np 17 | from torch.utils.data.sampler import Sampler 18 | import pdb 19 | 20 | ################################## 21 | ## Class-aware sampling, partly implemented by frombeijingwithlove 22 | ################################## 23 | 24 | class RandomCycleIter: 25 | 26 | def __init__ (self, data, test_mode=False): 27 | self.data_list = list(data) 28 | self.length = len(self.data_list) 29 | self.i = self.length - 1 30 | self.test_mode = test_mode 31 | 32 | def __iter__ (self): 33 | return self 34 | 35 | def __next__ (self): 36 | self.i += 1 37 | 38 | if self.i == self.length: 39 | self.i = 0 40 | if not self.test_mode: 41 | random.shuffle(self.data_list) 42 | 43 | return self.data_list[self.i] 44 | 45 | def class_aware_sample_generator (cls_iter, data_iter_list, n, num_samples_cls=1): 46 | 47 | i = 0 48 | j = 0 49 | while i < n: 50 | 51 | # yield next(data_iter_list[next(cls_iter)]) 52 | 53 | if j >= num_samples_cls: 54 | j = 0 55 | 56 | if j == 0: 57 | temp_tuple = next(zip(*[data_iter_list[next(cls_iter)]]*num_samples_cls)) 58 | yield temp_tuple[j] 59 | else: 60 | yield temp_tuple[j] 61 | 62 | i += 1 63 | j += 1 64 | 65 | class ClassAwareSampler (Sampler): 66 | 67 | def __init__(self, data_source, num_samples_cls=1,): 68 | num_classes = len(np.unique(data_source.labels)) 69 | self.class_iter = RandomCycleIter(range(num_classes)) 70 | cls_data_list = [list() for _ in range(num_classes)] 71 | for i, label in enumerate(data_source.labels): 72 | cls_data_list[label].append(i) 73 | self.data_iter_list = [RandomCycleIter(x) for x in cls_data_list] 74 | self.num_samples = max([len(x) for x in cls_data_list]) * len(cls_data_list) 75 | self.num_samples_cls = num_samples_cls 76 | 77 | def __iter__ (self): 78 | return class_aware_sample_generator(self.class_iter, self.data_iter_list, 79 | self.num_samples, self.num_samples_cls) 80 | 81 | def __len__ (self): 82 | return self.num_samples 83 | 84 | def get_sampler(): 85 | return ClassAwareSampler 86 | 87 | ################################## -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import pprint 4 | import warnings 5 | import yaml 6 | import libs.utils.globals as g 7 | import resource 8 | 9 | rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) 10 | resource.setrlimit(resource.RLIMIT_NOFILE, (4048, rlimit[1])) 11 | 12 | data_root = { 13 | "ImageNet": "/DATA/datasets/ImageNet", 14 | } 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument("--seed", default=1, type=int) 18 | parser.add_argument("--gpu", default="0,1,2,3", type=str) 19 | 20 | parser.add_argument("--experiment", default=0.1, type=float) 21 | parser.add_argument("--alpha", type=str, default="1") # always make sure to convert to the desired type in experiment_maker 22 | parser.add_argument("--beta", type=str, default="1") # always make sure to convert to the desired type in experiment_maker 23 | parser.add_argument("--normal_teachers", default=None, type=str) 24 | parser.add_argument("--aug_teachers", default=None, type=str) 25 | 26 | parser.add_argument("--wandb_logger", default=False, action="store_true") 27 | parser.add_argument("--log_offline", default=True, action="store_true") 28 | parser.add_argument("--resume", default=False, action="store_true", help="Will resume from the 'latest_model_checkpoint.pth'") 29 | 30 | args = parser.parse_args() 31 | 32 | # global configs 33 | g.wandb_log = args.wandb_logger 34 | g.epoch_global = 0 35 | g.log_offline = args.log_offline 36 | 37 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 38 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 39 | 40 | # custom 41 | from libs.utils.experiments_maker import experiment_maker 42 | from libs.data import dataloader 43 | from libs.utils.utils import * 44 | 45 | # Random Seed 46 | import torch 47 | import random 48 | import numpy as np 49 | 50 | print(f"=======> Using seed: {args.seed} <========") 51 | random.seed(args.seed) 52 | torch.manual_seed(args.seed) 53 | torch.cuda.manual_seed(args.seed) 54 | torch.cuda.manual_seed_all(args.seed) 55 | np.random.seed(args.seed) 56 | torch.backends.cudnn.deterministic = True 57 | torch.backends.cudnn.benchmark = False 58 | g.seed = args.seed 59 | 60 | config = experiment_maker(args.experiment, data_root, normal_teacher=[int(s) for s in args.normal_teachers.split(',')] if args.normal_teachers else [], aug_teacher=[int(s) for s in args.aug_teachers.split(',')] if args.aug_teachers else [], seed=args.seed, custom_var1=args.alpha, custom_var2=args.beta) 61 | 62 | if g.wandb_log: 63 | import wandb 64 | config_dictionary = config 65 | if args.resume: 66 | id = torch.load(config["training_opt"]["log_dir"]+"/latest_model_checkpoint.pth")['wandb_id'] 67 | print(f"\nResuming wandb id: {id}!\n") 68 | else: 69 | id = wandb.util.generate_id() 70 | print(f"\nStarting wandb id: {id}!\n") 71 | wandb.init( 72 | project="long-tail", 73 | entity="long-tail", 74 | reinit=True, 75 | name=f"{config['training_opt']['stage']}", 76 | allow_val_change=True, 77 | save_code=True, 78 | config=config_dictionary, 79 | tags=config["wandb_tags"], 80 | id=id, 81 | resume="allow", 82 | ) 83 | wandb.config.update(args, allow_val_change=True) 84 | config["wandb_id"] = id 85 | else: 86 | config["wandb_id"] = None 87 | 88 | if not os.path.isdir(config["training_opt"]["log_dir"]): 89 | os.makedirs(config["training_opt"]["log_dir"]) 90 | # else: 91 | # raise Exception("Directory already exists!!") 92 | 93 | g.log_dir = config["training_opt"]["log_dir"] 94 | if g.log_offline: 95 | if not os.path.isdir(f"{g.log_dir}/metrics"): 96 | os.makedirs(f"{g.log_dir}/metrics") 97 | 98 | 99 | splits = ["train", "val"] 100 | 101 | data = { 102 | x: dataloader.load_data( 103 | data_root=data_root[config["training_opt"]["dataset"].rstrip("_LT")], 104 | dataset=config["training_opt"]["dataset"], 105 | phase=x, 106 | batch_size=config["training_opt"]["batch_size"], 107 | sampler_dic=get_sampler_dict(config["training_opt"]["sampler"]), 108 | num_workers=config["training_opt"]["num_workers"], 109 | special_aug=config["training_opt"]["special_aug"] if "special_aug" in config["training_opt"] else False, 110 | ) 111 | for x in splits 112 | } 113 | # Number of samples in each class 114 | config["training_opt"]["data_count"] = data["train"].dataset.img_num_list 115 | 116 | # import appropriate core 117 | if "core" in config: 118 | training_model = source_import(config["core"]).get_core(config, data) 119 | else: 120 | from libs.core.stage_2 import model 121 | training_model = model(config, data, test=False) 122 | 123 | # training sequence 124 | print("\nInitiating training sequence!") 125 | if args.resume: 126 | training_model.resume_run(config["training_opt"]["log_dir"]+"/latest_model_checkpoint.pth") 127 | training_model.train() 128 | 129 | print("=" * 25, " ALL COMPLETED ", "=" * 25) 130 | -------------------------------------------------------------------------------- /libs/models/ResNetFeature.py: -------------------------------------------------------------------------------- 1 | """Copyright (c) Facebook, Inc. and its affiliates. 2 | All rights reserved. 3 | 4 | This source code is licensed under the license found in the 5 | LICENSE file in the root directory of this source tree. 6 | 7 | Portions of the source code are from the OLTR project which 8 | notice below and in LICENSE in the root directory of 9 | this source tree. 10 | 11 | Copyright (c) 2019, Zhongqi Miao 12 | All rights reserved. 13 | """ 14 | 15 | 16 | import math 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | 20 | 21 | def conv3x3(in_planes, out_planes, stride=1): 22 | """3x3 convolution with padding""" 23 | return nn.Conv2d( 24 | in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False 25 | ) 26 | 27 | 28 | class BasicBlock(nn.Module): 29 | expansion = 1 30 | 31 | def __init__(self, inplanes, planes, stride=1, downsample=None): 32 | super(BasicBlock, self).__init__() 33 | self.conv1 = conv3x3(inplanes, planes, stride) 34 | self.bn1 = nn.BatchNorm2d(planes) 35 | self.relu = nn.ReLU(inplace=True) 36 | self.conv2 = conv3x3(planes, planes) 37 | self.bn2 = nn.BatchNorm2d(planes) 38 | self.downsample = downsample 39 | self.stride = stride 40 | 41 | def forward(self, x): 42 | residual = x 43 | 44 | out = self.conv1(x) 45 | out = self.bn1(out) 46 | out = self.relu(out) 47 | 48 | out = self.conv2(out) 49 | out = self.bn2(out) 50 | 51 | if self.downsample is not None: 52 | residual = self.downsample(x) 53 | 54 | out += residual 55 | out = self.relu(out) 56 | 57 | return out 58 | 59 | 60 | class Bottleneck(nn.Module): 61 | expansion = 4 62 | 63 | def __init__(self, inplanes, planes, stride=1, downsample=None): 64 | super(Bottleneck, self).__init__() 65 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 66 | self.bn1 = nn.BatchNorm2d(planes) 67 | self.conv2 = nn.Conv2d( 68 | planes, planes, kernel_size=3, stride=stride, padding=1, bias=False 69 | ) 70 | self.bn2 = nn.BatchNorm2d(planes) 71 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 72 | self.bn3 = nn.BatchNorm2d(planes * 4) 73 | self.relu = nn.ReLU(inplace=True) 74 | self.downsample = downsample 75 | self.stride = stride 76 | 77 | def forward(self, x): 78 | residual = x 79 | 80 | out = self.conv1(x) 81 | out = self.bn1(out) 82 | out = self.relu(out) 83 | 84 | out = self.conv2(out) 85 | out = self.bn2(out) 86 | out = self.relu(out) 87 | 88 | out = self.conv3(out) 89 | out = self.bn3(out) 90 | 91 | if self.downsample is not None: 92 | residual = self.downsample(x) 93 | 94 | out += residual 95 | out = self.relu(out) 96 | 97 | return out 98 | 99 | 100 | class ResNet(nn.Module): 101 | def __init__(self, block, layers, use_fc=False, dropout=None): 102 | self.inplanes = 64 103 | super(ResNet, self).__init__() 104 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 105 | self.bn1 = nn.BatchNorm2d(64) 106 | self.relu = nn.ReLU(inplace=True) 107 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 108 | self.layer1 = self._make_layer(block, 64, layers[0]) 109 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 110 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 111 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 112 | self.avgpool = nn.AvgPool2d(7, stride=1) 113 | 114 | self.use_fc = use_fc 115 | self.use_dropout = True if dropout else False 116 | 117 | if self.use_fc: 118 | print("Using fc.") 119 | self.fc_add = nn.Linear(512 * block.expansion, 512) 120 | 121 | if self.use_dropout: 122 | print("Using dropout.") 123 | self.dropout = nn.Dropout(p=dropout) 124 | 125 | for m in self.modules(): 126 | if isinstance(m, nn.Conv2d): 127 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 128 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 129 | elif isinstance(m, nn.BatchNorm2d): 130 | m.weight.data.fill_(1) 131 | m.bias.data.zero_() 132 | 133 | def _make_layer(self, block, planes, blocks, stride=1): 134 | downsample = None 135 | if stride != 1 or self.inplanes != planes * block.expansion: 136 | downsample = nn.Sequential( 137 | nn.Conv2d( 138 | self.inplanes, 139 | planes * block.expansion, 140 | kernel_size=1, 141 | stride=stride, 142 | bias=False, 143 | ), 144 | nn.BatchNorm2d(planes * block.expansion), 145 | ) 146 | 147 | layers = [] 148 | layers.append(block(self.inplanes, planes, stride, downsample)) 149 | self.inplanes = planes * block.expansion 150 | for i in range(1, blocks): 151 | layers.append(block(self.inplanes, planes)) 152 | 153 | return nn.Sequential(*layers) 154 | 155 | def forward(self, x, *args): 156 | x = self.conv1(x) 157 | x = self.bn1(x) 158 | x = self.relu(x) 159 | x = self.maxpool(x) 160 | 161 | x = self.layer1(x) 162 | x = self.layer2(x) 163 | x = self.layer3(x) 164 | x = self.layer4(x) 165 | 166 | x = self.avgpool(x) 167 | 168 | x = x.view(x.size(0), -1) 169 | 170 | if self.use_fc: 171 | x = F.relu(self.fc_add(x)) 172 | 173 | if self.use_dropout: 174 | x = self.dropout(x) 175 | 176 | return x 177 | -------------------------------------------------------------------------------- /libs/models/ResNet32Feature.py: -------------------------------------------------------------------------------- 1 | from os import path 2 | from collections import OrderedDict 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.nn.init as init 8 | 9 | 10 | def _weights_init(m): 11 | classname = m.__class__.__name__ 12 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 13 | init.kaiming_normal_(m.weight) 14 | 15 | 16 | class LambdaLayer(nn.Module): 17 | def __init__(self, lambd): 18 | super(LambdaLayer, self).__init__() 19 | self.lambd = lambd 20 | 21 | def forward(self, x): 22 | return self.lambd(x) 23 | 24 | 25 | class BasicBlock(nn.Module): 26 | expansion = 1 27 | 28 | def __init__(self, in_planes, planes, stride=1, option="A"): 29 | super(BasicBlock, self).__init__() 30 | self.conv1 = nn.Conv2d( 31 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False 32 | ) 33 | self.bn1 = nn.BatchNorm2d(planes) 34 | self.conv2 = nn.Conv2d( 35 | planes, planes, kernel_size=3, stride=1, padding=1, bias=False 36 | ) 37 | self.bn2 = nn.BatchNorm2d(planes) 38 | 39 | self.shortcut = nn.Sequential() 40 | if stride != 1 or in_planes != planes: 41 | if option == "A": 42 | """ 43 | For CIFAR10 ResNet paper uses option A. 44 | """ 45 | self.shortcut = LambdaLayer( 46 | lambda x: F.pad( 47 | x[:, :, ::2, ::2], 48 | (0, 0, 0, 0, planes // 4, planes // 4), 49 | "constant", 50 | 0, 51 | ) 52 | ) 53 | elif option == "B": 54 | self.shortcut = nn.Sequential( 55 | nn.Conv2d( 56 | in_planes, 57 | self.expansion * planes, 58 | kernel_size=1, 59 | stride=stride, 60 | bias=False, 61 | ), 62 | nn.BatchNorm2d(self.expansion * planes), 63 | ) 64 | 65 | def forward(self, x): 66 | out = F.relu(self.bn1(self.conv1(x))) 67 | out = self.bn2(self.conv2(out)) 68 | out += self.shortcut(x) 69 | out = F.relu(out) 70 | return out 71 | 72 | 73 | class BBN_ResNet_Cifar(nn.Module): 74 | """ResNet32 from the "BBN: Bilateral-Branch Network with Cumulative Learning for Long-Tailed Visual Recognition (CVPR 2020)" """ 75 | 76 | def __init__(self, block, num_blocks): 77 | """Initialize 78 | #FIXME 79 | Args: 80 | block ([type]): [description] 81 | num_blocks ([type]): [description] 82 | """ 83 | super(BBN_ResNet_Cifar, self).__init__() 84 | self.in_planes = 16 85 | 86 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 87 | self.bn1 = nn.BatchNorm2d(16) 88 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) 89 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) 90 | self.layer3 = self._make_layer(block, 64, num_blocks[2] - 1, stride=2) 91 | self.cb_block = block(self.in_planes, self.in_planes, stride=1) 92 | self.rb_block = block(self.in_planes, self.in_planes, stride=1) 93 | 94 | self.apply(_weights_init) 95 | 96 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 97 | 98 | def load_model(self, pretrain_dir): 99 | """Load a pre-trained model 100 | 101 | Args: 102 | pretrain_dir (str): path of pretrained model 103 | """ 104 | print(f"Loading Backbone pretrain model from {pretrain_dir}......") 105 | model_dict = self.state_dict() 106 | pretrain_dict = torch.load(pretrain_dir)["state_dict_best"]["feat_model"] 107 | 108 | new_dict = OrderedDict() 109 | 110 | # Removing FC and Classifier layers 111 | for k, v in pretrain_dict.items(): 112 | if k.startswith("module"): 113 | k = k[7:] 114 | if "fc" not in k and "classifier" not in k: 115 | new_dict[k] = v 116 | 117 | model_dict.update(new_dict) 118 | self.load_state_dict(model_dict) 119 | print("Backbone model has been loaded......") 120 | 121 | def _make_layer(self, block, planes, num_blocks, stride, add_flag=True): 122 | strides = [stride] + [1] * (num_blocks - 1) 123 | layers = [] 124 | for stride in strides: 125 | layers.append(block(self.in_planes, planes, stride)) 126 | self.in_planes = planes * block.expansion 127 | 128 | return nn.Sequential(*layers) 129 | 130 | def forward(self, x, **kwargs): 131 | out = F.relu(self.bn1(self.conv1(x))) 132 | out = self.layer1(out) 133 | out = self.layer2(out) 134 | out = self.layer3(out) 135 | if "feature_cb" in kwargs: 136 | out = self.cb_block(out) 137 | return out 138 | elif "feature_rb" in kwargs: 139 | out = self.rb_block(out) 140 | return out 141 | 142 | out1 = self.cb_block(out) 143 | out2 = self.rb_block(out) 144 | out = torch.cat((out1, out2), dim=1) 145 | 146 | out = self.avgpool(out) 147 | out = out.view(out.shape[0], -1) 148 | 149 | 150 | return out 151 | 152 | def create_model(pretrain=False, pretrain_dir=None, *args): 153 | """Initialize/load the model 154 | 155 | Args: 156 | pretrain (bool, optional): Use pre-trained model?. Defaults to False. 157 | pretrain_dir (str, optional): Directory of the pre-trained model. Defaults to None. 158 | 159 | Returns: 160 | class: Model 161 | """ 162 | 163 | print("Loading ResNet 32 Feature Model.") 164 | resnet32 = BBN_ResNet_Cifar(BasicBlock, [5, 5, 5]) 165 | 166 | if pretrain: 167 | if path.exists(pretrain_dir): 168 | print("===> Load Pretrain Initialization for ResNet32") 169 | resnet32.load_model(pretrain_dir=pretrain_dir) 170 | else: 171 | raise Exception("Pretrain path doesn't exist!!") 172 | else: 173 | print("===> Train backbone from the scratch") 174 | 175 | return resnet32 176 | -------------------------------------------------------------------------------- /libs/data/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torchvision 4 | from torch.utils.data import Dataset, DataLoader, ConcatDataset 5 | from torchvision import transforms 6 | import os 7 | from PIL import Image 8 | 9 | # Image statistics 10 | RGB_statistics = { 11 | "iNaturalist18": {"mean": [0.466, 0.471, 0.380], "std": [0.195, 0.194, 0.192]}, 12 | "default": {"mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225]}, 13 | } 14 | 15 | # Data transformation with augmentation 16 | def get_data_transform(split, rgb_mean, rbg_std, key=False): 17 | data_transforms = { 18 | "train": transforms.Compose( 19 | [ 20 | transforms.RandomResizedCrop(224), 21 | transforms.RandomHorizontalFlip(), 22 | transforms.ToTensor(), 23 | transforms.Normalize(rgb_mean, rbg_std), 24 | ] 25 | ) 26 | if key == False 27 | else transforms.Compose( 28 | [ transforms.RandomApply([transforms.ColorJitter(brightness=(0.1, 0.3), contrast=(0.1, 0.3), saturation=(0.1, 0.3), hue=(0.1, 0.3))], p=0.8), 29 | transforms.RandomResizedCrop(224), 30 | transforms.RandomHorizontalFlip(), 31 | transforms.ToTensor(), 32 | transforms.Normalize(rgb_mean, rbg_std), 33 | AddGaussianNoise(0., 0.01), 34 | ] 35 | ), 36 | "val": transforms.Compose( 37 | [ 38 | transforms.Resize(256), 39 | transforms.CenterCrop(224), 40 | transforms.ToTensor(), 41 | transforms.Normalize(rgb_mean, rbg_std), 42 | ] 43 | ), 44 | "test": transforms.Compose( 45 | [ 46 | transforms.Resize(256), 47 | transforms.CenterCrop(224), 48 | transforms.ToTensor(), 49 | transforms.Normalize(rgb_mean, rbg_std), 50 | ] 51 | ), 52 | } 53 | return data_transforms[split] 54 | 55 | 56 | # Dataset 57 | class LT_Dataset(Dataset): 58 | def __init__(self, root, txt, transform=None, template=None, top_k=None): 59 | self.img_path = [] 60 | self.labels = [] 61 | self.transform = transform 62 | with open(txt) as f: 63 | for line in f: 64 | if "val" in line: 65 | if "iNaturalist18" in txt: 66 | rootalt = root 67 | else: 68 | rootalt = "/home/rahul_intern/" 69 | else: 70 | rootalt = root 71 | self.img_path.append(os.path.join(rootalt, line.split()[0])) 72 | self.labels.append(int(line.split()[1])) 73 | # select top k class 74 | if top_k: 75 | # only select top k in training, in case train/val/test not matching. 76 | if "train" in txt: 77 | max_len = max(self.labels) + 1 78 | dist = [[i, 0] for i in range(max_len)] 79 | for i in self.labels: 80 | dist[i][-1] += 1 81 | dist.sort(key=lambda x: x[1], reverse=True) 82 | # saving 83 | torch.save(dist, template + "_top_{}_mapping".format(top_k)) 84 | else: 85 | # loading 86 | dist = torch.load(template + "_top_{}_mapping".format(top_k)) 87 | selected_labels = {item[0]: i for i, item in enumerate(dist[:top_k])} 88 | # replace original path and labels 89 | self.new_img_path = [] 90 | self.new_labels = [] 91 | for path, label in zip(self.img_path, self.labels): 92 | if label in selected_labels: 93 | self.new_img_path.append(path) 94 | self.new_labels.append(selected_labels[label]) 95 | self.img_path = self.new_img_path 96 | self.labels = self.new_labels 97 | self.img_num_list = list(np.unique(self.labels, return_counts=True)[1]) 98 | def __len__(self): 99 | return len(self.labels) 100 | 101 | def __getitem__(self, index): 102 | 103 | path = self.img_path[index] 104 | label = self.labels[index] 105 | 106 | with open(path, "rb") as f: 107 | sample = Image.open(f).convert("RGB") 108 | 109 | if self.transform is not None: 110 | sample = self.transform(sample) 111 | 112 | return sample, label, index 113 | 114 | 115 | # Load datasets 116 | def load_data( 117 | data_root, 118 | dataset, 119 | phase, 120 | batch_size, 121 | top_k_class=None, 122 | sampler_dic=None, 123 | num_workers=4, 124 | shuffle=True, 125 | special_aug=False 126 | ): 127 | 128 | txt_split = phase 129 | txt = "./libs/data/%s/%s_%s.txt" % (dataset, dataset, txt_split) 130 | template = "./libs/data/%s/%s" % (dataset, dataset) 131 | print("Loading data from %s" % (txt)) 132 | key = special_aug 133 | 134 | rgb_mean, rgb_std = RGB_statistics["default"]["mean"], RGB_statistics["default"]["std"] 135 | if phase not in ["train", "val"]: 136 | transform = get_data_transform("test", rgb_mean, rgb_std, key) 137 | else: 138 | transform = get_data_transform(phase, rgb_mean, rgb_std, key) 139 | print("Use data transformation:", transform) 140 | 141 | set_ = LT_Dataset(data_root, txt, transform, template=template) 142 | 143 | if sampler_dic and phase == "train": 144 | print("=====> Using sampler: ", sampler_dic["sampler"]) 145 | print("=====> Sampler parameters: ", sampler_dic["params"]) 146 | return DataLoader( 147 | dataset=set_, 148 | batch_size=batch_size, 149 | shuffle=False, 150 | sampler=sampler_dic["sampler"](set_, **sampler_dic["params"]), 151 | num_workers=num_workers, 152 | ) 153 | elif phase == "train": 154 | print("=====> No sampler.") 155 | print("=====> Shuffle is %s." % (shuffle)) 156 | return DataLoader( 157 | dataset=set_, 158 | batch_size=batch_size, 159 | shuffle=shuffle, 160 | num_workers=num_workers, 161 | ) 162 | else: 163 | print("=====> No sampler.") 164 | print("=====> Shuffle is %s." % (shuffle)) 165 | return DataLoader( 166 | dataset=set_, 167 | batch_size=batch_size, 168 | shuffle=True, 169 | num_workers=num_workers, 170 | ) 171 | 172 | class AddGaussianNoise(object): 173 | def __init__(self, mean=0., std=1.): 174 | self.std = std 175 | self.mean = mean 176 | 177 | def __call__(self, tensor): 178 | return tensor + torch.randn(tensor.size()) * self.std + self.mean 179 | 180 | def __repr__(self): 181 | return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std) 182 | -------------------------------------------------------------------------------- /libs/models/ResNet32Featureb4.py: -------------------------------------------------------------------------------- 1 | from os import path 2 | from collections import OrderedDict 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.nn.init as init 8 | 9 | 10 | def _weights_init(m): 11 | classname = m.__class__.__name__ 12 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 13 | init.kaiming_normal_(m.weight) 14 | 15 | 16 | class LambdaLayer(nn.Module): 17 | def __init__(self, lambd): 18 | super(LambdaLayer, self).__init__() 19 | self.lambd = lambd 20 | 21 | def forward(self, x): 22 | return self.lambd(x) 23 | 24 | 25 | class BasicBlock(nn.Module): 26 | expansion = 1 27 | 28 | def __init__(self, in_planes, planes, stride=1, option="A"): 29 | super(BasicBlock, self).__init__() 30 | self.conv1 = nn.Conv2d( 31 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False 32 | ) 33 | self.bn1 = nn.BatchNorm2d(planes) 34 | self.conv2 = nn.Conv2d( 35 | planes, planes, kernel_size=3, stride=1, padding=1, bias=False 36 | ) 37 | self.bn2 = nn.BatchNorm2d(planes) 38 | 39 | self.shortcut = nn.Sequential() 40 | if stride != 1 or in_planes != planes: 41 | if option == "A": 42 | """ 43 | For CIFAR10 ResNet paper uses option A. 44 | """ 45 | self.shortcut = LambdaLayer( 46 | lambda x: F.pad( 47 | x[:, :, ::2, ::2], 48 | (0, 0, 0, 0, planes // 4, planes // 4), 49 | "constant", 50 | 0, 51 | ) 52 | ) 53 | elif option == "B": 54 | self.shortcut = nn.Sequential( 55 | nn.Conv2d( 56 | in_planes, 57 | self.expansion * planes, 58 | kernel_size=1, 59 | stride=stride, 60 | bias=False, 61 | ), 62 | nn.BatchNorm2d(self.expansion * planes), 63 | ) 64 | 65 | def forward(self, x): 66 | out = F.relu(self.bn1(self.conv1(x))) 67 | out = self.bn2(self.conv2(out)) 68 | out += self.shortcut(x) 69 | out = F.relu(out) 70 | return out 71 | 72 | 73 | class BBN_ResNet_Cifar(nn.Module): 74 | """ResNet32 from the "BBN: Bilateral-Branch Network with Cumulative Learning for Long-Tailed Visual Recognition (CVPR 2020)" """ 75 | 76 | def __init__(self, block, num_blocks): 77 | """Initialize 78 | #FIXME 79 | Args: 80 | block ([type]): [description] 81 | num_blocks ([type]): [description] 82 | """ 83 | super(BBN_ResNet_Cifar, self).__init__() 84 | self.in_planes = 16 85 | 86 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 87 | self.bn1 = nn.BatchNorm2d(16) 88 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) 89 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) 90 | self.layer3 = self._make_layer(block, 64, num_blocks[2] - 1, stride=2) 91 | self.cb_block = block(self.in_planes, self.in_planes, stride=1) 92 | self.rb_block = block(self.in_planes, self.in_planes, stride=1) 93 | 94 | self.apply(_weights_init) 95 | 96 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 97 | 98 | def load_model(self, pretrain_dir): 99 | """Load a pre-trained model 100 | 101 | Args: 102 | pretrain_dir (str): path of pretrained model 103 | """ 104 | print(f"Loading Backbone pretrain model from {pretrain_dir}......") 105 | model_dict = self.state_dict() 106 | pretrain_dict = torch.load(pretrain_dir)["state_dict_best"]["feat_model"] 107 | 108 | new_dict = OrderedDict() 109 | 110 | # Removing FC and Classifier layers 111 | for k, v in pretrain_dict.items(): 112 | if k.startswith("module"): 113 | k = k[7:] 114 | if "fc" not in k and "classifier" not in k: 115 | new_dict[k] = v 116 | 117 | model_dict.update(new_dict) 118 | self.load_state_dict(model_dict) 119 | print("Backbone model has been loaded......") 120 | 121 | def _make_layer(self, block, planes, num_blocks, stride, add_flag=True): 122 | strides = [stride] + [1] * (num_blocks - 1) 123 | layers = [] 124 | for stride in strides: 125 | layers.append(block(self.in_planes, planes, stride)) 126 | self.in_planes = planes * block.expansion 127 | 128 | return nn.Sequential(*layers) 129 | 130 | def forward(self, x, retrain,**kwargs): 131 | 132 | if retrain: 133 | outb4 = x.reshape(x.shape[0], 64, 8, 8) 134 | else: 135 | out = F.relu(self.bn1(self.conv1(x))) 136 | out = self.layer1(out) 137 | out = self.layer2(out) 138 | outb4 = self.layer3(out) 139 | if "feature_cb" in kwargs: 140 | outb4 = self.cb_block(outb4) 141 | return out 142 | elif "feature_rb" in kwargs: 143 | outb4 = self.rb_block(outb4) 144 | return out 145 | 146 | out1 = self.cb_block(outb4) 147 | out2 = self.rb_block(outb4) 148 | out = torch.cat((out1, out2), dim=1) 149 | 150 | out = self.avgpool(out) 151 | out = out.view(out.shape[0], -1) 152 | 153 | 154 | return outb4.flatten(start_dim=1), out 155 | 156 | # def forward(self, x, **kwargs): 157 | # outb1 = F.relu(self.bn1(self.conv1(x))) 158 | # outb2 = self.layer1(outb1) 159 | # outb3 = self.layer2(outb2) 160 | # outb4 = self.layer3(outb3) 161 | # if "feature_cb" in kwargs: 162 | # out = self.cb_block(outb4) 163 | # return out 164 | # elif "feature_rb" in kwargs: 165 | # out = self.rb_block(outb4) 166 | # return out 167 | 168 | # if kwargs == {}: 169 | # out1 = self.cb_block(outb4) 170 | # out2 = self.rb_block(outb4) 171 | # else: 172 | # out1 = self.cb_block(out) 173 | # out2 = self.rb_block(out) 174 | # outb5 = torch.cat((out1, out2), dim=1) 175 | 176 | # outb5avg = self.avgpool(outb5) 177 | # out = outb5avg.view(outb5avg.shape[0], -1) 178 | 179 | # return outb1,outb2,outb3,outb4,outb5,outb5avg, out 180 | 181 | 182 | def create_model(pretrain=False, pretrain_dir=None, *args): 183 | """Initialize/load the model 184 | 185 | Args: 186 | pretrain (bool, optional): Use pre-trained model?. Defaults to False. 187 | pretrain_dir (str, optional): Directory of the pre-trained model. Defaults to None. 188 | 189 | Returns: 190 | class: Model 191 | """ 192 | 193 | print("Loading ResNet 32 Feature Model.") 194 | resnet32 = BBN_ResNet_Cifar(BasicBlock, [5, 5, 5]) 195 | 196 | if pretrain: 197 | if path.exists(pretrain_dir): 198 | print("===> Load Pretrain Initialization for ResNet32") 199 | resnet32.load_model(pretrain_dir=pretrain_dir) 200 | else: 201 | raise Exception("Pretrain path doesn't exist!!") 202 | else: 203 | print("===> Train backbone from the scratch") 204 | 205 | return resnet32 206 | 207 | 208 | # # Use for any debugging/checking 209 | # a = create_model( 210 | # pretrain=True, pretrain_dir="libs/models/sample_final_model_checkpoint.pth" 211 | # ) 212 | # print(a) 213 | -------------------------------------------------------------------------------- /libs/utils/experiments_maker.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import pprint 3 | import libs.utils.globals as g 4 | import torch 5 | 6 | 7 | experiments = { 8 | 0.1 : "BaselineCosineCE", 9 | 0.2 : "BaselineCosineCE_DifferentAug", 10 | 0.3 : "ECBD", 11 | } 12 | 13 | def experiment_maker(experiment, data_root, normal_teacher=1, aug_teacher=None, seed=1, custom_var1="0", custom_var2="0"): 14 | """Creates an experiment and outputs an appropriate yaml file 15 | 16 | Args: 17 | experiment (float): Experiment of choice 18 | dataset (float): Dataset name 19 | data_root (dict): Dict of the root directories of all the datasets 20 | seed (int, optional): Which seed is being used ? Defaults to 1. 21 | custom_var1 (str, optional): Custom variable to use in experiments - purpose changes according to the experiment 22 | custom_var2 (str, optional): Custom variable to use in experiments - purpose changes according to the experiment 23 | 24 | Returns: 25 | [dictionary]: list of modified config files (length = number of experiments) 26 | """ 27 | assert experiment in experiments.keys(), "Wrong Experiment!" 28 | 29 | # Load Default configuration 30 | with open("libs/utils/default_config.yaml") as f: 31 | config = yaml.load(f, Loader=yaml.FullLoader) 32 | 33 | num_of_classes = 1000 34 | dataset_name = f"ImageNet_LT" 35 | exp_name_template = f'ImageNet' 36 | 37 | # Have a separate root folders and experiments names for all seeds except for seed 1 38 | if seed == 1: 39 | init_dir = "logs" 40 | else: 41 | init_dir = f"logs/other_seeds/seed_{seed}" 42 | exp_name_template = f"seed_{seed}_{exp_name_template}" 43 | 44 | config["training_opt"]["num_classes"] = num_of_classes 45 | config["training_opt"]["dataset"] = dataset_name 46 | 47 | if experiment == 0.1: #BaselineCosineCE 48 | config["core"] = "./libs/core/BaselineCosineCE.py" 49 | 50 | # loss 51 | config["criterions"]["ClassifierLoss"]["def_file"] = "./libs/loss/SoftmaxLoss.py" 52 | config["criterions"]["ClassifierLoss"]["loss_params"] = {} 53 | config["criterions"]["ClassifierLoss"]["optim_params"] = False 54 | config["criterions"]["ClassifierLoss"]["weight"] = 1.0 55 | 56 | # network 57 | # part 1 58 | config["networks"]["feat_model"]["trainable"] = True 59 | config["networks"]["feat_model"]["def_file"] = "./libs/models/ResNet50Feature.py" 60 | config["networks"]["feat_model"]["optim_params"]["lr"] = 0.2 61 | config["networks"]["feat_model"]["optim_params"]["momentum"] = 0.9 62 | config["networks"]["feat_model"]["optim_params"]["weight_decay"] = 0.0005 63 | config["networks"]["feat_model"]["scheduler_params"]["coslr"] = True 64 | config["networks"]["feat_model"]["params"]["pretrain"] = False 65 | config["networks"]["feat_model"]["params"]["pretrain_dir"] = None 66 | 67 | # part 2 68 | config["networks"]["classifier"]["trainable"] = True 69 | config["networks"]["classifier"]["def_file"] = "./libs/models/CosineDotProductClassifier.py" 70 | config["networks"]["classifier"]["optim_params"]["lr"] = 0.2 71 | config["networks"]["classifier"]["optim_params"]["momentum"] = 0.9 72 | config["networks"]["classifier"]["optim_params"]["weight_decay"] = 0.0005 73 | config["networks"]["classifier"]["scheduler_params"]["coslr"] = True 74 | config["networks"]["classifier"]["params"]["feat_dim"] = 2048 75 | config["networks"]["classifier"]["params"]["num_classes"] = num_of_classes 76 | config["networks"]["classifier"]["params"]["pretrain"] = False 77 | config["networks"]["classifier"]["params"]["pretrain_dir"] = None 78 | 79 | #delete 80 | del(config["criterions"]["PerformanceLoss"]) 81 | del(config["criterions"]["EmbeddingLoss"]) 82 | del(config["networks"]["embedding"]) 83 | 84 | # force shuffle dataset 85 | config["shuffle"] = False 86 | 87 | # tags for wandb 88 | config["wandb_tags"] = [experiments[experiment]] 89 | 90 | # other training configs 91 | config["training_opt"]["backbone"] = "resnet50" 92 | 93 | #------Effective batch size after considering GPU count and 94 | #------gradient accumulation for GPU memory bottlenech is 512. 95 | #------64 samples per batch, accumulated over 8 iters for GPU memory bottleneck. 96 | #------Since DataParallel is used, to achieve the effective batchsize of 512, the 64 samples per batch 97 | #------is divided by the GPU count. 98 | config["training_opt"]["batch_size"] = int(64/int(torch.cuda.device_count())) 99 | config["training_opt"]["accumulation_step"] = int(512/config["training_opt"]["batch_size"]) 100 | 101 | config["training_opt"]["feature_dim"] = 2048 102 | config["training_opt"]["num_workers"] = 20 103 | config["training_opt"]["num_epochs"] = 90 104 | config["training_opt"]["sampler"] = False 105 | 106 | # final name of the experiment 107 | exp_name = f'{experiments[experiment]}_{exp_name_template}_{config["training_opt"]["backbone"]}' 108 | 109 | config["training_opt"]["stage"] = exp_name 110 | config["training_opt"]["log_dir"] = f'./{init_dir}/{dataset_name}/{exp_name}' 111 | 112 | elif experiment == 0.2: #BaselineCosineCE_DifferentAug 113 | config["core"] = "./libs/core/BaselineCosineCE.py" 114 | 115 | # loss 116 | config["criterions"]["ClassifierLoss"]["def_file"] = "./libs/loss/SoftmaxLoss.py" 117 | config["criterions"]["ClassifierLoss"]["loss_params"] = {} 118 | config["criterions"]["ClassifierLoss"]["optim_params"] = False 119 | config["criterions"]["ClassifierLoss"]["weight"] = 1.0 120 | 121 | # network 122 | # part 1 123 | config["networks"]["feat_model"]["trainable"] = True 124 | config["networks"]["feat_model"]["def_file"] = "./libs/models/ResNet50Feature.py" 125 | config["networks"]["feat_model"]["optim_params"]["lr"] = 0.2 126 | config["networks"]["feat_model"]["optim_params"]["momentum"] = 0.9 127 | config["networks"]["feat_model"]["optim_params"]["weight_decay"] = 0.0005 128 | config["networks"]["feat_model"]["scheduler_params"]["coslr"] = True 129 | config["networks"]["feat_model"]["params"]["pretrain"] = False 130 | config["networks"]["feat_model"]["params"]["pretrain_dir"] = None 131 | 132 | # part 2 133 | config["networks"]["classifier"]["trainable"] = True 134 | config["networks"]["classifier"]["def_file"] = "./libs/models/CosineDotProductClassifier.py" 135 | config["networks"]["classifier"]["optim_params"]["lr"] = 0.2 136 | config["networks"]["classifier"]["optim_params"]["momentum"] = 0.9 137 | config["networks"]["classifier"]["optim_params"]["weight_decay"] = 0.0005 138 | config["networks"]["classifier"]["scheduler_params"]["coslr"] = True 139 | config["networks"]["classifier"]["params"]["feat_dim"] = 2048 140 | config["networks"]["classifier"]["params"]["num_classes"] = num_of_classes 141 | config["networks"]["classifier"]["params"]["pretrain"] = False 142 | config["networks"]["classifier"]["params"]["pretrain_dir"] = None 143 | 144 | #delete 145 | del(config["criterions"]["PerformanceLoss"]) 146 | del(config["criterions"]["EmbeddingLoss"]) 147 | del(config["networks"]["embedding"]) 148 | 149 | # force shuffle dataset 150 | config["shuffle"] = False 151 | 152 | # tags for wandb 153 | config["wandb_tags"] = [experiments[experiment]] 154 | 155 | # other training configs 156 | config["training_opt"]["backbone"] = "resnet50" 157 | 158 | #------Effective batch size after considering GPU count and 159 | #------gradient accumulation for GPU memory bottlenech is 512. 160 | #------64 samples per batch, accumulated over 8 iters for GPU memory bottleneck. 161 | #------Since DataParallel is used, to achieve the effective batchsize of 512, the 64 samples per batch 162 | #------is divided by the GPU count. 163 | config["training_opt"]["batch_size"] = int(64/int(torch.cuda.device_count())) 164 | config["training_opt"]["accumulation_step"] = int(512/config["training_opt"]["batch_size"]) 165 | 166 | config["training_opt"]["feature_dim"] = 2048 167 | config["training_opt"]["num_workers"] = 20 168 | config["training_opt"]["num_epochs"] = 90 169 | config["training_opt"]["sampler"] = False 170 | config["training_opt"]["special_aug"] = True 171 | 172 | # final name of the experiment 173 | exp_name = f'{experiments[experiment]}_{exp_name_template}_{config["training_opt"]["backbone"]}' 174 | config["training_opt"]["stage"] = exp_name 175 | config["training_opt"]["log_dir"] = f'./{init_dir}/{dataset_name}/{exp_name}' 176 | 177 | 178 | elif experiment == 0.3: #ECBD_BaselineCosineCE 179 | config["core"] = "./libs/core/CBDistillCE.py" 180 | 181 | # loss 182 | config["criterions"]["ClassifierLoss"]["def_file"] = "./libs/loss/SoftmaxLoss.py" 183 | config["criterions"]["ClassifierLoss"]["loss_params"] = {} 184 | config["criterions"]["ClassifierLoss"]["optim_params"] = False 185 | config["criterions"]["ClassifierLoss"]["weight"] = 1.0 - float(custom_var1) 186 | 187 | # Distill loss (Just doing cosine distance between teacher and student features) 188 | config["criterions"]["DistillLoss"] = {} 189 | config["criterions"]["DistillLoss"]["def_file"] = "./libs/loss/CosineDistill.py" 190 | config["criterions"]["DistillLoss"]["loss_params"] = {} 191 | config["criterions"]["DistillLoss"]["loss_params"]["beta"] = float(custom_var2) 192 | config["criterions"]["DistillLoss"]["optim_params"] = False 193 | config["criterions"]["DistillLoss"]["weight"] = float(custom_var1) 194 | 195 | # network 196 | # part 1 197 | config["networks"]["feat_model"]["trainable"] = True 198 | config["networks"]["feat_model"]["def_file"] = "./libs/models/ResNet50Feature.py" 199 | config["networks"]["feat_model"]["optim_params"]["lr"] = 0.2 200 | config["networks"]["feat_model"]["optim_params"]["momentum"] = 0.9 201 | config["networks"]["feat_model"]["optim_params"]["weight_decay"] = 0.0005 202 | config["networks"]["feat_model"]["scheduler_params"]["coslr"] = True 203 | config["networks"]["feat_model"]["params"]["pretrain"] = False 204 | config["networks"]["feat_model"]["params"]["pretrain_dir"] = None 205 | 206 | # part 2 207 | config["networks"]["classifier"]["trainable"] = True 208 | config["networks"]["classifier"]["def_file"] = "./libs/models/CosineDotProductClassifier.py" 209 | config["networks"]["classifier"]["optim_params"]["lr"] = 0.2 210 | config["networks"]["classifier"]["optim_params"]["momentum"] = 0.9 211 | config["networks"]["classifier"]["optim_params"]["weight_decay"] = 0.0005 212 | config["networks"]["classifier"]["scheduler_params"]["coslr"] = True 213 | config["networks"]["classifier"]["params"]["feat_dim"] = 2048 214 | config["networks"]["classifier"]["params"]["num_classes"] = num_of_classes 215 | config["networks"]["classifier"]["params"]["pretrain"] = False 216 | config["networks"]["classifier"]["params"]["pretrain_dir"] = None 217 | 218 | config["training_opt"]["backbone"] = "resnet50" 219 | for i,j in zip(range(len(normal_teacher)), normal_teacher): 220 | 221 | exp_name_template_t = f'ImageNet' 222 | seed_t = j 223 | # Have a separate root folders and experiments names for all seeds except for seed 1 224 | if seed_t == 1: 225 | init_dir_t = "logs" 226 | else: 227 | init_dir_t = f"logs/other_seeds/seed_{seed_t}" 228 | exp_name_template_t = f"seed_{seed_t}_{exp_name_template_t}" 229 | 230 | config["networks"][f"normal_t{i}_model"] = {} 231 | config["networks"][f"normal_t{i}_model"]["trainable"] = True 232 | config["networks"][f"normal_t{i}_model"]["def_file"] = "./libs/models/ResNet50Feature.py" 233 | config["networks"][f"normal_t{i}_model"]["optim_params"] = {} 234 | config["networks"][f"normal_t{i}_model"]["optim_params"]["lr"] = 0.2 235 | config["networks"][f"normal_t{i}_model"]["optim_params"]["momentum"] = 0.9 236 | config["networks"][f"normal_t{i}_model"]["optim_params"]["weight_decay"] = 0.0005 237 | config["networks"][f"normal_t{i}_model"]["scheduler_params"] = {} 238 | config["networks"][f"normal_t{i}_model"]["scheduler_params"]["coslr"] = True 239 | config["networks"][f"normal_t{i}_model"]["scheduler_params"]["endlr"] = 0.0 240 | config["networks"][f"normal_t{i}_model"]["scheduler_params"]["step_size"] = 30 241 | config["networks"][f"normal_t{i}_model"]["params"] = {} 242 | config["networks"][f"normal_t{i}_model"]["params"]["pretrain"] = True 243 | config["networks"][f"normal_t{i}_model"]["params"]["pretrain_dir"] = f'./{init_dir_t}/{dataset_name}/{experiments[0.1]}_{exp_name_template_t}_{config["training_opt"]["backbone"]}/final_model_checkpoint.pth' 244 | config["networks"][f"normal_t{i}_model"]["fix"] = True 245 | 246 | for i,j in zip(range(len(aug_teacher)), aug_teacher): 247 | 248 | exp_name_template_t = f'ImageNet' 249 | seed_t = j 250 | # Have a separate root folders and experiments names for all seeds except for seed 1 251 | if seed_t == 1: 252 | init_dir_t = "logs" 253 | else: 254 | init_dir_t = f"logs/other_seeds/seed_{seed_t}" 255 | exp_name_template_t = f"seed_{seed_t}_{exp_name_template_t}" 256 | 257 | config["networks"][f"aug_t{i}_model"] = {} 258 | config["networks"][f"aug_t{i}_model"]["trainable"] = True 259 | config["networks"][f"aug_t{i}_model"]["def_file"] = "./libs/models/ResNet50Feature.py" 260 | config["networks"][f"aug_t{i}_model"]["optim_params"] = {} 261 | config["networks"][f"aug_t{i}_model"]["optim_params"]["lr"] = 0.2 262 | config["networks"][f"aug_t{i}_model"]["optim_params"]["momentum"] = 0.9 263 | config["networks"][f"aug_t{i}_model"]["optim_params"]["weight_decay"] = 0.0005 264 | config["networks"][f"aug_t{i}_model"]["scheduler_params"] = {} 265 | config["networks"][f"aug_t{i}_model"]["scheduler_params"]["coslr"] = True 266 | config["networks"][f"aug_t{i}_model"]["scheduler_params"]["endlr"] = 0.0 267 | config["networks"][f"aug_t{i}_model"]["scheduler_params"]["step_size"] = 30 268 | config["networks"][f"aug_t{i}_model"]["params"] = {} 269 | config["networks"][f"aug_t{i}_model"]["params"]["pretrain"] = True 270 | config["networks"][f"aug_t{i}_model"]["params"]["pretrain_dir"] = f'./{init_dir_t}/{dataset_name}/{experiments[0.2]}_{exp_name_template_t}_{config["training_opt"]["backbone"]}/final_model_checkpoint.pth' 271 | config["networks"][f"aug_t{i}_model"]["fix"] = True 272 | 273 | if (len(normal_teacher) + len(aug_teacher)) > 1 : 274 | config["networks"]["ecbd_converter"] = {} 275 | config["networks"]["ecbd_converter"]["trainable"] = True 276 | config["networks"]["ecbd_converter"]["def_file"] = "./libs/models/ecbd_converter.py" 277 | config["networks"]["ecbd_converter"]["optim_params"] = {} 278 | config["networks"]["ecbd_converter"]["optim_params"]["lr"] = 0.2 279 | config["networks"]["ecbd_converter"]["optim_params"]["momentum"] = 0.9 280 | config["networks"]["ecbd_converter"]["optim_params"]["weight_decay"] = 0.0005 281 | config["networks"]["ecbd_converter"]["scheduler_params"] = {} 282 | config["networks"]["ecbd_converter"]["scheduler_params"]["coslr"] = True 283 | config["networks"]["ecbd_converter"]["scheduler_params"]["endlr"] = 0.0 284 | config["networks"]["ecbd_converter"]["scheduler_params"]["step_size"] = 30 285 | config["networks"]["ecbd_converter"]["params"] = {} 286 | config["networks"]["ecbd_converter"]["params"]["feat_in"] = config["networks"]["classifier"]["params"]["feat_dim"] 287 | config["networks"]["ecbd_converter"]["params"]["feat_out"] = config["networks"]["classifier"]["params"]["feat_dim"]*(len(normal_teacher) + len(aug_teacher)) 288 | 289 | #delete 290 | del(config["criterions"]["PerformanceLoss"]) 291 | del(config["criterions"]["EmbeddingLoss"]) 292 | del(config["networks"]["embedding"]) 293 | 294 | # force shuffle dataset 295 | config["shuffle"] = False 296 | 297 | # tags for wandb 298 | config["wandb_tags"] = [experiments[experiment]] 299 | 300 | # other training configs 301 | config["training_opt"]["backbone"] = "resnet50" 302 | 303 | #------Effective batch size after considering GPU count and 304 | #------gradient accumulation for GPU memory bottlenech is 512. 305 | #------64 samples per batch, accumulated over 8 iters for GPU memory bottleneck. 306 | #------Since DataParallel is used, to achieve the effective batchsize of 512, the 64 samples per batch 307 | #------is divided by the GPU count. 308 | config["training_opt"]["batch_size"] = int(64/int(torch.cuda.device_count())) 309 | config["training_opt"]["accumulation_step"] = int(512/config["training_opt"]["batch_size"]) 310 | 311 | config["training_opt"]["feature_dim"] = 2048 312 | config["training_opt"]["num_workers"] = 20 313 | config["training_opt"]["num_epochs"] = 90 314 | 315 | config["training_opt"]["sampler"] = {"def_file": "./libs/samplers/ClassAwareSampler.py", "num_samples_cls": 4, "type": "ClassAwareSampler"} 316 | 317 | # final name of the experiment 318 | exp_name = f'{experiments[experiment]}_{exp_name_template}_{config["training_opt"]["backbone"]}' 319 | config["training_opt"]["stage"] = exp_name 320 | config["training_opt"]["log_dir"] = f'./{init_dir}/{dataset_name}/{exp_name}/alpha_{float(custom_var1)},beta_{float(custom_var2)}_normal_k_{len(normal_teacher)}_aug_k_{len(aug_teacher)}' 321 | 322 | else: 323 | print(f"Wrong experiments setup!-{experiment}") 324 | 325 | if g.log_offline: 326 | g.log_dir = config["training_opt"]["log_dir"] 327 | return config 328 | 329 | -------------------------------------------------------------------------------- /libs/core/core_base.py: -------------------------------------------------------------------------------- 1 | 2 | import matplotlib.pyplot as plt 3 | from torch.optim import optimizer 4 | import libs.utils.globals as g 5 | import os 6 | import copy 7 | import pickle 8 | import torch 9 | import torch.nn as nn 10 | import torch.optim as optim 11 | import torch.nn.functional as F 12 | from tqdm import tqdm 13 | from libs.utils.utils import * 14 | from libs.utils.logger import Logger 15 | import time 16 | import numpy as np 17 | import warnings 18 | import pdb 19 | 20 | 21 | class model: 22 | def __init__(self, config, data): 23 | """Initialize 24 | 25 | Args: 26 | config (Dict): Dictionary of all the configurations 27 | data (list): Train, val, test splits 28 | """ 29 | 30 | self.config = config 31 | self.training_opt = self.config["training_opt"] 32 | self.data = data 33 | self.num_gpus = torch.cuda.device_count() 34 | self.do_shuffle = config["shuffle"] if "shuffle" in config else False 35 | self.start_epoch = 0 36 | self.accumulation_step = self.training_opt["accumulation_step"] 37 | 38 | #-----Offline Logger 39 | self.logger = Logger(self.training_opt["log_dir"]) 40 | self.log_file = os.path.join(self.training_opt["log_dir"], "log.txt") 41 | self.logger.log_cfg(self.config) 42 | 43 | 44 | # If using steps for training, we need to calculate training steps 45 | # for each epoch based on actual number of training data instead of 46 | # oversampled data number 47 | print("Using steps for training.") 48 | self.training_data_num = len(self.data["train"].dataset) 49 | self.epoch_steps = int(self.training_data_num / self.training_opt["batch_size"]) 50 | 51 | # Initialize loss 52 | self.init_criterions() 53 | 54 | # Initialize model 55 | self.init_models() 56 | 57 | # Initialize model optimizer and scheduler 58 | print("Initializing model optimizer.") 59 | self.init_optimizers(self.model_optim_params_dict) 60 | 61 | 62 | def init_models(self): 63 | """Initialize models 64 | """ 65 | networks_defs = self.config["networks"] 66 | self.networks = {} 67 | self.model_optim_params_dict = {} 68 | 69 | print("Using", torch.cuda.device_count(), "GPUs.") 70 | 71 | # Create the models in loop 72 | for key, val in networks_defs.items(): 73 | def_file = val["def_file"] 74 | model_args = val["params"] 75 | 76 | # Create/load model 77 | self.networks[key] = source_import(def_file).create_model(**model_args) 78 | if networks_defs[key]["trainable"]: 79 | self.networks[key] = nn.DataParallel(self.networks[key]).cuda() 80 | 81 | # Freezing part or entire model 82 | if "fix" in val and val["fix"]: 83 | print(f"Freezing weights of module {key}") 84 | for param_name, param in self.networks[key].named_parameters(): 85 | param.requires_grad = False 86 | if "fix_set" in val: 87 | for fix_layer in val["fix_set"]: 88 | for param_name, param in self.networks[key].named_parameters(): 89 | if fix_layer == param_name: 90 | param.requires_grad = False 91 | print(f"=====> Freezing: {param_name} | {param.requires_grad}") 92 | 93 | # wandb logging 94 | if g.wandb_log: 95 | wandb.watch(self.networks[key], log="all") 96 | 97 | # Optimizer list to add to the optimizer in the "init_optimizer" step 98 | optim_params = val["optim_params"] 99 | self.model_optim_params_dict[key] = { 100 | "params": self.networks[key].parameters(), 101 | "lr": optim_params["lr"], 102 | "momentum": optim_params["momentum"], 103 | "weight_decay": optim_params["weight_decay"], 104 | } 105 | 106 | def init_optimizers(self, optim_params_dict): 107 | """Init optimizer with/without scheduler for it 108 | 109 | Args: 110 | optim_params_dict (Dict): A dictonary with all the params for the optimizer 111 | """ 112 | networks_defs = self.config["networks"] 113 | self.model_optimizer_dict = {} 114 | self.model_scheduler_dict = {} 115 | 116 | for key, val in networks_defs.items(): 117 | if networks_defs[key]["trainable"]: 118 | # optimizer 119 | if ("optimizer" in self.training_opt and self.training_opt["optimizer"] == "adam"): 120 | print("=====> Using Adam optimizer") 121 | optimizer = optim.Adam([optim_params_dict[key],]) 122 | else: 123 | print("=====> Using SGD optimizer") 124 | optimizer = optim.SGD([optim_params_dict[key],]) 125 | self.model_optimizer_dict[key] = optimizer 126 | 127 | # scheduler 128 | if val["scheduler_params"]: 129 | scheduler_params = val["scheduler_params"] 130 | 131 | if scheduler_params["coslr"]: 132 | self.model_scheduler_dict[key] = torch.optim.lr_scheduler.CosineAnnealingLR( 133 | optimizer, 134 | self.training_opt["num_epochs"], 135 | ) 136 | elif scheduler_params['warmup']: 137 | print("===> Module {} : Using warmup".format(key)) 138 | self.model_scheduler_dict[key] = WarmupMultiStepLR(optimizer, scheduler_params['lr_step'], 139 | gamma=scheduler_params['lr_factor'], warmup_epochs=scheduler_params['warm_epoch']) 140 | else: 141 | self.model_scheduler_dict[key] = optim.lr_scheduler.StepLR( 142 | optimizer, 143 | step_size=scheduler_params["step_size"], 144 | gamma=scheduler_params["gamma"], 145 | ) 146 | 147 | def init_criterions(self): 148 | """Initialize criterion (loss) and if required optimizer, scheduler for trainable params in it. 149 | """ 150 | criterion_defs = self.config["criterions"] 151 | self.criterions = {} 152 | self.criterion_weights = {} 153 | 154 | for key, val in criterion_defs.items(): 155 | def_file = val["def_file"] 156 | loss_args = list(val["loss_params"].values()) 157 | 158 | self.criterions[key] = (source_import(def_file).create_loss(*loss_args).cuda()) 159 | self.criterion_weights[key] = val["weight"] 160 | 161 | if val["optim_params"]: 162 | print("Initializing criterion optimizer.") 163 | optim_params = val["optim_params"] 164 | optim_params = [ 165 | { 166 | "params": self.criterions[key].parameters(), 167 | "lr": optim_params["lr"], 168 | "momentum": optim_params["momentum"], 169 | "weight_decay": optim_params["weight_decay"], 170 | } 171 | ] 172 | 173 | # Initialize criterion optimizer 174 | if ("optimizer" in self.training_opt and self.training_opt["optimizer"] == "adam"): 175 | print("=====> Using Adam optimizer") 176 | optimizer = optim.Adam(optim_params) 177 | else: 178 | print("=====> Using SGD optimizer") 179 | optimizer = optim.SGD(optim_params) 180 | self.criterion_optimizer = optimizer 181 | 182 | # Initialize criterion scheduler 183 | if "scheduler_params" in val and val["scheduler_params"]: 184 | scheduler_params = val["scheduler_params"] 185 | if scheduler_params["coslr"]: 186 | self.criterion_optimizer_scheduler = ( 187 | torch.optim.lr_scheduler.CosineAnnealingLR( 188 | optimizer, 189 | self.training_opt["num_epochs"], 190 | ) 191 | ) 192 | else: 193 | self.criterion_optimizer_scheduler = optim.lr_scheduler.StepLR( 194 | optimizer, 195 | step_size=scheduler_params["step_size"], 196 | gamma=scheduler_params["gamma"], 197 | ) 198 | 199 | else: 200 | self.criterion_optimizer_scheduler = None 201 | else: 202 | self.criterion_optimizer = None 203 | self.criterion_optimizer_scheduler = None 204 | 205 | 206 | def show_current_lr(self): 207 | """Shows current learning rate 208 | 209 | Returns: 210 | float: Current learning rate 211 | """ 212 | max_lr = 0.0 213 | for key, val in self.model_optimizer_dict.items(): 214 | lr_set = list(set([para["lr"] for para in val.param_groups])) 215 | if max(lr_set) > max_lr: 216 | max_lr = max(lr_set) 217 | lr_set = ",".join([str(i) for i in lr_set]) 218 | print_str = [f"=====> Current Learning Rate of model {key} : {str(lr_set)}"] 219 | 220 | print_write(print_str, self.log_file) 221 | wandb_log({f"LR - {key}": float(lr_set)}) 222 | 223 | if self.criterion_optimizer: 224 | lr_set_rad = list(set([para["lr"] for para in self.criterion_optimizer.param_groups])) 225 | lr_set_rad = ",".join([str(i) for i in lr_set_rad]) 226 | 227 | wandb_log({f"LR - Radius": float(lr_set_rad)}) 228 | 229 | return max_lr 230 | 231 | def batch_forward(self, inputs, labels=None, phase="train", retrain= False): 232 | """Batch Forward 233 | 234 | Args: 235 | inputs (float Tensor): batch_size x image_size 236 | labels (int, optional): Labels. Defaults to None. 237 | phase (str, optional): Train or Test?. Defaults to "train". 238 | """ 239 | 240 | # Calculate Features and outputs 241 | self.features = self.networks["feat_model"](inputs) 242 | self.features = F.normalize(self.features, p=2, dim=1) 243 | 244 | self.logits = self.networks["classifier"](self.features, labels) 245 | 246 | def batch_backward(self): 247 | """Backprop 248 | """ 249 | if self.accumulation_step == 1: 250 | # Zero out optimizer gradients 251 | for key, optimizer in self.model_optimizer_dict.items(): 252 | optimizer.zero_grad() 253 | if self.criterion_optimizer: 254 | self.criterion_optimizer.zero_grad() 255 | 256 | # Back-propagation from loss outputs 257 | self.loss.backward() 258 | 259 | if (self.step+1) % self.accumulation_step == 0: 260 | # Step optimizers 261 | for key, optimizer in self.model_optimizer_dict.items(): 262 | optimizer.step() 263 | if self.criterion_optimizer: 264 | self.criterion_optimizer.step() 265 | 266 | if self.accumulation_step != 1: 267 | # Zero out optimizer gradients 268 | for key, optimizer in self.model_optimizer_dict.items(): 269 | optimizer.zero_grad() 270 | if self.criterion_optimizer: 271 | self.criterion_optimizer.zero_grad() 272 | 273 | def batch_loss(self, labels): 274 | """Calculate training loss 275 | 276 | Args: 277 | labels (int): Dim = Batch_size 278 | """ 279 | self.loss = 0 280 | 281 | # Calculating loss 282 | if "EmbeddingLoss" in self.criterions.keys(): 283 | self.loss_embed, self.loss_embed_proto, self.loss_embed_biasreduc = self.criterions["EmbeddingLoss"](self.embedding, labels) 284 | self.loss_embed *= self.criterion_weights["EmbeddingLoss"] 285 | self.loss += self.loss_embed 286 | 287 | # Calculating loss 288 | if "ClassifierLoss" in self.criterions.keys(): 289 | self.loss_classifier = self.criterions["ClassifierLoss"](self.logits, labels) 290 | self.loss_classifier *= self.criterion_weights["ClassifierLoss"] 291 | self.loss += self.loss_classifier #------Note here that it is not +=. GPU memory saving measure!!!! 292 | 293 | self.loss = self.loss / self.accumulation_step 294 | 295 | def shuffle_batch(self, x, y): 296 | """Force shuffle data 297 | 298 | Args: 299 | x (float Tensor): Datapoints 300 | y (int): Labels 301 | 302 | Returns: 303 | floatTensor, int: Return shuffled datapoints and corresponding labels 304 | """ 305 | index = torch.randperm(x.size(0)) 306 | x = x[index] 307 | y = y[index] 308 | return x, y 309 | 310 | def train(self, retrain=False): 311 | # When training the network 312 | 313 | print_str = ["Phase: train"] 314 | print_write(print_str, self.log_file) 315 | 316 | # Initialize best model and other variables 317 | self.best_model_weights = {} 318 | for key, _ in self.config["networks"].items(): 319 | if self.config["networks"][key]["trainable"]: 320 | self.best_model_weights[key] = copy.deepcopy(self.networks[key].state_dict()) 321 | 322 | best_acc = 0.0 323 | best_epoch = 0 324 | self.retrain = retrain 325 | self.end_epoch = self.training_opt["num_epochs"] 326 | 327 | # Loop over epochs 328 | for epoch in range(self.start_epoch, self.end_epoch + 1): 329 | epoch_start_time = time.time() 330 | 331 | g.epoch_global = epoch #---global config 332 | 333 | # train mode 334 | for key, model in self.networks.items(): 335 | if self.config["networks"][key]["trainable"]: 336 | # only train the module with lr > 0 337 | if self.config["networks"][key]["optim_params"]["lr"] == 0.0: 338 | model.eval() 339 | else: 340 | model.train() 341 | 342 | # Empty cuda cache 343 | torch.cuda.empty_cache() 344 | 345 | if self.model_scheduler_dict: 346 | for key, scheduler in self.model_scheduler_dict.items(): 347 | scheduler.step() 348 | if self.criterion_optimizer_scheduler: 349 | self.criterion_optimizer_scheduler.step() 350 | 351 | total_preds = [] 352 | total_labels = [] 353 | 354 | print_write([self.training_opt["log_dir"]], self.log_file) 355 | 356 | # print learning rate 357 | current_lr = self.show_current_lr() 358 | current_lr = min(current_lr * 50, 1.0) 359 | 360 | self.step = 0 361 | time1 = time.time() 362 | for inputs, labels, indexes in self.data["train"]: 363 | et_dataloader = time.time() - time1 364 | 365 | # Break when step equal to epoch step 366 | if self.step == self.epoch_steps: 367 | break 368 | 369 | # Force shuffle option 370 | if self.do_shuffle: 371 | inputs, labels = self.shuffle_batch(inputs, labels) 372 | 373 | # Pushing to GPU 374 | inputs, labels = inputs.cuda(), labels.cuda() 375 | 376 | with torch.set_grad_enabled(True): 377 | time2 = time.time() 378 | 379 | # If training, forward with loss, and no top 5 accuracy calculation 380 | self.batch_forward(inputs, labels, phase="train", retrain=retrain) 381 | self.batch_loss(labels) 382 | self.batch_backward() 383 | 384 | et_forback = time.time() - time2 385 | 386 | # Tracking and printing predictions 387 | _, preds = torch.max(self.logits, 1) 388 | total_preds.append(torch2numpy(preds)) 389 | total_labels.append(torch2numpy(labels)) 390 | 391 | # Output minibatch training results 392 | if self.step % self.training_opt['display_step'] == 0: 393 | 394 | minibatch_loss_classifier = self.loss_classifier.item() if 'ClassifierLoss' in self.criterions else None 395 | minibatch_loss_embed = self.loss_embed.item() if 'EmbeddingLoss' in self.criterions else None 396 | minibatch_loss_embed_proto = self.loss_embed_proto.item() if 'EmbeddingLoss' in self.criterions else None 397 | minibatch_loss_embed_biasreduc = self.loss_embed_biasreduc.item() if 'EmbeddingLoss' in self.criterions else None 398 | minibatch_loss_total = self.loss.item() 399 | minibatch_acc = mic_acc_cal(preds, labels) 400 | 401 | 402 | print_str = ['Epoch: [%d/%d]' 403 | % (epoch, self.training_opt['num_epochs']), 404 | 'Step: %5d' 405 | % (self.step), 406 | 'Minibatch_loss_embedding: %.3f' 407 | % (minibatch_loss_embed) if minibatch_loss_embed else '', 408 | 'Minibatch_loss_classifier: %.3f' 409 | % (minibatch_loss_classifier) if minibatch_loss_classifier else '', 410 | 'Minibatch_accuracy_micro: %.3f' 411 | % (minibatch_acc)] 412 | print_write(print_str, self.log_file) 413 | 414 | loss_info = { 415 | 'epoch': epoch, 416 | 'Step': self.step, 417 | 'Total': minibatch_loss_total, 418 | 'Embedding (Total)': minibatch_loss_embed, 419 | 'Proto': minibatch_loss_embed_proto, 420 | 'BiasReduc': minibatch_loss_embed_biasreduc, 421 | 'Classifier': minibatch_loss_classifier, 422 | } 423 | 424 | self.logger.log_loss(loss_info) 425 | 426 | wandb_log({"Training Loss": minibatch_loss_total}) 427 | 428 | # batch-level: sampler update 429 | if hasattr(self.data["train"].sampler, "update_weights"): 430 | if hasattr(self.data["train"].sampler, "ptype"): 431 | ptype = self.data["train"].sampler.ptype 432 | else: 433 | ptype = "score" 434 | ws = get_priority(ptype, self.logits.detach(), labels) 435 | 436 | inlist = [indexes.cpu().numpy(), ws] 437 | if self.training_opt["sampler"]["type"] == "ClassPrioritySampler": 438 | inlist.append(labels.cpu().numpy()) 439 | self.data["train"].sampler.update_weights(*inlist) 440 | 441 | #----Clear things out 442 | del inputs, labels, self.logits, self.features, preds, indexes 443 | torch.cuda.empty_cache() 444 | 445 | time1_1 = time.time() 446 | step_time = time1_1 - time1 447 | time1 = time1_1 448 | 449 | wandb_log({"Loader": et_dataloader, 450 | "ForBack": et_forback, 451 | "Total": (et_dataloader+et_forback), 452 | "Step_time": step_time, 453 | "Epoch Elap": (time1 - epoch_start_time)/60.0,}) 454 | 455 | self.step+=1 456 | g.step_global += 1 457 | 458 | # epoch-level: reset sampler weight 459 | if hasattr(self.data["train"].sampler, "get_weights"): 460 | self.logger.log_ws(epoch, self.data["train"].sampler.get_weights()) 461 | if hasattr(self.data["train"].sampler, "reset_weights"): 462 | self.data["train"].sampler.reset_weights(epoch) 463 | 464 | # After every epoch, validation 465 | rsls = {'epoch': epoch} 466 | rsls_train = self.eval_with_preds(total_preds, total_labels) 467 | 468 | eval_time1 = time.time() 469 | rsls_eval, _ , _ , _ = self.eval(phase='val') 470 | rsls.update(rsls_train) 471 | rsls.update(rsls_eval) 472 | if "test" in list(self.data.keys())[1:]: 473 | rsls_test, _ , _ , _ = self.eval(phase='test') 474 | rsls.update(rsls_test) 475 | del rsls_test 476 | eval_time2 = time.time() 477 | 478 | # Reset class weights for sampling if pri_mode is valid 479 | if hasattr(self.data["train"].sampler, "reset_priority"): 480 | ws = get_priority( 481 | self.data["train"].sampler.ptype, 482 | self.total_logits.detach(), 483 | self.total_labels, 484 | ) 485 | self.data["train"].sampler.reset_priority( 486 | ws, self.total_labels.cpu().numpy() 487 | ) 488 | 489 | self.logger.log_acc(rsls) 490 | 491 | # # Under validation, the best model need to be updated 492 | if rsls_eval["val_all"] > best_acc: 493 | best_epoch = epoch 494 | best_acc = rsls_eval["val_all"] 495 | for key, _ in self.config["networks"].items(): 496 | if self.config["networks"][key]["trainable"]: 497 | self.best_model_weights[key] = copy.deepcopy(self.networks[key].state_dict()) 498 | 499 | wandb_log({"Best Val": 100*best_acc, "Best Epoch": best_epoch}) 500 | wandb_log({'B_val_all': self.eval_acc_mic_top1, 501 | 'B_val_many': self.many_acc_top1, 502 | 'B_val_median': self.median_acc_top1, 503 | 'B_val_low': self.low_acc_top1}) 504 | 505 | print("===> Saving checkpoint") 506 | self.save_latest(epoch) 507 | wandb_log({"Eval Elap": (eval_time2-eval_time1)/60.0, 508 | "Epoch Elap": (time.time() - epoch_start_time)/60.0}) 509 | 510 | #----Clear things out 511 | del rsls_eval 512 | del rsls_train 513 | del rsls 514 | 515 | # Resetting the model with the best weights 516 | self.reset_model(self.best_model_weights) 517 | 518 | # Save the best model 519 | self.save_model(epoch, best_epoch, self.best_model_weights, best_acc) 520 | 521 | print("Done") 522 | print("Training Complete.") 523 | print_str = [f"Best validation accuracy is {best_acc} at epoch {best_epoch}"] 524 | print_write(print_str, self.log_file) 525 | 526 | # Empty cuda cache 527 | torch.cuda.empty_cache() 528 | 529 | 530 | def eval_with_preds(self, preds, labels): 531 | """Train accuracy 532 | 533 | Args: 534 | preds (int): Predictions 535 | labels (int): Ground Truth 536 | 537 | Returns: 538 | dict: dictionary of all training accuracies 539 | """ 540 | # Count the number of examples 541 | n_total = sum([len(p) for p in preds]) 542 | 543 | # Split the examples into normal and mixup 544 | normal_preds, normal_labels = [], [] 545 | mixup_preds, mixup_labels1, mixup_labels2, mixup_ws = [], [], [], [] 546 | for p, l in zip(preds, labels): 547 | if isinstance(l, tuple): 548 | mixup_preds.append(p) 549 | mixup_labels1.append(l[0]) 550 | mixup_labels2.append(l[1]) 551 | mixup_ws.append(l[2] * np.ones_like(l[0])) 552 | else: 553 | normal_preds.append(p) 554 | normal_labels.append(l) 555 | 556 | # Calculate normal prediction accuracy 557 | rsl = { 558 | "train_all": 0.0, 559 | "train_many": 0.0, 560 | "train_median": 0.0, 561 | "train_low": 0.0, 562 | } 563 | 564 | if len(normal_preds) > 0: 565 | normal_preds, normal_labels = list( 566 | map(np.concatenate, [normal_preds, normal_labels]) 567 | ) 568 | n_top1 = mic_acc_cal(normal_preds, normal_labels) 569 | ( 570 | n_top1_many, 571 | n_top1_median, 572 | n_top1_low, 573 | ) = shot_acc(normal_preds, normal_labels, self.data["train"]) 574 | rsl["train_all"] += len(normal_preds) / n_total * n_top1 575 | rsl["train_many"] += len(normal_preds) / n_total * n_top1_many 576 | rsl["train_median"] += len(normal_preds) / n_total * n_top1_median 577 | rsl["train_low"] += len(normal_preds) / n_total * n_top1_low 578 | 579 | # Calculate mixup prediction accuracy 580 | if len(mixup_preds) > 0: 581 | mixup_preds, mixup_labels, mixup_ws = list( 582 | map( 583 | np.concatenate, 584 | [mixup_preds * 2, mixup_labels1 + mixup_labels2, mixup_ws], 585 | ) 586 | ) 587 | mixup_ws = np.concatenate([mixup_ws, 1 - mixup_ws]) 588 | n_top1 = weighted_mic_acc_cal(mixup_preds, mixup_labels, mixup_ws) 589 | n_top1_many, n_top1_median, n_top1_low, = weighted_shot_acc( 590 | mixup_preds, mixup_labels, mixup_ws, self.data["train"] 591 | ) 592 | rsl["train_all"] += len(mixup_preds) / 2 / n_total * n_top1 593 | rsl["train_many"] += len(mixup_preds) / 2 / n_total * n_top1_many 594 | rsl["train_median"] += len(mixup_preds) / 2 / n_total * n_top1_median 595 | rsl["train_low"] += len(mixup_preds) / 2 / n_total * n_top1_low 596 | 597 | # Top-1 accuracy and additional string 598 | print_str = [ 599 | "\n Training acc Top1: %.3f \n" % (rsl["train_all"]), 600 | "Many_top1: %.3f" % (rsl["train_many"]), 601 | "Median_top1: %.3f" % (rsl["train_median"]), 602 | "Low_top1: %.3f" % (rsl["train_low"]), 603 | "\n", 604 | ] 605 | 606 | print_write(print_str, self.log_file) 607 | phase = "train" 608 | wandb_log({phase + '_all': rsl["train_all"]*100, 609 | phase + '_many': rsl["train_many"]*100, 610 | phase + '_median': rsl["train_median"]*100, 611 | phase + '_low': rsl["train_low"]*100, 612 | phase + ' Accuracy': ["train_all"]*100,}) 613 | 614 | return rsl 615 | 616 | def eval(self, phase='val'): 617 | print_str = ['Phase: %s' % (phase)] 618 | print_write(print_str, self.log_file) 619 | 620 | torch.cuda.empty_cache() 621 | 622 | # In validation or testing mode, set model to eval() and initialize running loss/correct 623 | for model in self.networks.values(): 624 | model.eval() 625 | 626 | # self.total_logits = torch.empty((0, self.training_opt['num_classes'])) #.cuda() 627 | self.total_labels = torch.empty(0, dtype=torch.long) #.cuda() 628 | self.total_preds = [] 629 | self.total_paths = np.empty(0) 630 | 631 | feats_all, labels_all, idxs_all, logits_all = [], [], [], [] 632 | featmaps_all = [] 633 | 634 | # Iterate over dataset 635 | stepval = 0 636 | for inputs, labels, paths in tqdm(self.data[phase]): 637 | inputs, labels = inputs.cuda(), labels.cuda() 638 | 639 | # If on training phase, enable gradients 640 | with torch.set_grad_enabled(False): 641 | 642 | # In validation or testing 643 | self.batch_forward(inputs, labels, phase=phase) 644 | self.batch_loss(labels) 645 | 646 | # self.total_logits = torch.cat((self.total_logits, self.logits)) 647 | _, preds = F.softmax(self.logits, dim=1).max(dim=1) 648 | self.total_preds.append(preds.cpu()) 649 | self.total_labels = torch.cat((self.total_labels, labels.cpu())) 650 | 651 | #----Clear things out 652 | del preds, inputs, labels 653 | torch.cuda.empty_cache() 654 | 655 | self.total_paths = np.concatenate((self.total_paths, paths)) 656 | stepval+=1 657 | 658 | preds = torch.hstack(self.total_preds) 659 | 660 | # Calculate the overall accuracy and F measurement 661 | self.eval_acc_mic_top1= mic_acc_cal(preds[self.total_labels != -1], 662 | self.total_labels[self.total_labels != -1]) 663 | self.eval_f_measure = F_measure(preds, self.total_labels, theta=self.training_opt['open_threshold']) 664 | 665 | self.many_acc_top1, \ 666 | self.median_acc_top1, \ 667 | self.low_acc_top1, \ 668 | self.cls_accs = shot_acc(preds[self.total_labels != -1], 669 | self.total_labels[self.total_labels != -1], 670 | self.data['train'], 671 | acc_per_cls=True) 672 | 673 | # Top-1 accuracy and additional string 674 | print_str = ['\n\n', 675 | 'Phase: %s' 676 | % (phase), 677 | '\n\n', 678 | 'Evaluation_accuracy_micro_top1: %.3f' 679 | % (self.eval_acc_mic_top1), 680 | '\n', 681 | 'Averaged F-measure: %.3f' 682 | % (self.eval_f_measure), 683 | '\n', 684 | 'Many_shot_accuracy_top1: %.3f' 685 | % (self.many_acc_top1), 686 | 'Median_shot_accuracy_top1: %.3f' 687 | % (self.median_acc_top1), 688 | 'Low_shot_accuracy_top1: %.3f' 689 | % (self.low_acc_top1), 690 | '\n'] 691 | 692 | rsl = {phase + '_all': self.eval_acc_mic_top1, 693 | phase + '_many': self.many_acc_top1, 694 | phase + '_median': self.median_acc_top1, 695 | phase + '_low': self.low_acc_top1, 696 | phase + '_fscore': self.eval_f_measure, 697 | phase + '_loss': self.loss.item()} 698 | 699 | wandb_log({phase + '_all': self.eval_acc_mic_top1*100, 700 | phase + '_many': self.many_acc_top1*100, 701 | phase + '_median': self.median_acc_top1*100, 702 | phase + '_low': self.low_acc_top1*100, 703 | phase + ' Accuracy': self.eval_acc_mic_top1*100, 704 | phase + ' Loss': self.loss.item(),}) 705 | 706 | 707 | print_write(print_str, self.log_file) 708 | print(f"------------->{self.eval_acc_mic_top1 * 100}") 709 | 710 | return rsl, preds, self.total_labels, self.cls_accs 711 | 712 | def save_latest(self, epoch): 713 | """Saves latest weights of the model 714 | 715 | Args: 716 | epoch (int): Epoch number 717 | """ 718 | #-----> Model's state_dict 719 | model_weights = {} 720 | for key, _ in self.config["networks"].items(): 721 | if self.config["networks"][key]["trainable"]: 722 | model_weights[key] = copy.deepcopy( 723 | self.networks[key].state_dict() 724 | ) 725 | 726 | #-----> Optimizer's state_dict 727 | optimizer_state_dict = {} 728 | for key, _ in self.model_optimizer_dict.items(): 729 | optimizer_state_dict[key] = copy.deepcopy( 730 | self.model_optimizer_dict[key].state_dict() 731 | ) 732 | 733 | #-----> Criterion's Optimizer's state_dict 734 | criterion_optimizer_state_dict = self.criterion_optimizer.state_dict() if self.criterion_optimizer else None 735 | 736 | #----> Scheduler's state dict 737 | scheduler_state_dict = {} 738 | if self.model_scheduler_dict: 739 | for key, _ in self.model_scheduler_dict.items(): 740 | scheduler_state_dict[key] = copy.deepcopy( 741 | self.model_scheduler_dict[key].state_dict() 742 | ) 743 | else: 744 | scheduler_state_dict = None 745 | 746 | #-----> Criterion's Scheduler's state_dict 747 | criterion_scheduler_state_dict = self.criterion_optimizer_scheduler.state_dict() if self.criterion_optimizer_scheduler else None 748 | 749 | 750 | model_states = { 751 | "epoch": epoch, 752 | "state_dict": model_weights, 753 | "opt_state_dict": optimizer_state_dict, 754 | "opt_crit_state_dict": criterion_optimizer_state_dict, 755 | "sch_state_dict": scheduler_state_dict, 756 | "sch_crit_state_dict": criterion_scheduler_state_dict, 757 | "wandb_id": self.config["wandb_id"], 758 | } 759 | 760 | model_dir = os.path.join( 761 | self.training_opt["log_dir"], "latest_model_checkpoint.pth" 762 | ) 763 | torch.save(model_states, model_dir) 764 | 765 | def save_model(self, epoch, best_epoch, best_model_weights, best_acc): 766 | """Saves the best model's weights 767 | 768 | Args: 769 | epoch (int): Epoch number 770 | best_epoch (int): Epoch with the best accuracy or val loss 771 | best_model_weights (float Tensor): Best model's weights 772 | best_acc (float): Best accuracy 773 | """ 774 | 775 | model_states = { 776 | "epoch": epoch, 777 | "best_epoch": best_epoch, 778 | "state_dict_best": best_model_weights, 779 | "best_acc": best_acc, 780 | } 781 | 782 | model_dir = os.path.join( 783 | self.training_opt["log_dir"], "final_model_checkpoint.pth" 784 | ) 785 | 786 | torch.save(model_states, model_dir) 787 | 788 | def reset_model(self, model_state): 789 | """Resets the model with the best weight 790 | 791 | Args: 792 | model_state (dict): dict with best weight 793 | """ 794 | for key, model in self.networks.items(): 795 | if self.config["networks"][key]["trainable"]: 796 | weights = model_state[key] 797 | weights = {k: weights[k] for k in weights if k in model.state_dict()} 798 | model.load_state_dict(weights) 799 | 800 | def resume_run(self, saved_dict): 801 | """Resets the model with the best weight 802 | 803 | Args: 804 | model_state (dict): dict with best weight 805 | """ 806 | loaded_dict = torch.load(saved_dict) 807 | model_state = loaded_dict["state_dict"] 808 | optimizer_state_dict = loaded_dict["opt_state_dict"] 809 | criterion_optimizer_state_dict = loaded_dict["opt_crit_state_dict"] 810 | scheduler_state_dict = loaded_dict["sch_state_dict"] 811 | criterion_scheduler_state_dict = loaded_dict["sch_crit_state_dict"] 812 | 813 | for key, model in self.networks.items(): 814 | if self.config["networks"][key]["trainable"]: 815 | weights = model_state[key] 816 | weights = {k: weights[k] for k in weights if k in model.state_dict()} 817 | model.load_state_dict(weights) 818 | 819 | #-----> Optimizer's state_dict 820 | for key, _ in self.model_optimizer_dict.items(): 821 | self.model_optimizer_dict[key].load_state_dict(optimizer_state_dict[key]) 822 | 823 | #-----> Criterion's Optimizer's state_dict 824 | if self.criterion_optimizer : 825 | self.criterion_optimizer.load_state_dict(criterion_optimizer_state_dict) 826 | 827 | #----> Scheduler's state dict 828 | if self.model_scheduler_dict: 829 | for key, _ in self.model_scheduler_dict.items(): 830 | self.model_scheduler_dict[key].load_state_dict(scheduler_state_dict[key]) 831 | 832 | #-----> Criterion's Scheduler's state_dict 833 | if self.criterion_optimizer_scheduler : 834 | self.criterion_optimizer_scheduler.load_state_dict(criterion_scheduler_state_dict) 835 | 836 | self.start_epoch = loaded_dict["epoch"] + 1 837 | 838 | print(f"\nResuming from Epoch: {self.start_epoch}!\n") 839 | 840 | #----------------------------------------------------- 841 | 842 | # This is there so that we can use source_import from the utils to import model 843 | def get_core(*args): 844 | return model(*args) -------------------------------------------------------------------------------- /libs/utils/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import torch 4 | from sklearn.metrics import f1_score 5 | import torch.nn.functional as F 6 | import importlib 7 | import pdb 8 | 9 | import libs.utils.globals as g 10 | 11 | 12 | def update(config, args): 13 | # Change parameters 14 | config["model_dir"] = get_value(config["model_dir"], args.model_dir) 15 | config["training_opt"]["batch_size"] = get_value( 16 | config["training_opt"]["batch_size"], args.batch_size 17 | ) 18 | return config 19 | 20 | 21 | def source_import(file_path): 22 | """This function imports python module directly from source code using importlib""" 23 | spec = importlib.util.spec_from_file_location("", file_path) 24 | module = importlib.util.module_from_spec(spec) 25 | spec.loader.exec_module(module) 26 | return module 27 | 28 | 29 | def batch_show(inp, title=None): 30 | """Imshow for Tensor.""" 31 | inp = inp.numpy().transpose((1, 2, 0)) 32 | mean = np.array([0.485, 0.456, 0.406]) 33 | std = np.array([0.229, 0.224, 0.225]) 34 | inp = std * inp + mean 35 | inp = np.clip(inp, 0, 1) 36 | plt.figure(figsize=(20, 20)) 37 | plt.imshow(inp) 38 | if title is not None: 39 | plt.title(title) 40 | 41 | 42 | def print_write(print_str, log_file): 43 | print(*print_str) 44 | if g.log_offline: 45 | if log_file is None: 46 | return 47 | with open(log_file, "a") as f: 48 | print(*print_str, file=f) 49 | 50 | 51 | def init_weights(model, weights_path, caffe=False, classifier=False): 52 | """Initialize weights""" 53 | print( 54 | "Pretrained %s weights path: %s" 55 | % ("classifier" if classifier else "feature model", weights_path) 56 | ) 57 | weights = torch.load(weights_path) 58 | if not classifier: 59 | if caffe: 60 | weights = { 61 | k: weights[k] if k in weights else model.state_dict()[k] 62 | for k in model.state_dict() 63 | } 64 | else: 65 | weights = weights["state_dict_best"]["feat_model"] 66 | weights = { 67 | k: weights["module." + k] 68 | if "module." + k in weights 69 | else model.state_dict()[k] 70 | for k in model.state_dict() 71 | } 72 | else: 73 | weights = weights["state_dict_best"]["classifier"] 74 | weights = { 75 | k: weights["module.fc." + k] 76 | if "module.fc." + k in weights 77 | else model.state_dict()[k] 78 | for k in model.state_dict() 79 | } 80 | model.load_state_dict(weights) 81 | return model 82 | 83 | 84 | def init_weights_rahul(model, weights_path, caffe=False, classifier=False): 85 | """Initialize weights""" 86 | print( 87 | "Pretrained %s weights path: %s" 88 | % ("classifier" if classifier else "feature model", weights_path) 89 | ) 90 | weights = torch.load(weights_path) 91 | if not classifier: 92 | if caffe: 93 | weights = { 94 | k: weights[k] if k in weights else model.state_dict()[k] 95 | for k in model.state_dict() 96 | } 97 | else: 98 | weights = weights["state_dict_best"]["feat_model"] 99 | weights = { 100 | k: weights["module." + k] 101 | if "module." + k in weights 102 | else model.state_dict()[k] 103 | for k in model.state_dict() 104 | } 105 | else: 106 | weights = weights["state_dict_best"]["classifier"] 107 | weights = { 108 | k: weights["module." + k] 109 | if "module." + k in weights 110 | else model.state_dict()[k] 111 | for k in model.state_dict() 112 | } 113 | model.load_state_dict(weights) 114 | return model 115 | 116 | 117 | def shot_acc( 118 | preds, labels, train_data, many_shot_thr=100, low_shot_thr=20, acc_per_cls=False 119 | ): 120 | 121 | if isinstance(train_data, np.ndarray): 122 | training_labels = np.array(train_data).astype(int) 123 | else: 124 | training_labels = np.array(train_data.dataset.labels).astype(int) 125 | 126 | if isinstance(preds, torch.Tensor): 127 | preds = preds.detach().cpu().numpy() 128 | labels = labels.detach().cpu().numpy() 129 | elif isinstance(preds, np.ndarray): 130 | pass 131 | else: 132 | raise TypeError("Type ({}) of preds not supported".format(type(preds))) 133 | train_class_count = [] 134 | test_class_count = [] 135 | class_correct = [] 136 | for l in np.unique(labels): 137 | train_class_count.append(len(training_labels[training_labels == l])) 138 | test_class_count.append(len(labels[labels == l])) 139 | class_correct.append((preds[labels == l] == labels[labels == l]).sum()) 140 | 141 | many_shot = [] 142 | median_shot = [] 143 | low_shot = [] 144 | for i in range(len(train_class_count)): 145 | if train_class_count[i] > many_shot_thr: 146 | many_shot.append((class_correct[i] / test_class_count[i])) 147 | elif train_class_count[i] < low_shot_thr: 148 | low_shot.append((class_correct[i] / test_class_count[i])) 149 | else: 150 | median_shot.append((class_correct[i] / test_class_count[i])) 151 | 152 | if len(many_shot) == 0: 153 | many_shot.append(0) 154 | if len(median_shot) == 0: 155 | median_shot.append(0) 156 | if len(low_shot) == 0: 157 | low_shot.append(0) 158 | 159 | if acc_per_cls: 160 | class_accs = [c / cnt for c, cnt in zip(class_correct, test_class_count)] 161 | return np.mean(many_shot), np.mean(median_shot), np.mean(low_shot), class_accs 162 | else: 163 | return np.mean(many_shot), np.mean(median_shot), np.mean(low_shot) 164 | 165 | 166 | def weighted_shot_acc( 167 | preds, labels, ws, train_data, many_shot_thr=100, low_shot_thr=20 168 | ): 169 | 170 | training_labels = np.array(train_data.dataset.labels).astype(int) 171 | 172 | if isinstance(preds, torch.Tensor): 173 | preds = preds.detach().cpu().numpy() 174 | labels = labels.detach().cpu().numpy() 175 | elif isinstance(preds, np.ndarray): 176 | pass 177 | else: 178 | raise TypeError("Type ({}) of preds not supported".format(type(preds))) 179 | train_class_count = [] 180 | test_class_count = [] 181 | class_correct = [] 182 | for l in np.unique(labels): 183 | train_class_count.append(len(training_labels[training_labels == l])) 184 | test_class_count.append(ws[labels == l].sum()) 185 | class_correct.append( 186 | ((preds[labels == l] == labels[labels == l]) * ws[labels == l]).sum() 187 | ) 188 | 189 | many_shot = [] 190 | median_shot = [] 191 | low_shot = [] 192 | for i in range(len(train_class_count)): 193 | if train_class_count[i] > many_shot_thr: 194 | many_shot.append((class_correct[i] / test_class_count[i])) 195 | elif train_class_count[i] < low_shot_thr: 196 | low_shot.append((class_correct[i] / test_class_count[i])) 197 | else: 198 | median_shot.append((class_correct[i] / test_class_count[i])) 199 | return np.mean(many_shot), np.mean(median_shot), np.mean(low_shot) 200 | 201 | 202 | def F_measure(preds, labels, theta=None): 203 | # Regular f1 score 204 | return f1_score( 205 | labels.detach().cpu().numpy(), preds.detach().cpu().numpy(), average="macro" 206 | ) 207 | 208 | 209 | def mic_acc_cal(preds, labels): 210 | if isinstance(labels, tuple): 211 | assert len(labels) == 3 212 | targets_a, targets_b, lam = labels 213 | acc_mic_top1 = ( 214 | lam * preds.eq(targets_a.data).cpu().sum().float() 215 | + (1 - lam) * preds.eq(targets_b.data).cpu().sum().float() 216 | ) / len(preds) 217 | else: 218 | acc_mic_top1 = (preds == labels).sum().item() / len(labels) 219 | return acc_mic_top1 220 | 221 | 222 | def weighted_mic_acc_cal(preds, labels, ws): 223 | acc_mic_top1 = ws[preds == labels].sum() / ws.sum() 224 | return acc_mic_top1 225 | 226 | 227 | def class_count(data): 228 | labels = np.array(data.dataset.labels) 229 | class_data_num = [] 230 | for l in np.unique(labels): 231 | class_data_num.append(len(labels[labels == l])) 232 | return class_data_num 233 | 234 | 235 | # New Added 236 | def torch2numpy(x): 237 | if isinstance(x, torch.Tensor): 238 | return x.detach().cpu().numpy() 239 | elif isinstance(x, (list, tuple)): 240 | return tuple([torch2numpy(xi) for xi in x]) 241 | else: 242 | return x 243 | 244 | 245 | def logits2score(logits, labels): 246 | scores = F.softmax(logits, dim=1) 247 | score = scores.gather(1, labels.view(-1, 1)) 248 | score = score.squeeze().cpu().numpy() 249 | return score 250 | 251 | 252 | def logits2entropy(logits): 253 | scores = F.softmax(logits, dim=1) 254 | scores = scores.cpu().numpy() + 1e-30 255 | ent = -scores * np.log(scores) 256 | ent = np.sum(ent, 1) 257 | return ent 258 | 259 | 260 | def logits2CE(logits, labels): 261 | scores = F.softmax(logits, dim=1) 262 | score = scores.gather(1, labels.view(-1, 1)) 263 | score = score.squeeze().cpu().numpy() + 1e-30 264 | ce = -np.log(score) 265 | return ce 266 | 267 | 268 | def get_priority(ptype, logits, labels): 269 | if ptype == "score": 270 | ws = 1 - logits2score(logits, labels) 271 | elif ptype == "entropy": 272 | ws = logits2entropy(logits) 273 | elif ptype == "CE": 274 | ws = logits2CE(logits, labels) 275 | 276 | return ws 277 | 278 | 279 | def get_value(oldv, newv): 280 | if newv is not None: 281 | return newv 282 | else: 283 | return oldv 284 | 285 | 286 | # Tang Kaihua New Add 287 | def print_grad_norm(named_parameters, logger_func, log_file, verbose=False): 288 | if not verbose: 289 | return None 290 | 291 | total_norm = 0.0 292 | param_to_norm = {} 293 | param_to_shape = {} 294 | for n, p in named_parameters.items(): 295 | if p.grad is not None: 296 | param_norm = p.grad.norm(2) 297 | total_norm += param_norm ** 2 298 | param_to_norm[n] = param_norm 299 | param_to_shape[n] = p.size() 300 | 301 | total_norm = total_norm ** (1.0 / 2) 302 | 303 | logger_func( 304 | ["----------Total norm {:.5f}-----------------".format(total_norm)], log_file 305 | ) 306 | for name, norm in sorted(param_to_norm.items(), key=lambda x: -x[1]): 307 | logger_func( 308 | ["{:<50s}: {:.5f}, ({})".format(name, norm, param_to_shape[name])], log_file 309 | ) 310 | logger_func(["-------------------------------"], log_file) 311 | 312 | return total_norm 313 | 314 | 315 | def smooth_l1_loss(input, target, beta=1.0 / 9, reduction="mean"): 316 | n = torch.abs(input - target) 317 | cond = n < beta 318 | loss = torch.where(cond, 0.5 * n ** 2 / beta, n - 0.5 * beta) 319 | if reduction == "mean": 320 | return loss.mean() 321 | elif reduction == "sum": 322 | return loss.sum() 323 | else: 324 | print("XXXXXX Error Reduction Type for smooth_l1_loss, use default mean") 325 | return loss.mean() 326 | 327 | 328 | def l2_loss(input, target, reduction="mean"): 329 | return F.mse_loss(input, target, reduction=reduction) 330 | 331 | 332 | def regression_loss( 333 | input, 334 | target, 335 | l2=False, 336 | pre_mean=True, 337 | l1=False, 338 | moving_average=False, 339 | moving_ratio=0.01, 340 | ): 341 | assert (l2 + l1 + moving_average) == 1 342 | if l2: 343 | if input.shape[0] == target.shape[0]: 344 | assert not pre_mean 345 | loss = l2_loss(input, target.clone().detach()) 346 | else: 347 | assert pre_mean 348 | loss = l2_loss(input, target.clone().detach().mean(0, keepdim=True)) 349 | elif l1: 350 | loss = smooth_l1_loss(input, target.clone().detach()) 351 | elif moving_average: 352 | # input should be register_buffer rather than nn.Parameter 353 | with torch.no_grad(): 354 | input = ( 355 | 1 - moving_ratio 356 | ) * input + moving_ratio * target.clone().detach().mean(0, keepdim=True) 357 | loss = None 358 | return loss 359 | 360 | 361 | def gumbel_softmax(logits, tau=1, hard=False, gumbel=True, dim=-1): 362 | if gumbel: 363 | gumbels = -torch.empty_like(logits).exponential_().log() # ~Gumbel(0,1) 364 | gumbels = (logits + gumbels) / tau # ~Gumbel(logits,tau) 365 | y_soft = gumbels.softmax(dim) 366 | else: 367 | y_soft = logits.softmax(dim) 368 | 369 | if hard: 370 | # Straight through. 371 | index = y_soft.max(dim, keepdim=True)[1] 372 | y_hard = torch.zeros_like(logits).scatter_(dim, index, 1.0) 373 | ret = y_hard - y_soft.detach() + y_soft 374 | else: 375 | # Reparametrization trick. 376 | ret = y_soft 377 | return ret 378 | 379 | 380 | def gumbel_sigmoid(logits, tau=1, hard=False, gumbel=True): 381 | if gumbel: 382 | gumbels = -torch.empty_like(logits).exponential_().log() # ~Gumbel(0,1) 383 | gumbels = (logits + gumbels) / tau # ~Gumbel(logits,tau) 384 | y_soft = torch.sigmoid(gumbels) 385 | else: 386 | y_soft = torch.sigmoid(logits) 387 | 388 | if hard: 389 | # Straight through. 390 | y_hard = (y_soft > 0.5).float() 391 | ret = y_hard - y_soft.detach() + y_soft 392 | else: 393 | # Reparametrization trick. 394 | ret = y_soft 395 | return ret 396 | 397 | # Warmup from BBN 398 | from bisect import bisect_right 399 | 400 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 401 | def __init__( 402 | self, 403 | optimizer, 404 | milestones, 405 | gamma=0.1, 406 | warmup_factor=1.0 / 3, 407 | warmup_epochs=5, 408 | warmup_method="linear", 409 | last_epoch=-1, 410 | ): 411 | if not list(milestones) == sorted(milestones): 412 | raise ValueError( 413 | "Milestones should be a list of" " increasing integers. Got {}", 414 | milestones, 415 | ) 416 | 417 | if warmup_method not in ("constant", "linear"): 418 | raise ValueError( 419 | "Only 'constant' or 'linear' warmup_method accepted" 420 | "got {}".format(warmup_method) 421 | ) 422 | self.milestones = milestones 423 | self.gamma = gamma 424 | self.warmup_factor = warmup_factor 425 | self.warmup_epochs = warmup_epochs 426 | self.warmup_method = warmup_method 427 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 428 | 429 | def get_lr(self): 430 | warmup_factor = 1 431 | if self.last_epoch < self.warmup_epochs: 432 | if self.warmup_method == "constant": 433 | warmup_factor = self.warmup_factor 434 | elif self.warmup_method == "linear": 435 | alpha = float(self.last_epoch) / self.warmup_epochs 436 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 437 | return [ 438 | base_lr 439 | * warmup_factor 440 | * self.gamma ** bisect_right(self.milestones, self.last_epoch) 441 | for base_lr in self.base_lrs 442 | ] 443 | 444 | # Rahul Vigneswaran 445 | import plotly.express as px 446 | import pandas as pd 447 | import plotly.figure_factory as ff 448 | from sklearn.decomposition import IncrementalPCA 449 | from sklearn.metrics import confusion_matrix 450 | # from cuml.manifold import TSNE 451 | from sklearn.manifold import TSNE 452 | # from tsnecuda import TSNE 453 | from more_itertools import sort_together 454 | 455 | import libs.utils.globals as g 456 | 457 | 458 | 459 | def plot_tsne(embedding, labels, phase="train"): 460 | """Function to plot tsne 461 | 462 | Args: 463 | embedding (float Tensor): Embedding of data. Batch Size x Embedding Size 464 | labels (int): Ground truth. 465 | phase (str, optional): Is the plot for train data or validation data or test data? Defaults to "train". 466 | """ 467 | X_tsne = TSNE(n_components=2).fit_transform(embedding) 468 | tsne_x = X_tsne[:, 0] 469 | tsne_y = X_tsne[:, 1] 470 | 471 | tsne_x = sort_together([labels,tsne_x])[1] 472 | tsne_y = sort_together([labels,tsne_y])[1] 473 | labels = sort_together([labels,labels])[1] 474 | 475 | sym = [0, 1, 4, 24, 5, 3, 17, 13, 26, 20] 476 | classes = { 477 | 0: "plane", 478 | 1: "car", 479 | 2: "bird", 480 | 3: "cat", 481 | 4: "deer", 482 | 5: "dog", 483 | 6: "frog", 484 | 7: "horse", 485 | 8: "ship", 486 | 9: "truck", 487 | } 488 | 489 | class_label = [classes[i] for i in labels] 490 | 491 | df = pd.DataFrame( 492 | list(zip(tsne_x, tsne_y, class_label)), columns=["x", "y", "Class"] 493 | ) 494 | 495 | fig = px.scatter( 496 | df, 497 | x="x", 498 | y="y", 499 | color="Class", 500 | symbol="Class", 501 | symbol_sequence=sym, 502 | hover_name=class_label, 503 | labels={"color": "Class"}, 504 | ) 505 | 506 | if g.wandb_log: 507 | if phase == "train": 508 | wandb.log({"t-SNE": fig, "epoch": g.epoch_global}) 509 | elif phase == "val": 510 | wandb.log({"t-SNE Eval": fig, "epoch": g.epoch_global}) 511 | elif phase == "test": 512 | wandb.log({"t-SNE Test": fig, "epoch": g.epoch_global}) 513 | else: 514 | raise Exception("Invalid data split!!") 515 | 516 | if g.log_offline: 517 | if phase == "train": 518 | fig.write_image(f"{g.log_dir}/metrics/tsne.png") 519 | elif phase == "val": 520 | fig.write_image(f"{g.log_dir}/metrics/tsneEval.png") 521 | elif phase == "test": 522 | fig.write_image(f"{g.log_dir}/metrics/tsneTest.png") 523 | else: 524 | raise Exception("Invalid data split!!") 525 | 526 | def plot_tsne_with_name(embedding, labels, name="train"): 527 | """Function to plot tsne 528 | 529 | Args: 530 | embedding (float Tensor): Embedding of data. Batch Size x Embedding Size 531 | labels (int): Ground truth. 532 | phase (str, optional): Is the plot for train data or validation data or test data? Defaults to "train". 533 | """ 534 | X_tsne = TSNE(n_components=2).fit_transform(embedding) 535 | tsne_x = X_tsne[:, 0] 536 | tsne_y = X_tsne[:, 1] 537 | 538 | tsne_x = sort_together([labels,tsne_x])[1] 539 | tsne_y = sort_together([labels,tsne_y])[1] 540 | labels = sort_together([labels,labels])[1] 541 | 542 | sym = [0, 1, 4, 24, 5, 3, 17, 13, 26, 20] 543 | classes = { 544 | 0: "plane", 545 | 1: "car", 546 | 2: "bird", 547 | 3: "cat", 548 | 4: "deer", 549 | 5: "dog", 550 | 6: "frog", 551 | 7: "horse", 552 | 8: "ship", 553 | 9: "truck", 554 | } 555 | 556 | class_label = [classes[i] for i in labels] 557 | 558 | df = pd.DataFrame( 559 | list(zip(tsne_x, tsne_y, class_label)), columns=["x", "y", "Class"] 560 | ) 561 | 562 | fig = px.scatter( 563 | df, 564 | x="x", 565 | y="y", 566 | color="Class", 567 | symbol="Class", 568 | symbol_sequence=sym, 569 | hover_name=class_label, 570 | labels={"color": "Class"}, 571 | ) 572 | 573 | if g.wandb_log: 574 | actual_name = f"{name} t-SNE" 575 | wandb.log({actual_name: fig}) 576 | 577 | if g.log_offline: 578 | actual_name = f"{name} t-SNE" 579 | fig.write_image(f"{g.log_dir}/metrics/{actual_name}.png") 580 | 581 | def plot_tsne_with_name_with_mark_key(embedding, labels, mark_key, name="all", dir=None): 582 | """Function to plot tsne 583 | 584 | Args: 585 | embedding (float Tensor): Embedding of data. Batch Size x Embedding Size 586 | labels (int): Ground truth. 587 | phase (str, optional): Is the plot for train data or validation data or test data? Defaults to "train". 588 | """ 589 | X_tsne = TSNE(n_components=2).fit_transform(embedding) 590 | tsne_x = X_tsne[:, 0] 591 | tsne_y = X_tsne[:, 1] 592 | 593 | tsne_x = sort_together([labels,tsne_x])[1] 594 | tsne_y = sort_together([labels,tsne_y])[1] 595 | labs = sort_together([labels,labels])[1] 596 | marker_keys = sort_together([labels, mark_key])[1] 597 | 598 | sym = [0, 26, 29, 41] #[0, 1, 4, 24, 5, 3, 17, 13, 26, 20] 599 | classes = { 600 | 0: "plane", 601 | 1: "car", 602 | 2: "bird", 603 | 3: "cat", 604 | 4: "deer", 605 | 5: "dog", 606 | 6: "frog", 607 | 7: "horse", 608 | 8: "ship", 609 | 9: "truck", 610 | } 611 | 612 | class_label = [classes[int(i.cpu().numpy())] for i in labs] 613 | marker_keys = [int(i.cpu().numpy()) for i in marker_keys] 614 | 615 | df = pd.DataFrame( 616 | list(zip(tsne_x, tsne_y, class_label, marker_keys)), columns=["x", "y", "Class", "Keys"] 617 | ) 618 | 619 | fig = px.scatter( 620 | df, 621 | x="x", 622 | y="y", 623 | color="Class", 624 | symbol="Keys", 625 | symbol_sequence=sym, 626 | hover_name=class_label, 627 | labels={"color": "Keys"}, 628 | ) 629 | 630 | # fig.show() 631 | if g.wandb_log: 632 | actual_name = f"{name} t-SNE" 633 | wandb_log({actual_name: fig}) 634 | 635 | if g.log_offline: 636 | actual_name = f"{name} t-SNE" 637 | fig.write_image(f"{dir}/{actual_name}.png") 638 | 639 | def plot_confusion(preds, labels, phase = "train"): 640 | """Function to plot confusion matrix (both usual and normalized plots) 641 | 642 | Args: 643 | labels (int): N 644 | preds (int): N 645 | phase (str, optional): Is the plot for train data or validation data or test data? Defaults to "train". 646 | """ 647 | cfm = confusion_matrix(labels, preds) 648 | 649 | for normalize in [False, True]: 650 | z = cfm 651 | 652 | if normalize: 653 | n_digits = 4 654 | 655 | z -= z.min() 656 | z = z/z.max() 657 | 658 | z = np.around(z, n_digits) 659 | 660 | # z = torch.from_numpy(np.around(z.numpy(), n_digits)) 661 | 662 | 663 | x = y = [ 664 | "0 : plane", 665 | "1: car", 666 | "2: bird", 667 | "3: cat", 668 | "4: deer", 669 | "5: dog", 670 | "6: frog", 671 | "7: horse", 672 | "8: ship", 673 | "9: truck", 674 | ] 675 | 676 | # change each element of z to type string for annotations 677 | z_text = [[str(y) for y in x] for x in z] 678 | 679 | # set up figure 680 | fig = ff.create_annotated_heatmap( 681 | z, x=x, y=y, annotation_text=z_text, colorscale="Viridis" 682 | ) 683 | 684 | # add custom xaxis title 685 | fig.add_annotation( 686 | dict( 687 | font=dict(color="black", size=14), 688 | x=0.5, 689 | y=-0.15, 690 | showarrow=False, 691 | text="Predicted Label", 692 | xref="paper", 693 | yref="paper", 694 | ) 695 | ) 696 | 697 | # add custom yaxis title 698 | fig.add_annotation( 699 | dict( 700 | font=dict(color="black", size=14), 701 | x=-0.35, 702 | y=0.5, 703 | showarrow=False, 704 | text="True Label", 705 | textangle=-90, 706 | xref="paper", 707 | yref="paper", 708 | ) 709 | ) 710 | 711 | # adjust margins to make room for yaxis title 712 | fig.update_layout(margin=dict(t=50, l=200)) 713 | 714 | # add colorbar 715 | fig["data"][0]["showscale"] = True 716 | 717 | if g.wandb_log: 718 | if phase == "train": 719 | if normalize: 720 | wandb.log({"Train Confusion Martix (Normalized)": fig}) 721 | else: 722 | wandb.log({"Train Confusion Martix": fig}) 723 | elif phase == "val": 724 | if normalize: 725 | wandb.log({"Eval Confusion Martix (Normalized)": fig}) 726 | else: 727 | wandb.log({"Eval Confusion Martix": fig}) 728 | elif phase == "test": 729 | if normalize: 730 | wandb.log({"Test Confusion Martix (Normalized)": fig}) 731 | else: 732 | wandb.log({"Test Confusion Martix": fig}) 733 | else: 734 | raise Exception("Invalid data split!!") 735 | 736 | if g.log_offline: 737 | if phase == "train": 738 | if normalize: 739 | fig.write_image(f"{g.log_dir}/metrics/Train_Confusion_Martix_(Normalized).png") 740 | else: 741 | fig.write_image(f"{g.log_dir}/metrics/Train_Confusion_Martix.png") 742 | elif phase == "val": 743 | if normalize: 744 | fig.write_image(f"{g.log_dir}/metrics/Eval_Confusion_Martix_(Normalized).png") 745 | else: 746 | fig.write_image(f"{g.log_dir}/metrics/Eval_Confusion_Martix.png") 747 | elif phase == "test": 748 | if normalize: 749 | fig.write_image(f"{g.log_dir}/metrics/Test_Confusion_Martix_(Normalized).png") 750 | else: 751 | fig.write_image(f"{g.log_dir}/metrics/Test_Confusion_Martix.png") 752 | else: 753 | raise Exception("Invalid data split!!") 754 | 755 | 756 | def plot_confusion_with_name(preds, labels, phase="train", name = "none"): 757 | """Same confusion matrix function as above but with an addition name for logging purpose (both usual and normalized plots) 758 | 759 | Args: 760 | labels (int): N 761 | preds (int): N 762 | phase (str, optional): Is the plot for train data or validation data or test data? Defaults to "train". 763 | name (str, optional): Addition name str for certain logging scenarios. Defaults to "none". 764 | """ 765 | cfm = confusion_matrix(labels, preds) 766 | 767 | for normalize in [False, True]: 768 | z = cfm 769 | 770 | if normalize: 771 | n_digits = 4 772 | 773 | z -= z.min() 774 | z = z/z.max() 775 | 776 | z = np.around(z, n_digits) 777 | 778 | x = y = [ 779 | "0 : plane", 780 | "1: car", 781 | "2: bird", 782 | "3: cat", 783 | "4: deer", 784 | "5: dog", 785 | "6: frog", 786 | "7: horse", 787 | "8: ship", 788 | "9: truck", 789 | ] 790 | 791 | # change each element of z to type string for annotations 792 | z_text = [[str(y) for y in x] for x in z] 793 | 794 | # set up figure 795 | fig = ff.create_annotated_heatmap( 796 | z, x=x, y=y, annotation_text=z_text, colorscale="Viridis" 797 | ) 798 | 799 | # add custom xaxis title 800 | fig.add_annotation( 801 | dict( 802 | font=dict(color="black", size=14), 803 | x=0.5, 804 | y=-0.15, 805 | showarrow=False, 806 | text="Predicted Label", 807 | xref="paper", 808 | yref="paper", 809 | ) 810 | ) 811 | 812 | # add custom yaxis title 813 | fig.add_annotation( 814 | dict( 815 | font=dict(color="black", size=14), 816 | x=-0.35, 817 | y=0.5, 818 | showarrow=False, 819 | text="True Label", 820 | textangle=-90, 821 | xref="paper", 822 | yref="paper", 823 | ) 824 | ) 825 | 826 | # adjust margins to make room for yaxis title 827 | fig.update_layout(margin=dict(t=50, l=200)) 828 | 829 | # add colorbar 830 | fig["data"][0]["showscale"] = True 831 | 832 | if g.wandb_log: 833 | if phase == "train": 834 | if normalize: 835 | wandb.log({f"Train Confusion Martix (Normalized) - {name}": fig}) 836 | else: 837 | wandb.log({f"Train Confusion Martix - {name}": fig}) 838 | elif phase == "val": 839 | if normalize: 840 | wandb.log({f"Eval Confusion Martix (Normalized) - {name}": fig}) 841 | else: 842 | wandb.log({f"Eval Confusion Martix - {name}": fig}) 843 | elif phase == "test": 844 | if normalize: 845 | wandb.log({f"Test Confusion Martix (Normalized) - {name}": fig}) 846 | else: 847 | wandb.log({f"Test Confusion Martix - {name}": fig}) 848 | else: 849 | raise Exception("Invalid data split!!") 850 | 851 | if g.log_offline: 852 | if phase == "train": 853 | if normalize: 854 | fig.write_image(f"{g.log_dir}/metrics/Train_Confusion_Martix_(Normalized)_{name}.png") 855 | else: 856 | fig.write_image(f"{g.log_dir}/metrics/Train_Confusion_Martix_{name}.png") 857 | elif phase == "val": 858 | if normalize: 859 | fig.write_image(f"{g.log_dir}/metrics/Eval_Confusion_Martix_(Normalized)_{name}.png") 860 | else: 861 | fig.write_image(f"{g.log_dir}/metrics/Eval_Confusion_Martix_{name}.png") 862 | elif phase == "test": 863 | if normalize: 864 | fig.write_image(f"{g.log_dir}/metrics/Test_Confusion_Martix_(Normalized)_{name}.png") 865 | else: 866 | fig.write_image(f"{g.log_dir}/metrics/Test_Confusion_Martix_{name}.png") 867 | else: 868 | raise Exception("Invalid data split!!") 869 | 870 | 871 | def prediction_change_finder(old, new, labels, class_names, phase): 872 | """[summary] 873 | 874 | Args: 875 | old (Python list): List of predictions from the Nearest Neighbour on the features directly from the feature extractor. 876 | new (Python list): List of predictions on the embedding after being trained on any desired loss (eg, protoloss) 877 | labels (Python list): List of ground truth labels 878 | class_names (Python list): List of class names 879 | phase (str): Is the data from train or eval set or test set? 880 | 881 | Returns: 882 | corrected (Python list): List of counts of samples that were misclassified in old and corrected in new 883 | wronged (Python list): List of counts of samples that were correct in old and misclassified in new 884 | stayed_correct (Python list): List of counts of samples that were correct in both old and new 885 | stayed_wrong (Python list): List of counts of samples that were wrong in both old and new 886 | class_total (Python list): List of total number of samples in each class 887 | """ 888 | 889 | nums = len(class_names) 890 | 891 | corrected = list(0 for i in range(nums)) 892 | wronged = list(0 for i in range(nums)) 893 | stayed_correct = list(0 for i in range(nums)) 894 | stayed_wrong = list(0 for i in range(nums)) 895 | 896 | classwise_correct_after = list(0 for i in range(nums)) 897 | classwise_correct_before = list(0 for i in range(nums)) 898 | 899 | class_correct = list(0 for i in range(nums)) 900 | class_total = list(0 for i in range(nums)) 901 | 902 | for i in range(len(labels)): 903 | correct_label = labels[i].cpu() 904 | if old[i] == correct_label : 905 | classwise_correct_before[correct_label] +=1 906 | if new[i] == correct_label : 907 | classwise_correct_after[correct_label] +=1 908 | stayed_correct[correct_label] += 1 909 | else: 910 | wronged[correct_label] += 1 911 | else: 912 | if new[i] == correct_label : 913 | classwise_correct_after[correct_label] +=1 914 | corrected[correct_label] += 1 915 | else: 916 | stayed_wrong[correct_label] += 1 917 | class_total[correct_label] += 1 918 | 919 | # corrected = torch.FloatTensor(corrected).unsqueeze(1) 920 | # wronged = torch.FloatTensor(wronged).unsqueeze(1) 921 | # stayed_correct = torch.FloatTensor(stayed_correct).unsqueeze(1) 922 | # stayed_wrong = torch.FloatTensor(stayed_wrong).unsqueeze(1) 923 | # class_total = torch.FloatTensor(class_total).unsqueeze(1) 924 | 925 | metrics = {} 926 | metrics["Wrong -> Correct"] = corrected 927 | metrics["(W->C)%"] = [(corrected[i]/class_total[i])*100 for i in range(len(class_total))] 928 | metrics["Correct -> Wrong"] = wronged 929 | metrics["(C->W)%"] = [(wronged[i]/class_total[i])*100 for i in range(len(class_total))] 930 | metrics["Stayed Correct"] = stayed_correct 931 | metrics["(Stayed Correct)%"] = [(stayed_correct[i]/class_total[i])*100 for i in range(len(class_total))] 932 | metrics["Stayed Wrong"] = stayed_wrong 933 | metrics["(Stayed Wrong)%"] = [(stayed_wrong[i]/class_total[i])*100 for i in range(len(class_total))] 934 | metrics["Class Total"] = class_total 935 | metrics["Accuracy_FeatureSpace"] = [(classwise_correct_before[i]/class_total[i])*100 for i in range(len(class_total))] 936 | metrics["Accuracy_EmbeddingSpace"] = [(classwise_correct_after[i]/class_total[i])*100 for i in range(len(class_total))] 937 | 938 | if g.wandb_log: 939 | if phase == "train": 940 | df = pd.DataFrame(metrics).astype("float").round(5) 941 | df.insert(0, column="Class", value=class_names) 942 | wandb.log({f"Train: Raw to Trained metrics": wandb.Table(dataframe=df)}) 943 | elif phase == "val": 944 | df = pd.DataFrame(metrics).astype("float").round(5) 945 | df.insert(0, column="Class", value=class_names) 946 | wandb.log({f"Eval: Raw to Trained metrics": wandb.Table(dataframe=df)}) 947 | elif phase == "test": 948 | df = pd.DataFrame(metrics).astype("float").round(5) 949 | df.insert(0, column="Class", value=class_names) 950 | wandb.log({f"Test: Raw to Trained metrics": wandb.Table(dataframe=df)}) 951 | else: 952 | raise Exception("Invalid data split!!") 953 | 954 | if g.log_offline: 955 | if phase == "train": 956 | df = pd.DataFrame(metrics).astype("float").round(5) 957 | df.insert(0, column="Class", value=class_names) 958 | fig, ax = plt.subplots(figsize=(25, 5)) # set size frame 959 | ax.xaxis.set_visible(False) # hide the x axis 960 | ax.yaxis.set_visible(False) # hide the y axis 961 | ax.set_frame_on(False) # no visible frame, uncomment if size is ok 962 | tabla = table(ax, df, loc='upper right', colWidths=[0.05]*len(df.columns)) # where df is your data frame 963 | tabla.auto_set_font_size(True) # Activate set fontsize manually 964 | tabla.set_fontsize(15) # if ++fontsize is necessary ++colWidths 965 | tabla.scale(1.7, 2) # change size table 966 | plt.savefig(f"{g.log_dir}/metrics/TrainRawtoTrainedmetrics.png", transparent=False) 967 | elif phase == "val": 968 | df = pd.DataFrame(metrics).astype("float").round(5) 969 | df.insert(0, column="Class", value=class_names) 970 | fig, ax = plt.subplots(figsize=(25, 5)) # set size frame 971 | ax.xaxis.set_visible(False) # hide the x axis 972 | ax.yaxis.set_visible(False) # hide the y axis 973 | ax.set_frame_on(False) # no visible frame, uncomment if size is ok 974 | tabla = table(ax, df, loc='upper right', colWidths=[0.05]*len(df.columns)) # where df is your data frame 975 | tabla.auto_set_font_size(True) # Activate set fontsize manually 976 | tabla.set_fontsize(15) # if ++fontsize is necessary ++colWidths 977 | tabla.scale(1.7, 2) # change size table 978 | plt.savefig(f"{g.log_dir}/metrics/EvalRawtoTrainedmetrics.png", transparent=False) 979 | elif phase == "test": 980 | df = pd.DataFrame(metrics).astype("float").round(5) 981 | df.insert(0, column="Class", value=class_names) 982 | fig, ax = plt.subplots(figsize=(25, 5)) # set size frame 983 | ax.xaxis.set_visible(False) # hide the x axis 984 | ax.yaxis.set_visible(False) # hide the y axis 985 | ax.set_frame_on(False) # no visible frame, uncomment if size is ok 986 | tabla = table(ax, df, loc='upper right', colWidths=[0.05]*len(df.columns)) # where df is your data frame 987 | tabla.auto_set_font_size(True) # Activate set fontsize manually 988 | tabla.set_fontsize(15) # if ++fontsize is necessary ++colWidths 989 | tabla.scale(1.7, 2) # change size table 990 | plt.savefig(f"{g.log_dir}/metrics/TestRawtoTrainedmetrics.png", transparent=False) 991 | else: 992 | raise Exception("Invalid data split!!") 993 | 994 | # return corrected, wronged, stayed_correct, stayed_wrong, class_total 995 | 996 | 997 | def wandb_log(dict1): 998 | if g.wandb_log: 999 | import wandb 1000 | dict1["epoch"] = g.epoch_global 1001 | dict1["iter"] = g.step_global 1002 | wandb.log(dict1) 1003 | 1004 | def get_sampler_dict(sampler_defs): 1005 | if sampler_defs: 1006 | if sampler_defs["type"] == "ClassAwareSampler": # Inverse Sampler 1007 | sampler_dic = { 1008 | "sampler": source_import(sampler_defs["def_file"]).get_sampler(), 1009 | "params": {"num_samples_cls": sampler_defs["num_samples_cls"]}, 1010 | } 1011 | elif sampler_defs["type"] in ["MixedPrioritizedSampler","ClassPrioritySampler",]: 1012 | sampler_dic = { 1013 | "sampler": source_import(sampler_defs["def_file"]).get_sampler(), 1014 | "params": {k: v for k, v in sampler_defs.items() if k not in ["type", "def_file"]}, 1015 | } 1016 | else: 1017 | sampler_dic = None 1018 | 1019 | return sampler_dic 1020 | 1021 | 1022 | --------------------------------------------------------------------------------