├── README.md ├── base ├── __init__.py ├── base_data_loader.py ├── base_model.py └── base_trainer.py ├── configs ├── cifar100_lt.json ├── cifar10_lt.json └── imagenet_lt.json ├── data_loader ├── cifar_data_loaders.py ├── data_loaders.py ├── imagenet_lt_data_loaders.py └── imbalance_cifar.py ├── logger ├── __init__.py ├── logger.py └── logger_config.json ├── main.py ├── model ├── ResnetCifar.py ├── ResnetImagenet.py ├── loss.py ├── metric.py └── model.py ├── parse_config.py ├── trainer ├── __init__.py └── trainer.py └── utils ├── __init__.py ├── gflops.py └── util.py /README.md: -------------------------------------------------------------------------------- 1 | # Environment requirements 2 | ``` 3 | python >= 3.8 4 | pytorch >= 1.8 5 | ``` 6 | 7 | # Command 8 | ``` 9 | python main.py -c "configs/cifar100_lt.json" 10 | ``` 11 | 12 | # Citation 13 | ``` 14 | @inproceedings{li2022trustworthy, 15 | title={Trustworthy Long-Tailed Classification}, 16 | author={Li, Bolian and Han, Zongbo and Li, Haining and Fu, Huazhu and Zhang, Changqing}, 17 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 18 | pages={6970--6979}, 19 | year={2022} 20 | } 21 | ``` 22 | -------------------------------------------------------------------------------- /base/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_data_loader import * 2 | from .base_model import * 3 | from .base_trainer import * 4 | -------------------------------------------------------------------------------- /base/base_data_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data import DataLoader 3 | from torch.utils.data.dataloader import default_collate 4 | from torch.utils.data.sampler import SubsetRandomSampler 5 | 6 | class BaseDataLoader(DataLoader): 7 | def __init__(self, dataset, batch_size, shuffle, validation_split, num_workers, collate_fn=default_collate): 8 | self.validation_split = validation_split 9 | self.shuffle = shuffle 10 | 11 | self.batch_idx = 0 12 | self.n_samples = len(dataset) 13 | 14 | self.sampler, self.valid_sampler = self._split_sampler(self.validation_split) 15 | 16 | self.init_kwargs = { 17 | 'dataset': dataset, 18 | 'batch_size': batch_size, 19 | 'shuffle': self.shuffle, 20 | 'collate_fn': collate_fn, 21 | 'num_workers': num_workers 22 | } 23 | super().__init__(sampler=self.sampler, **self.init_kwargs) 24 | 25 | def _split_sampler(self, split): 26 | if split == 0.0: 27 | return None,None 28 | 29 | idx_full = np.arange(self.n_samples) 30 | np.random.shuffle(idx_full) 31 | 32 | if isinstance(split, int): 33 | assert split > 0 34 | assert split < self.n_samples, "validation set size is configured to be larger than entire dataset." 35 | len_valid = split 36 | else: 37 | len_valid = int(self.n_samples * split) 38 | 39 | valid_idx = idx_full[0:len_valid] 40 | train_idx = np.delete(idx_full, np.arange(0, len_valid)) 41 | 42 | train_sampler = SubsetRandomSampler(train_idx) 43 | valid_sampler = SubsetRandomSampler(valid_idx) 44 | 45 | # turn off shuffle option which is mutually exclusive with sampler 46 | self.shuffle = False 47 | self.n_samples = len(train_idx) 48 | 49 | return train_sampler, valid_sampler 50 | 51 | def split_validation(self): 52 | if self.valid_sampler is None: 53 | return None 54 | else: 55 | return DataLoader(sampler=self.valid_sampler, **self.init_kwargs) 56 | -------------------------------------------------------------------------------- /base/base_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import numpy as np 3 | from abc import abstractmethod 4 | 5 | 6 | class BaseModel(nn.Module): 7 | @abstractmethod 8 | def forward(self, *inputs): 9 | raise NotImplementedError 10 | -------------------------------------------------------------------------------- /base/base_trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn 3 | from abc import abstractmethod 4 | from numpy import inf 5 | from utils import load_state_dict, rename_parallel_state_dict 6 | 7 | class BaseTrainer: 8 | def __init__(self,model,criterion,opt,config): 9 | 10 | # Check with nvidia-smi about the available GPUs. Only 1 GPU is required. 11 | self.device = torch.device('cuda:0') 12 | 13 | self.config = config 14 | self.model = model.to(self.device) 15 | self.criterion = criterion.to(self.device) 16 | self.opt = opt 17 | self.epochs = config['trainer']['epochs'] 18 | 19 | @abstractmethod 20 | def _train_epoch(self,epoch): 21 | raise NotImplementedError 22 | 23 | def train(self): 24 | for epoch in range(1,self.epochs+1): 25 | result = self._train_epoch(epoch) 26 | -------------------------------------------------------------------------------- /configs/cifar100_lt.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "CIFAR100_LT", 3 | "arch": { 4 | "type": "ResNet32Model", 5 | "args": { 6 | "num_classes": 100, 7 | "num_experts": 3 8 | } 9 | }, 10 | "data_loader": { 11 | "type": "ImbalanceCIFAR100DataLoader", 12 | "args":{ 13 | "data_dir": "~", 14 | "batch_size": 128, 15 | "num_workers": 2 16 | } 17 | }, 18 | "optimizer": { 19 | "type": "SGD", 20 | "args":{ 21 | "lr": 0.1, 22 | "weight_decay": 5e-4, 23 | "momentum": 0.9, 24 | "nesterov": true 25 | } 26 | }, 27 | "loss": { 28 | "type": "TLCLoss", 29 | "args": {"reweight_epoch": 160} 30 | }, 31 | "lr_scheduler": { 32 | "type": "CustomLR", 33 | "args": { 34 | "step1": 160, 35 | "step2": 180, 36 | "gamma": 0.01, 37 | "warmup_epoch": 5 38 | } 39 | }, 40 | "trainer": { 41 | "epochs": 200, 42 | "save_dir": "saved/" 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /configs/cifar10_lt.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "CIFAR10_LT", 3 | "arch": { 4 | "type": "ResNet32Model", 5 | "args": { 6 | "num_classes": 10, 7 | "num_experts": 3 8 | } 9 | }, 10 | "data_loader": { 11 | "type": "ImbalanceCIFAR10DataLoader", 12 | "args":{ 13 | "data_dir": "~", 14 | "batch_size": 128, 15 | "num_workers": 2 16 | } 17 | }, 18 | "optimizer": { 19 | "type": "SGD", 20 | "args":{ 21 | "lr": 0.1, 22 | "weight_decay": 2e-4, 23 | "momentum": 0.9, 24 | "nesterov": true 25 | } 26 | }, 27 | "loss": { 28 | "type": "TLCLoss", 29 | "args": {"reweight_epoch": 160} 30 | }, 31 | "lr_scheduler": { 32 | "type": "CustomLR", 33 | "args": { 34 | "step1": 160, 35 | "step2": 180, 36 | "gamma": 0.01, 37 | "warmup_epoch": 5 38 | } 39 | }, 40 | "trainer": { 41 | "epochs": 200, 42 | "save_dir": "saved/" 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /configs/imagenet_lt.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "ImageNet_LT", 3 | "arch": { 4 | "type": "ResNet50Model", 5 | "args": { 6 | "num_classes": 1000, 7 | "num_experts": 3 8 | } 9 | }, 10 | "data_loader": { 11 | "type": "ImageNetLTDataLoader", 12 | "args":{ 13 | "data_dir": "~/ImageNet_LT", 14 | "batch_size": 256, 15 | "num_workers": 10 16 | } 17 | }, 18 | "optimizer": { 19 | "type": "SGD", 20 | "args":{ 21 | "lr": 0.1, 22 | "weight_decay": 2e-4, 23 | "momentum": 0.9, 24 | "nesterov": true 25 | } 26 | }, 27 | "loss": { 28 | "type": "TLCLoss", 29 | "args": { 30 | "reweight_factor": 0.02, 31 | "reweight_epoch": 80 32 | } 33 | }, 34 | "lr_scheduler": { 35 | "type": "CustomLR", 36 | "args": { 37 | "step1": 60, 38 | "step2": 80, 39 | "gamma": 0.1, 40 | "warmup_epoch": 5 41 | } 42 | }, 43 | "trainer": { 44 | "epochs": 100, 45 | "save_dir": "saved/" 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /data_loader/cifar_data_loaders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | import os, sys 5 | from torchvision import datasets, transforms 6 | from torch.utils.data import DataLoader, Dataset, Sampler 7 | from base import BaseDataLoader 8 | from PIL import Image 9 | from .imbalance_cifar import IMBALANCECIFAR10, IMBALANCECIFAR100 10 | from .imagenet_lt_data_loaders import LT_Dataset 11 | 12 | class ImbalanceCIFAR100DataLoader(DataLoader): 13 | def __init__(self,data_dir,batch_size,num_workers,training=True,retain_epoch_size=True): 14 | normalize = transforms.Normalize(mean=[0.4914,0.4822,0.4465],std=[0.2023,0.1994,0.2010]) 15 | train_trsfm = transforms.Compose([ 16 | transforms.RandomCrop(32,padding=4) , 17 | transforms.RandomHorizontalFlip() , 18 | transforms.RandomRotation(15) , 19 | transforms.ToTensor() , 20 | normalize 21 | ]) 22 | test_trsfm = transforms.Compose([transforms.ToTensor(),normalize]) 23 | 24 | if training: 25 | self.dataset = IMBALANCECIFAR100(data_dir,train=True,download=True,transform=train_trsfm) 26 | self.val_dataset = datasets.CIFAR100(data_dir,train=False,download=True,transform=test_trsfm) 27 | else: 28 | self.dataset = datasets.CIFAR100(data_dir,train=False,download=True,transform=test_trsfm) 29 | self.val_dataset = None 30 | 31 | num_classes = max(self.dataset.targets)+1 32 | assert num_classes == 100 33 | 34 | self.cls_num_list = np.histogram(self.dataset.targets,bins=num_classes)[0].tolist() 35 | 36 | self.init_kwargs = { 37 | 'batch_size' : batch_size , 38 | 'shuffle' : True , 39 | 'num_workers' : num_workers , 40 | 'drop_last' : False 41 | } 42 | super().__init__(dataset=self.dataset,**self.init_kwargs,sampler=None) 43 | 44 | def split_validation(self,type='test'): 45 | return DataLoader( 46 | dataset = self.OOD_dataset if type=='OOD' else self.val_dataset , 47 | batch_size = 4096 , 48 | shuffle = False , 49 | num_workers = 2 , 50 | drop_last = False 51 | ) 52 | 53 | class ImbalanceCIFAR10DataLoader(DataLoader): 54 | def __init__(self,data_dir,batch_size,num_workers,training=True,retain_epoch_size=True): 55 | normalize = transforms.Normalize(mean=[0.4914,0.4822,0.4465],std=[0.2023,0.1994,0.2010]) 56 | train_trsfm = transforms.Compose([ 57 | transforms.RandomCrop(32, padding=4), 58 | transforms.RandomHorizontalFlip() , 59 | transforms.RandomRotation(15) , 60 | transforms.ToTensor() , 61 | normalize 62 | ]) 63 | test_trsfm = transforms.Compose([transforms.ToTensor(),normalize]) 64 | 65 | if training: 66 | self.dataset = IMBALANCECIFAR10(data_dir,train=True,download=True,transform=train_trsfm) 67 | self.val_dataset = datasets.CIFAR10(data_dir,train=False,download=True,transform=test_trsfm) 68 | else: 69 | self.dataset = datasets.CIFAR10(data_dir,train=False,download=True,transform=test_trsfm) 70 | self.val_dataset = None 71 | 72 | # Uncomment to use OOD datasets 73 | self.OOD_dataset = None 74 | # self.OOD_dataset = datasets.SVHN(data_dir,split="test",download=True,transform=test_trsfm) 75 | # self.OOD_dataset = LT_Dataset('../ImageNet_LT/ImageNet_LT_open','../ImageNet_LT/ImageNet_LT_open.txt',train_trsfm) 76 | # self.OOD_dataset = LT_Dataset('../Places_LT/Places_LT_open','../Places_LT/Places_LT_open.txt',train_trsfm) 77 | 78 | num_classes = max(self.dataset.targets)+1 79 | assert num_classes == 10 80 | 81 | self.cls_num_list = np.histogram(self.dataset.targets,bins=num_classes)[0].tolist() 82 | 83 | self.init_kwargs = { 84 | 'batch_size' : batch_size , 85 | 'shuffle' : True , 86 | 'num_workers' : num_workers , 87 | 'drop_last' : False 88 | } 89 | super().__init__(dataset=self.dataset,**self.init_kwargs,sampler=None) 90 | 91 | def split_validation(self,type='test'): 92 | return DataLoader( 93 | dataset = self.OOD_dataset if type=='OOD' else self.val_dataset , 94 | batch_size = 4096 , 95 | shuffle = False , 96 | num_workers = 2 , 97 | drop_last = False 98 | ) -------------------------------------------------------------------------------- /data_loader/data_loaders.py: -------------------------------------------------------------------------------- 1 | from .cifar_data_loaders import ImbalanceCIFAR100DataLoader, ImbalanceCIFAR10DataLoader 2 | from .imagenet_lt_data_loaders import ImageNetLTDataLoader 3 | -------------------------------------------------------------------------------- /data_loader/imagenet_lt_data_loaders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | import os, sys 5 | from torchvision import datasets, transforms 6 | from torch.utils.data import DataLoader, Dataset, Sampler 7 | from base import BaseDataLoader 8 | from PIL import Image 9 | 10 | class LT_Dataset(Dataset): 11 | def __init__(self,root,txt,transform): 12 | self.img_paths = [] 13 | self.labels = [] 14 | self.transform = transform 15 | with open(txt) as f: 16 | for line in f: 17 | self.img_paths.append(os.path.join(root, line.split()[0])) 18 | self.labels.append(int(line.split()[1])) 19 | self.targets = self.labels 20 | 21 | def __len__(self): 22 | return len(self.labels) 23 | 24 | def __getitem__(self,index): 25 | path,label = self.img_paths[index],self.labels[index] 26 | with open(path,'rb') as f: 27 | img = Image.open(f).convert('RGB') 28 | img = self.transform(img) 29 | return img,label 30 | 31 | class ImageNetLTDataLoader(DataLoader): 32 | def __init__(self,data_dir,batch_size,num_workers,training=True,retain_epoch_size=True): 33 | train_trsfm = transforms.Compose([ 34 | transforms.RandomResizedCrop(224) , 35 | transforms.RandomHorizontalFlip() , 36 | transforms.ColorJitter(brightness=0.4,contrast=0.4,saturation=0.4,hue=0), 37 | transforms.ToTensor() , 38 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 39 | ]) 40 | test_trsfm = transforms.Compose([ 41 | transforms.Resize(256) , 42 | transforms.CenterCrop(224) , 43 | transforms.ToTensor() , 44 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 45 | ]) 46 | 47 | # We use relative path to avoid potential bugs. It is recommended to check the paths below to ensure data loading. 48 | if training: 49 | self.dataset = LT_Dataset('../ImageNet_LT','../ImageNet_LT/ImageNet_LT_train.txt',train_trsfm) 50 | self.val_dataset = LT_Dataset('../ImageNet_LT','../ImageNet_LT/ImageNet_LT_val.txt',test_trsfm) 51 | else: # test 52 | self.dataset = LT_Dataset(data_dir,data_dir+'/ImageNet_LT_val.txt',test_trsfm) 53 | self.val_dataset = None 54 | 55 | self.n_samples = len(self.dataset) 56 | 57 | num_classes = max(self.dataset.targets)+1 58 | assert num_classes == 1000 59 | 60 | self.cls_num_list = np.histogram(self.dataset.targets,bins=num_classes)[0].tolist() 61 | 62 | self.init_kwargs = { 63 | 'batch_size' : batch_size , 64 | 'shuffle' : True , 65 | 'num_workers' : num_workers , 66 | 'drop_last' : False 67 | } 68 | super().__init__(dataset=self.dataset,**self.init_kwargs,sampler=None) 69 | 70 | def split_validation(self,type='test'): 71 | return DataLoader( 72 | dataset = self.OOD_dataset if type=='OOD' else self.val_dataset , 73 | batch_size = 512 , 74 | shuffle = False , 75 | num_workers = 10 , 76 | drop_last = False 77 | ) 78 | -------------------------------------------------------------------------------- /data_loader/imbalance_cifar.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torchvision.transforms as transforms 4 | import numpy as np 5 | from math import * 6 | 7 | class IMBALANCECIFAR10(torchvision.datasets.CIFAR10): 8 | num_class = 10 9 | decay_stride = 2.1971 10 | 11 | def __init__(self,root,imb_type='exp',train=True,transform=None,target_transform=None,download=False): 12 | super(IMBALANCECIFAR10,self).__init__(root, train, transform, target_transform, download) 13 | img_num_list = self.get_img_num_per_cls(self.num_class,imb_type) 14 | self.gen_imbalanced_data(img_num_list) 15 | 16 | def get_img_num_per_cls(self,num_class,imb_type): 17 | img_max = len(self.data)/num_class 18 | img_num_per_cls = [] 19 | if imb_type == 'exp': 20 | for cls_idx in range(num_class): 21 | num = img_max*exp(-cls_idx/self.decay_stride) 22 | img_num_per_cls.append(int(num+0.5)) 23 | else: 24 | img_num_per_cls.extend([int(img_max)]*num_class) 25 | return img_num_per_cls 26 | 27 | def gen_imbalanced_data(self,img_num_per_cls): 28 | img_max = len(self.data)/self.num_class 29 | new_data,new_targets = [],[] 30 | targets_np = np.array(self.targets,dtype=np.int64) 31 | classes = np.arange(self.num_class) 32 | 33 | self.num_per_cls = np.zeros(self.num_class) 34 | for class_i,volume_i in zip(classes,img_num_per_cls): 35 | self.num_per_cls[class_i] = volume_i 36 | idx = np.where(targets_np==class_i)[0] 37 | np.random.shuffle(idx) 38 | keep_num = volume_i+1 39 | selec_idx = idx[:keep_num] 40 | new_data.append(self.data[selec_idx,...]) 41 | new_targets.extend([class_i]*keep_num) 42 | new_data = np.vstack(new_data) 43 | self.data = new_data 44 | self.targets = new_targets 45 | 46 | def get_cls_num_list(self): 47 | return self.num_per_cls.tolist() 48 | 49 | class IMBALANCECIFAR100(IMBALANCECIFAR10): 50 | base_folder = 'cifar-100-python' 51 | url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" 52 | filename = "cifar-100-python.tar.gz" 53 | tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85' 54 | train_list = [ 55 | ['train', '16019d7e3df5f24257cddd939b257f8d'], 56 | ] 57 | test_list = [ 58 | ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'], 59 | ] 60 | meta = { 61 | 'filename': 'meta', 62 | 'key': 'fine_label_names', 63 | 'md5': '7973b15100ade9c7d40fb424638fde48', 64 | } 65 | num_class = 100 66 | decay_stride = 21.9714 67 | -------------------------------------------------------------------------------- /logger/__init__.py: -------------------------------------------------------------------------------- 1 | from .logger import * -------------------------------------------------------------------------------- /logger/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import logging.config 3 | from pathlib import Path 4 | from utils import read_json 5 | 6 | 7 | def setup_logging(save_dir, log_config='logger/logger_config.json', default_level=logging.INFO): 8 | """ 9 | Setup logging configuration 10 | """ 11 | log_config = Path(log_config) 12 | if log_config.is_file(): 13 | config = read_json(log_config) 14 | # modify logging paths based on run config 15 | for _, handler in config['handlers'].items(): 16 | if 'filename' in handler: 17 | handler['filename'] = str(save_dir / handler['filename']) 18 | 19 | logging.config.dictConfig(config) 20 | else: 21 | print("Warning: logging configuration file is not found in {}.".format(log_config)) 22 | logging.basicConfig(level=default_level) 23 | -------------------------------------------------------------------------------- /logger/logger_config.json: -------------------------------------------------------------------------------- 1 | 2 | { 3 | "version": 1, 4 | "disable_existing_loggers": false, 5 | "formatters": { 6 | "simple": {"format": "%(message)s"}, 7 | "datetime": {"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s"} 8 | }, 9 | "handlers": { 10 | "console": { 11 | "class": "logging.StreamHandler", 12 | "level": "DEBUG", 13 | "formatter": "simple", 14 | "stream": "ext://sys.stdout" 15 | }, 16 | "info_file_handler": { 17 | "class": "logging.handlers.RotatingFileHandler", 18 | "level": "INFO", 19 | "formatter": "datetime", 20 | "filename": "info.log", 21 | "maxBytes": 10485760, 22 | "backupCount": 20, "encoding": "utf8" 23 | } 24 | }, 25 | "root": { 26 | "level": "INFO", 27 | "handlers": [ 28 | "console", 29 | "info_file_handler" 30 | ] 31 | } 32 | } -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import collections 3 | import torch 4 | import numpy as np 5 | import data_loader.data_loaders as module_data 6 | import model.loss as module_loss 7 | import model.metric as module_metric 8 | import model.model as module_arch 9 | from parse_config import ConfigParser 10 | from trainer import Trainer 11 | import random 12 | from time import time 13 | 14 | def random_seed_setup(seed:int=None): 15 | torch.backends.cudnn.enabled = True 16 | if seed: 17 | print('Set random seed as',seed) 18 | torch.backends.cudnn.deterministic = True 19 | torch.manual_seed(seed) 20 | torch.cuda.manual_seed_all(seed) 21 | np.random.seed(seed) 22 | random.seed(seed) 23 | else: 24 | torch.backends.cudnn.benchmark = True 25 | 26 | def main(config): 27 | logger = config.get_logger('train') 28 | 29 | # setup data_loader instances 30 | data_loader = config.init_obj('data_loader',module_data) 31 | valid_data_loader = data_loader.split_validation() 32 | 33 | # build model architecture, then print to console 34 | model = config.init_obj('arch',module_arch) 35 | 36 | # get loss 37 | loss_class = getattr(module_loss, config["loss"]["type"]) 38 | criterion = config.init_obj('loss',module_loss, cls_num_list=data_loader.cls_num_list) 39 | 40 | # build optimizer, learning rate scheduler. delete every lines containing lr_scheduler for disabling scheduler 41 | optimizer = config.init_obj('optimizer',torch.optim,model.parameters()) 42 | 43 | if "type" in config._config["lr_scheduler"]: 44 | lr_scheduler_args = config["lr_scheduler"]["args"] 45 | gamma = lr_scheduler_args["gamma"] if "gamma" in lr_scheduler_args else 0.1 46 | print("step1, step2, warmup_epoch, gamma:",(lr_scheduler_args["step1"],lr_scheduler_args["step2"],lr_scheduler_args["warmup_epoch"],gamma)) 47 | 48 | def lr_lambda(epoch): 49 | if epoch >= lr_scheduler_args["step2"]: 50 | lr = gamma*gamma 51 | elif epoch >= lr_scheduler_args["step1"]: 52 | lr = gamma 53 | else: 54 | lr = 1 55 | warmup_epoch = lr_scheduler_args["warmup_epoch"] 56 | if epoch < warmup_epoch: 57 | lr = lr*float(1+epoch)/warmup_epoch 58 | return lr 59 | lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,lr_lambda) 60 | else: 61 | lr_scheduler = None 62 | 63 | trainer = Trainer( 64 | model , 65 | criterion , 66 | optimizer , 67 | config = config , 68 | data_loader = data_loader , 69 | valid_data_loader = valid_data_loader , 70 | lr_scheduler = lr_scheduler 71 | ) 72 | random_seed_setup() 73 | trainer.train() 74 | 75 | if __name__=='__main__': 76 | args = argparse.ArgumentParser(description='PyTorch Template') 77 | args.add_argument('-c','--config',default=None,type=str,help='config file path (default: None)') 78 | 79 | # custom cli options to modify configuration from default values given in json file. 80 | CustomArgs = collections.namedtuple('CustomArgs','flags type target') 81 | options = [ 82 | CustomArgs(['--name'],type=str,target='name'), 83 | CustomArgs(['--save_period'],type=int,target='trainer;save_period'), 84 | CustomArgs(['--distribution_aware_diversity_factor'],type=float,target='loss;args;additional_diversity_factor'), 85 | CustomArgs(['--pos_weight'],type=float,target='arch;args;pos_weight'), 86 | CustomArgs(['--collaborative_loss'],type=int,target='loss;args;collaborative_loss'), 87 | ] 88 | config = ConfigParser.from_args(args,options) 89 | 90 | # Training 91 | start = time() 92 | main(config) 93 | end = time() 94 | 95 | # Show used time 96 | minute = (end-start)/60 97 | hour = minute/60 98 | if minute<60: 99 | print('Training finished in %.1f min'%minute) 100 | else: 101 | print('Training finished in %.1f h'%hour) 102 | -------------------------------------------------------------------------------- /model/ResnetCifar.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.nn.init as init 5 | from torch.nn import Parameter 6 | import random 7 | 8 | def _weights_init(m): 9 | classname = m.__class__.__name__ 10 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 11 | init.kaiming_normal_(m.weight) 12 | 13 | class NormedLinear(nn.Module): 14 | def __init__(self,in_features,out_features): 15 | super(NormedLinear, self).__init__() 16 | self.weight = Parameter(torch.Tensor(in_features, out_features)) 17 | self.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5) 18 | def forward(self, x): 19 | return F.normalize(x,dim=1).mm(F.normalize(self.weight,dim=0)) 20 | 21 | class BasicBlock(nn.Module): 22 | def __init__(self,in_planes,planes,stride=1): 23 | super(BasicBlock, self).__init__() 24 | self.conv1 = nn.Conv2d(in_planes,planes,kernel_size=3,stride=stride,padding=1,bias=False) 25 | self.bn1 = nn.BatchNorm2d(planes) 26 | self.conv2 = nn.Conv2d(planes,planes,kernel_size=3,stride=1,padding=1,bias=False) 27 | self.bn2 = nn.BatchNorm2d(planes) 28 | self.shortcut = lambda x: x 29 | if stride!=1 or in_planes!=planes: 30 | self.planes = planes 31 | self.in_planes = in_planes 32 | self.shortcut = lambda x: F.pad(x[:,:,::2,::2],(0,0,0,0,(planes-in_planes)//2,(planes-in_planes)//2),"constant",0) 33 | def forward(self, x): 34 | out = F.relu(self.bn1(self.conv1(x))) 35 | out = self.bn2(self.conv2(out))+self.shortcut(x) 36 | out = F.relu(out) 37 | return out 38 | 39 | class ResNet_s(nn.Module): 40 | def __init__(self,block,num_blocks,num_experts,num_classes,reweight_temperature=0.2): 41 | super(ResNet_s,self).__init__() 42 | 43 | self.in_planes = 16 44 | self.num_classes = num_classes 45 | self.num_experts = num_experts 46 | self.eta = reweight_temperature 47 | 48 | self.conv1 = nn.Conv2d(3,16,kernel_size=3,stride=1,padding=1,bias=False) 49 | self.bn1 = nn.BatchNorm2d(16) 50 | self.layer1s = nn.ModuleList([self._make_layer(block,16,num_blocks[0],stride=1) for _ in range(num_experts)]) 51 | self.in_planes = self.next_in_planes 52 | 53 | self.layer2s = nn.ModuleList([self._make_layer(block,32,num_blocks[1],stride=2) for _ in range(num_experts)]) 54 | self.in_planes = self.next_in_planes 55 | self.layer3s = nn.ModuleList([self._make_layer(block,64,num_blocks[2],stride=2) for _ in range(num_experts)]) 56 | self.in_planes = self.next_in_planes 57 | 58 | self.linears = nn.ModuleList([NormedLinear(64,num_classes) for _ in range(num_experts)]) 59 | 60 | self.use_experts = list(range(num_experts)) 61 | self.apply(_weights_init) 62 | 63 | def _make_layer(self, block, planes, num_blocks, stride): 64 | strides = [stride] + [1]*(num_blocks-1) 65 | layers = [] 66 | self.next_in_planes = self.in_planes 67 | for stride in strides: 68 | layers.append(block(self.next_in_planes, planes, stride)) 69 | self.next_in_planes = planes 70 | return nn.Sequential(*layers) 71 | 72 | def _hook_before_iter(self): 73 | assert self.training, "_hook_before_iter should be called at training time only, after train() is called" 74 | 75 | for module in self.modules(): 76 | if isinstance(module, nn.BatchNorm2d): 77 | if not module.weight.requires_grad: 78 | module.eval() 79 | 80 | def forward(self,x): 81 | x = F.relu(self.bn1(self.conv1(x))) 82 | 83 | outs = [] 84 | self.logits = outs 85 | b0 = None 86 | self.w = [torch.ones(len(x),dtype=torch.bool,device=x.device)] 87 | 88 | for i in self.use_experts: 89 | xi = self.layer1s[i](x) 90 | xi = self.layer2s[i](xi) 91 | xi = self.layer3s[i](xi) 92 | xi = F.avg_pool2d(xi,xi.shape[3]) 93 | xi = xi.flatten(1) 94 | xi = self.linears[i](xi) 95 | xi = xi*30 96 | outs.append(xi) 97 | 98 | # evidential 99 | alpha = torch.exp(xi)+1 100 | S = alpha.sum(dim=1,keepdim=True) 101 | b = (alpha-1)/S 102 | u = self.num_classes/S.squeeze(-1) 103 | 104 | # update w 105 | if b0 is None: 106 | C = 0 107 | else: 108 | bb = b0.view(-1,b0.shape[1],1)@b.view(-1,1,b.shape[1]) 109 | C = bb.sum(dim=[1,2])-bb.diagonal(dim1=1,dim2=2).sum(dim=1) 110 | b0 = b 111 | self.w.append(self.w[-1]*u/(1-C)) 112 | 113 | # dynamic reweighting 114 | exp_w = [torch.exp(wi/self.eta) for wi in self.w] 115 | exp_w = [wi/wi.sum() for wi in exp_w] 116 | exp_w = [wi.unsqueeze(-1) for wi in exp_w] 117 | 118 | reweighted_outs = [outs[i]*exp_w[i] for i in self.use_experts] 119 | return sum(reweighted_outs) 120 | -------------------------------------------------------------------------------- /model/ResnetImagenet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from utils import autocast 6 | 7 | class NormedLinear(nn.Module): 8 | def __init__(self, in_features, out_features): 9 | super(NormedLinear, self).__init__() 10 | self.weight = nn.Parameter(torch.Tensor(in_features, out_features)) 11 | self.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5) 12 | def forward(self, x): 13 | return F.normalize(x, dim=1).mm(F.normalize(self.weight, dim=0)) 14 | 15 | class Bottleneck(nn.Module): 16 | def __init__(self, inplanes, planes, stride=1, downsample=None): 17 | super(Bottleneck, self).__init__() 18 | self.conv1 = nn.Conv2d(inplanes,planes,kernel_size=1,bias=False) 19 | self.bn1 = nn.BatchNorm2d(planes) 20 | self.conv2 = nn.Conv2d(planes,planes,kernel_size=3,stride=stride,padding=1, bias=False) 21 | self.bn2 = nn.BatchNorm2d(planes) 22 | self.conv3 = nn.Conv2d(planes,planes*4,kernel_size=1,bias=False) 23 | self.bn3 = nn.BatchNorm2d(planes*4) 24 | self.relu = nn.ReLU(inplace=True) 25 | self.downsample = downsample 26 | self.stride = stride 27 | 28 | def forward(self, x): 29 | residual = x 30 | 31 | out = self.conv1(x) 32 | out = self.bn1(out) 33 | out = self.relu(out) 34 | 35 | out = self.conv2(out) 36 | out = self.bn2(out) 37 | out = self.relu(out) 38 | 39 | out = self.conv3(out) 40 | out = self.bn3(out) 41 | 42 | if self.downsample is not None: 43 | residual = self.downsample(x) 44 | 45 | out += residual 46 | out = self.relu(out) 47 | 48 | return out 49 | 50 | class ResNet(nn.Module): 51 | def __init__(self,block,layers,num_experts,num_classes=1000,layer3_output_dim=None,layer4_output_dim=None,reweight_temperature=0.5): 52 | self.inplanes = 64 53 | self.num_classes = num_classes 54 | self.num_experts = num_experts 55 | self.eta = reweight_temperature 56 | self.use_experts = list(range(num_experts)) 57 | 58 | super(ResNet, self).__init__() 59 | self.conv1 = nn.Conv2d(3,64,kernel_size=7,stride=2,padding=3,bias=False) 60 | self.bn1 = nn.BatchNorm2d(64) 61 | self.relu = nn.ReLU(inplace=True) 62 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 63 | self.layer1 = self._make_layer(block, 64, layers[0]) 64 | self.inplanes = self.next_inplanes 65 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 66 | self.inplanes = self.next_inplanes 67 | 68 | if layer3_output_dim is None: 69 | layer3_output_dim = 256 70 | if layer4_output_dim is None: 71 | layer4_output_dim = 512 72 | 73 | self.layer3 = self._make_layer(block, layer3_output_dim, layers[2], stride=2) 74 | self.inplanes = self.next_inplanes 75 | self.layer4s = nn.ModuleList([self._make_layer(block, layer4_output_dim, layers[3], stride=2) for _ in range(num_experts)]) 76 | self.inplanes = self.next_inplanes 77 | self.avgpool = nn.AvgPool2d(7, stride=1) 78 | 79 | for m in self.modules(): 80 | if isinstance(m, nn.Conv2d): 81 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 82 | m.weight.data.normal_(0, math.sqrt(2. / n)) 83 | elif isinstance(m, nn.BatchNorm2d): 84 | m.weight.data.fill_(1) 85 | m.bias.data.zero_() 86 | 87 | self.linears = nn.ModuleList([NormedLinear(layer4_output_dim * 4, num_classes) for _ in range(num_experts)]) 88 | 89 | def _hook_before_iter(self): 90 | assert self.training, "_hook_before_iter should be called at training time only, after train() is called" 91 | count = 0 92 | for module in self.modules(): 93 | if isinstance(module, nn.BatchNorm2d): 94 | if module.weight.requires_grad == False: 95 | module.eval() 96 | count += 1 97 | if count > 0: 98 | print("Warning: detected at least one frozen BN, set them to eval state. Count:", count) 99 | 100 | def _make_layer(self, block, planes, blocks, stride=1): 101 | downsample = None 102 | if stride!=1 or self.inplanes!=planes*4: 103 | downsample = nn.Sequential( 104 | nn.Conv2d(self.inplanes,planes*4,kernel_size=1,stride=stride,bias=False), 105 | nn.BatchNorm2d(planes*4), 106 | ) 107 | layers = [] 108 | layers.append(block(self.inplanes, planes, stride, downsample)) 109 | self.next_inplanes = planes*4 110 | for i in range(1,blocks): 111 | layers.append(block(self.next_inplanes, planes)) 112 | return nn.Sequential(*layers) 113 | 114 | def expert_forward(self,x,ind): 115 | x = self.layer4s[ind](x) 116 | x = self.avgpool(x) 117 | x = x.view(x.size(0),-1) 118 | x = self.linears[ind](x) 119 | return x*30 120 | 121 | def forward(self, x): 122 | with autocast(): 123 | x = self.conv1(x) 124 | x = self.bn1(x) 125 | x = self.relu(x) 126 | x = self.maxpool(x) 127 | x = self.layer1(x) 128 | x = self.layer2(x) 129 | x = self.layer3(x) 130 | 131 | outs = [] 132 | self.logits = outs 133 | b0 = None 134 | self.w = [torch.ones(len(x),dtype=torch.bool,device=x.device)] 135 | 136 | for i in range(self.num_experts): 137 | xi = self.expert_forward(x,i) 138 | outs.append(xi) 139 | 140 | # evidential 141 | alpha = torch.exp(xi)+1 142 | S = alpha.sum(dim=1,keepdim=True) 143 | b = (alpha-1)/S 144 | u = self.num_classes/S.squeeze(-1) 145 | 146 | # update w 147 | if b0 is None: 148 | C = 0 149 | else: 150 | bb = b0.view(-1,b0.shape[1],1)@b.view(-1,1,b.shape[1]) 151 | C = bb.sum(dim=[1,2])-bb.diagonal(dim1=1,dim2=2).sum(dim=1) 152 | b0 = b 153 | self.w.append(self.w[-1]*u/(1-C)) 154 | 155 | # dynamic reweighting 156 | exp_w = [torch.exp(wi/self.eta) for wi in self.w] 157 | exp_w = [wi/wi.sum() for wi in exp_w] 158 | exp_w = [wi.unsqueeze(-1) for wi in exp_w] 159 | 160 | reweighted_outs = [outs[i]*exp_w[i] for i in self.use_experts] 161 | return sum(reweighted_outs) 162 | -------------------------------------------------------------------------------- /model/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from math import exp 6 | 7 | class TLCLoss(nn.Module): 8 | def __init__(self,cls_num_list=None,max_m=0.5,reweight_epoch=-1,reweight_factor=0.05,annealing=500,tau=0.54): 9 | super(TLCLoss,self).__init__() 10 | self.reweight_epoch = reweight_epoch 11 | 12 | m_list = 1./np.sqrt(np.sqrt(cls_num_list)) 13 | m_list = m_list*(max_m/np.max(m_list)) 14 | m_list = torch.tensor(m_list,dtype=torch.float,requires_grad=False) 15 | self.m_list = m_list 16 | 17 | if reweight_epoch!=-1: 18 | idx = 1 19 | betas = [0,0.9999] 20 | effective_num = 1.0 - np.power(betas[idx], cls_num_list) 21 | per_cls_weights = (1.0 - betas[idx]) / np.array(effective_num) 22 | per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(cls_num_list) 23 | self.per_cls_weights_enabled = torch.tensor(per_cls_weights, dtype=torch.float, requires_grad=False) 24 | else: 25 | self.per_cls_weights_enabled = None 26 | cls_num_list = np.array(cls_num_list) / np.sum(cls_num_list) 27 | C = len(cls_num_list) 28 | per_cls_weights = C * cls_num_list * reweight_factor + 1 - reweight_factor 29 | per_cls_weights = per_cls_weights / np.max(per_cls_weights) 30 | 31 | # save diversity per_cls_weights 32 | self.per_cls_weights_enabled_diversity = torch.tensor(per_cls_weights,dtype=torch.float,requires_grad=False).to("cuda:0") 33 | self.T = (reweight_epoch+annealing)/reweight_factor 34 | self.tau = tau 35 | 36 | def to(self,device): 37 | super().to(device) 38 | self.m_list = self.m_list.to(device) 39 | if self.per_cls_weights_enabled is not None: 40 | self.per_cls_weights_enabled = self.per_cls_weights_enabled.to(device) 41 | if self.per_cls_weights_enabled_diversity is not None: 42 | self.per_cls_weights_enabled_diversity = self.per_cls_weights_enabled_diversity.to(device) 43 | return self 44 | 45 | def _hook_before_epoch(self,epoch): 46 | if self.reweight_epoch != -1: 47 | self.epoch = epoch 48 | if epoch > self.reweight_epoch: 49 | self.per_cls_weights_base = self.per_cls_weights_enabled 50 | self.per_cls_weights_diversity = self.per_cls_weights_enabled_diversity 51 | else: 52 | self.per_cls_weights_base = None 53 | self.per_cls_weights_diversity = None 54 | 55 | def get_final_output(self,x,y): 56 | index = torch.zeros_like(x,dtype=torch.uint8,device=x.device) 57 | index.scatter_(1,y.data.view(-1,1),1) 58 | index_float = index.float() 59 | batch_m = torch.matmul(self.m_list[None,:],index_float.transpose(0,1)) 60 | batch_m = batch_m.view((-1, 1)) 61 | x_m = x-30*batch_m 62 | return torch.exp(torch.where(index,x_m,x)) 63 | 64 | def forward(self,x,y,epoch,extra_info=None): 65 | loss = 0 66 | for i in range(extra_info["num_expert"]): 67 | alpha = self.get_final_output(extra_info["logits"][i],y) 68 | S = alpha.sum(dim=1,keepdim=True) 69 | l = F.nll_loss(torch.log(alpha)-torch.log(S),y,weight=self.per_cls_weights_base,reduction="none") 70 | 71 | # KL 72 | yi = F.one_hot(y,num_classes=alpha.shape[1]) 73 | 74 | # adjusted parameters of D(p|alpha) 75 | alpha_tilde = yi+(1-yi)*(alpha+1) 76 | S_tilde = alpha_tilde.sum(dim=1,keepdim=True) 77 | kl = torch.lgamma(S_tilde)-torch.lgamma(torch.tensor(alpha_tilde.shape[1]))-torch.lgamma(alpha_tilde).sum(dim=1,keepdim=True) \ 78 | +((alpha_tilde-1)*(torch.digamma(alpha_tilde)-torch.digamma(S_tilde))).sum(dim=1,keepdim=True) 79 | l += epoch/self.T*kl.squeeze(-1) 80 | 81 | # diversity 82 | if self.per_cls_weights_diversity is not None: 83 | diversity_temperature = self.per_cls_weights_diversity.view((1,-1)) 84 | temperature_mean = diversity_temperature.mean().item() 85 | else: 86 | diversity_temperature = 1 87 | temperature_mean = 1 88 | output_dist = F.log_softmax(extra_info["logits"][i]/diversity_temperature,dim=1) 89 | with torch.no_grad(): 90 | mean_output_dist = F.softmax(x/diversity_temperature,dim=1) 91 | l -= 0.01*temperature_mean*temperature_mean*F.kl_div(output_dist,mean_output_dist,reduction="none").sum(dim=1) 92 | 93 | # dynamic engagement 94 | w = extra_info['w'][i]/extra_info['w'][i].max() 95 | w = torch.where(w>self.tau,True,False) 96 | l = (w*l).sum()/w.sum() 97 | loss += l.mean() 98 | 99 | return loss 100 | -------------------------------------------------------------------------------- /model/metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from sklearn.metrics import * 4 | from scipy import interpolate 5 | 6 | def ACC(output,target,u=None,region_len=100/3): 7 | with torch.no_grad(): 8 | pred = torch.argmax(output,dim=1) 9 | correct = (pred==target) 10 | region_correct = (pred/region_len).long()==(target/region_len).long() 11 | acc = correct.sum().item()/len(target) 12 | region_acc = region_correct.sum().item()/len(target) 13 | split_acc = [0,0,0] 14 | 15 | # count number of classes for each region 16 | num_class = int(3*region_len) 17 | region_idx = (torch.arange(num_class)/region_len).long() 18 | region_vol = [ 19 | num_class-torch.count_nonzero(region_idx).item(), 20 | torch.where(region_idx==1,True,False).sum().item(), 21 | torch.where(region_idx==2,True,False).sum().item() 22 | ] 23 | target_count = target.bincount().cpu().numpy() 24 | region_vol = [target_count[:region_vol[0]].sum(), target_count[region_vol[0]:(region_vol[0]+region_vol[1])].sum(),target_count[-region_vol[2]:].sum()] 25 | for i in range(len(target)): 26 | split_acc[region_idx[target[i].item()]] += correct[i].item() 27 | split_acc = [split_acc[i]/region_vol[i] for i in range(3)] 28 | 29 | print('Classification ACC:') 30 | print('\t all \t =',acc) 31 | print('\t region =',region_acc) 32 | print('\t head \t =',split_acc[0]) 33 | print('\t med \t =',split_acc[1]) 34 | print('\t tail \t =',split_acc[2]) 35 | return acc, region_acc, split_acc[0], split_acc[1], split_acc[2] 36 | -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from base import BaseModel 4 | from . import ResnetImagenet 5 | from . import ResnetCifar 6 | 7 | class Model(BaseModel): 8 | requires_target = False 9 | 10 | def __init__(self, num_classes, backbone_class=None): 11 | super().__init__() 12 | if backbone_class is not None: # Do not init backbone here if None 13 | self.backbone = backbone_class(num_classes) 14 | 15 | def _hook_before_iter(self): 16 | self.backbone._hook_before_iter() 17 | 18 | def forward(self,x): 19 | return self.backbone(x) 20 | 21 | class ResNet32Model(Model): 22 | def __init__(self,num_classes,num_experts=1,**kwargs): 23 | super().__init__(num_classes,None) 24 | self.backbone = ResnetCifar.ResNet_s( 25 | ResnetCifar.BasicBlock , 26 | [5,5,5] , 27 | num_classes = num_classes , 28 | num_experts = num_experts , 29 | **kwargs 30 | ) 31 | 32 | class ResNet50Model(Model): 33 | def __init__(self,num_classes,layer3_output_dim=None,layer4_output_dim=None,num_experts=1,**kwargs): 34 | super().__init__(num_classes,None) 35 | self.backbone = ResnetImagenet.ResNet( 36 | ResnetImagenet.Bottleneck , 37 | [3,4,6,3] , 38 | num_classes = num_classes , 39 | layer3_output_dim = layer3_output_dim , 40 | layer4_output_dim = layer4_output_dim , 41 | num_experts = num_experts , 42 | **kwargs 43 | ) 44 | -------------------------------------------------------------------------------- /parse_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from pathlib import Path 4 | from functools import reduce, partial 5 | from operator import getitem 6 | from datetime import datetime 7 | from logger import setup_logging 8 | from utils import read_json, write_json 9 | 10 | 11 | class ConfigParser: 12 | def __init__(self, config, resume=None, modification=None, load_crt=None, run_id=None): 13 | """ 14 | class to parse configuration json file. Handles hyperparameters for training, initializations of modules, checkpoint saving 15 | and logging module. 16 | :param config: Dict containing configurations, hyperparameters for training. contents of `config.json` file for example. 17 | :param resume: String, path to the checkpoint being loaded. 18 | :param modification: Dict keychain:value, specifying position values to be replaced from config dict. 19 | :param run_id: Unique Identifier for training processes. Used to save checkpoints and training log. Timestamp is being used as default 20 | """ 21 | # load config file and apply modification 22 | self._config = _update_config(config, modification) 23 | self.resume = resume 24 | 25 | self.load_crt = load_crt 26 | 27 | # set save_dir where trained model and log will be saved. 28 | save_dir = Path(self.config['trainer']['save_dir']) 29 | 30 | exper_name = self.config['name'] 31 | if run_id is None: # use timestamp as default run-id 32 | run_id = datetime.now().strftime(r'%m%d_%H%M%S') 33 | self._save_dir = save_dir / 'models' / exper_name / run_id 34 | self._log_dir = save_dir / 'log' / exper_name / run_id 35 | 36 | # make directory for saving checkpoints and log. 37 | exist_ok = run_id == '' 38 | self.save_dir.mkdir(parents=True, exist_ok=exist_ok) 39 | self.log_dir.mkdir(parents=True, exist_ok=exist_ok) 40 | 41 | # save updated config file to the checkpoint dir 42 | write_json(self.config, self.save_dir / 'config.json') 43 | 44 | # configure logging module 45 | setup_logging(self.log_dir) 46 | self.log_levels = { 47 | 0: logging.WARNING, 48 | 1: logging.INFO, 49 | 2: logging.DEBUG 50 | } 51 | 52 | @classmethod 53 | def from_args(cls, args, options=''): 54 | """ 55 | Initialize this class from some cli arguments. Used in train, test. 56 | """ 57 | for opt in options: 58 | args.add_argument(*opt.flags, default=None, type=opt.type) 59 | if not isinstance(args, tuple): 60 | args = args.parse_args() 61 | 62 | if hasattr(args, "load_crt"): 63 | load_crt = args.load_crt 64 | else: 65 | load_crt = None 66 | 67 | msg_no_cfg = "Configuration file need to be specified. Add '-c config.json', for example." 68 | assert args.config is not None, msg_no_cfg 69 | cfg_fname = Path(args.config) 70 | 71 | config = read_json(cfg_fname) 72 | 73 | # parse custom cli options into dictionary 74 | modification = {opt.target : getattr(args, _get_opt_name(opt.flags)) for opt in options} 75 | return cls(config,modification,load_crt=load_crt) 76 | 77 | def init_obj(self, name, module, *args, **kwargs): 78 | """ 79 | Finds a function handle with the name given as 'type' in config, and returns the 80 | instance initialized with corresponding arguments given. 81 | 82 | `object = config.init_obj('name', module, a, b=1)` 83 | is equivalent to 84 | `object = module.name(a, b=1)` 85 | """ 86 | module_name = self[name]['type'] 87 | module_args = dict(self[name]['args']) if 'args' in self[name] else dict() 88 | 89 | module_args.update(kwargs) 90 | return getattr(module, module_name)(*args, **module_args) 91 | 92 | def init_ftn(self, name, module, *args, **kwargs): 93 | """ 94 | Finds a function handle with the name given as 'type' in config, and returns the 95 | function with given arguments fixed with functools.partial. 96 | 97 | `function = config.init_ftn('name', module, a, b=1)` 98 | is equivalent to 99 | `function = lambda *args, **kwargs: module.name(a, *args, b=1, **kwargs)`. 100 | """ 101 | module_name = self[name]['type'] 102 | module_args = dict(self[name]['args']) 103 | assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed' 104 | module_args.update(kwargs) 105 | return partial(getattr(module, module_name), *args, **module_args) 106 | 107 | def __getitem__(self, name): 108 | """Access items like ordinary dict.""" 109 | return self.config[name] 110 | 111 | def get_logger(self, name, verbosity=2): 112 | msg_verbosity = 'verbosity option {} is invalid. Valid options are {}.'.format(verbosity, self.log_levels.keys()) 113 | assert verbosity in self.log_levels, msg_verbosity 114 | logger = logging.getLogger(name) 115 | logger.setLevel(self.log_levels[verbosity]) 116 | return logger 117 | 118 | # setting read-only attributes 119 | @property 120 | def config(self): 121 | return self._config 122 | 123 | @property 124 | def save_dir(self): 125 | return self._save_dir 126 | 127 | @property 128 | def log_dir(self): 129 | return self._log_dir 130 | 131 | # helper functions to update config dict with custom cli options 132 | def _update_config(config, modification): 133 | if modification is None: 134 | return config 135 | 136 | for k, v in modification.items(): 137 | if v is not None: 138 | _set_by_path(config, k, v) 139 | return config 140 | 141 | def _get_opt_name(flags): 142 | for flg in flags: 143 | if flg.startswith('--'): 144 | return flg.replace('--', '') 145 | return flags[0].replace('--', '') 146 | 147 | def _set_by_path(tree, keys, value): 148 | """Set a value in a nested object in tree by sequence of keys.""" 149 | keys = keys.split(';') 150 | _get_by_path(tree, keys[:-1])[keys[-1]] = value 151 | 152 | def _get_by_path(tree, keys): 153 | """Access a nested object in tree by sequence of keys.""" 154 | return reduce(getitem, keys, tree) 155 | -------------------------------------------------------------------------------- /trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainer import * 2 | -------------------------------------------------------------------------------- /trainer/trainer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from torchvision.utils import make_grid 5 | from base import BaseTrainer 6 | from utils import inf_loop, MetricTracker,load_state_dict,rename_parallel_state_dict,autocast 7 | import model.model as module_arch 8 | from model.metric import * 9 | from tqdm import tqdm 10 | 11 | class Trainer(BaseTrainer): 12 | def __init__(self,model,criterion,opt,config,data_loader,valid_data_loader,lr_scheduler): 13 | super().__init__(model,criterion,opt,config) 14 | self.config = config 15 | self.data_loader = data_loader 16 | self.len_epoch = len(self.data_loader) 17 | self.valid_data_loader = valid_data_loader 18 | self.val_targets = torch.tensor(valid_data_loader.dataset.targets,device=self.device).long() 19 | self.num_class = self.val_targets.max().item()+1 20 | self.lr_scheduler = lr_scheduler 21 | 22 | def _train_epoch(self,epoch): 23 | self.model.train() 24 | self.model._hook_before_iter() 25 | self.criterion._hook_before_epoch(epoch) 26 | 27 | total_loss = [] 28 | for batch_id,(data,target) in tqdm(enumerate(self.data_loader)): 29 | data,target = data.to(self.device),target.to(self.device) 30 | self.opt.zero_grad() 31 | 32 | with autocast(): 33 | output = self.model(data) 34 | extra_info = { 35 | "num_expert" : len(self.model.backbone.logits) , 36 | "logits" : self.model.backbone.logits , 37 | 'w' : self.model.backbone.w 38 | } 39 | loss = self.criterion(x=output,y=target,epoch=epoch,extra_info=extra_info) 40 | loss.backward() 41 | 42 | self.opt.step() 43 | total_loss.append(loss.item()) 44 | 45 | self._valid_epoch(epoch) 46 | print("loss =",sum(total_loss)/len(total_loss)) 47 | 48 | if self.lr_scheduler is not None: 49 | self.lr_scheduler.step() 50 | 51 | def _valid_epoch(self,epoch): 52 | self.model.eval() 53 | output = torch.empty(0,self.num_class,dtype=torch.float32,device=self.device) 54 | uncertainty = torch.empty(0,dtype=torch.float32,device=self.device) 55 | for _,(data,_) in enumerate(self.valid_data_loader): 56 | data = data.to(self.device) 57 | 58 | with torch.no_grad(): 59 | o = self.model(data) 60 | u = self.model.backbone.w[-1] 61 | output = torch.cat([output,o.detach()],dim=0) 62 | uncertainty = torch.cat([uncertainty,u.detach()],dim=0) 63 | 64 | print(f'================ Epoch: {epoch:03d} ================') 65 | ACC(output,self.val_targets,uncertainty,region_len=self.num_class/3) 66 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .util import * 2 | -------------------------------------------------------------------------------- /utils/gflops.py: -------------------------------------------------------------------------------- 1 | # Note: the gflops.py is experimental and may be inaccurate. If you experience any problems, please tell us. 2 | # Please run this in project directory: python -m utils.gflops 3 | 4 | import torch 5 | from torchvision.models import resnet50 6 | from thop import profile 7 | import sys 8 | import model.model as models 9 | import argparse 10 | parser = argparse.ArgumentParser() 11 | 12 | # Examples: 13 | # ImageNet-LT 14 | # ResNeXt50 15 | # python -m utils.gflops ResNeXt50Model 0 --num_experts 4 --reduce_dim True --use_norm False 16 | 17 | # iNaturalist 18 | ## LDAM 19 | # python -m utils.gflops ResNet50Model 1 --num_experts 3 --reduce_dim True --use_norm True 20 | 21 | # Imbalance CIFAR 100 22 | ## LDAM 23 | # python -m utils.gflops ResNet32Model 2 --num_experts 3 --reduce_dim True --use_norm True 24 | 25 | 26 | parser.add_argument("model_name", type=str) 27 | parser.add_argument("dataset", type=str, help="0: ImageNet-LT, 1: iNaturalist, 2: Imbalance CIFAR 100") 28 | parser.add_argument("--num_experts", type=int, default=1) 29 | parser.add_argument("--layer2_dim", type=int, default=0) 30 | parser.add_argument("--layer3_dim", type=int, default=0) 31 | parser.add_argument("--layer4_dim", type=int, default=0) 32 | parser.add_argument("--reduce_dim", type=str, default="False", help="True: reduce dimension") 33 | parser.add_argument("--use_norm", type=str, default="False", help="True: use_norm") 34 | parser.add_argument("--ea_percentage", type=str, default=None, help="Percentage of passing each expert: only use this if you are calculating GFLOPs for an EA module. Example: 40.99,9.47,49.54") 35 | 36 | args = parser.parse_args() 37 | 38 | model_name = args.model_name 39 | num_classes_dict = { 40 | "0": 1000, 41 | "1": 8142, 42 | "2": 100 43 | } 44 | dataset_name_dict = { 45 | "0": "ImageNet-LT", 46 | "1": "iNaturalist 2018", 47 | "2": "Imbalanced CIFAR 100" 48 | } 49 | print("Using dataset", dataset_name_dict[args.dataset]) 50 | 51 | num_classes = num_classes_dict[args.dataset] 52 | 53 | def gflops_normed_linear(m, x, y): 54 | # per output element 55 | num_instance = y.size(0) 56 | total_ops = m.weight.numel() * num_instance + m.weight.size(0) # weight normalization can be ignored 57 | m.total_ops += torch.DoubleTensor([int(total_ops)]) 58 | 59 | num_experts = args.num_experts 60 | layer2_dim = args.layer2_dim 61 | layer3_dim = args.layer3_dim 62 | layer4_dim = args.layer4_dim 63 | reduced_dim = True if args.reduce_dim == "True" else False 64 | use_norm = True if args.use_norm == "True" else False 65 | ea_percentage = args.ea_percentage 66 | if ea_percentage is not None: 67 | ea_percentage = [float(item) for item in ea_percentage.split(",")] 68 | force_all = True 69 | loop = range(num_experts) 70 | else: 71 | force_all = False 72 | loop = [num_experts-1] 73 | 74 | total_macs = 0 75 | 76 | for i in loop: # i: num_experts - 1 so we need to add one 77 | model_arg = { 78 | "num_classes": num_classes, 79 | "num_experts": i+1, 80 | **({"layer2_output_dim": layer2_dim} if layer2_dim else {}), 81 | **({"layer3_output_dim": layer3_dim} if layer3_dim else {}), 82 | **({"layer4_output_dim": layer4_dim} if layer4_dim else {}), 83 | **({"reduce_dimension": reduced_dim} if reduced_dim else {}), 84 | **({"use_norm": use_norm} if use_norm else {}), 85 | **({"force_all": force_all} if force_all else {}) 86 | } 87 | 88 | print("Model Name: {}, Model Arg: {}".format(model_name, model_arg)) 89 | 90 | model = (getattr(models, model_name))(**model_arg) 91 | 92 | model = model.backbone 93 | model = model.eval() 94 | model = model 95 | 96 | if num_classes == 10 or num_classes == 100: # model inputs are different for CIFAR 97 | input = torch.randn(1, 3, 32, 32) 98 | else: 99 | input = torch.randn(1, 3, 224, 224) 100 | 101 | input_dim = input.shape 102 | print("Using input size", input_dim) 103 | macs, _ = profile(model, inputs=(input, ), verbose=False, custom_ops={ 104 | models.resnet_cifar.NormedLinear: gflops_normed_linear, 105 | models.ea_resnet_cifar.NormedLinear: gflops_normed_linear, 106 | models.ResNet.NormedLinear: gflops_normed_linear, 107 | models.EAResNet.NormedLinear: gflops_normed_linear 108 | }) 109 | if force_all: 110 | percentage_curr = ea_percentage[i] 111 | total_macs += percentage_curr * macs / 100 112 | else: 113 | total_macs += macs 114 | 115 | print("macs(G):", total_macs/1000/1000/1000) 116 | print() 117 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import json 3 | import pandas as pd 4 | from pathlib import Path 5 | from itertools import repeat 6 | from collections import OrderedDict 7 | 8 | # WARNING: 9 | # There is no guarantee that it will work or be used on a model. Please do use it with caution unless you make sure everything is working. 10 | use_fp16 = False 11 | 12 | if use_fp16: 13 | from torch.cuda.amp import autocast 14 | else: 15 | class Autocast(): # This is a dummy autocast class 16 | def __init__(self): 17 | pass 18 | def __enter__(self, *args, **kwargs): 19 | pass 20 | def __call__(self, arg=None): 21 | if arg is None: 22 | return self 23 | return arg 24 | def __exit__(self, *args, **kwargs): 25 | pass 26 | 27 | autocast = Autocast() 28 | 29 | def rename_parallel_state_dict(state_dict): 30 | count = 0 31 | for k in list(state_dict.keys()): 32 | if k.startswith('module.'): 33 | v = state_dict.pop(k) 34 | renamed = k[7:] 35 | state_dict[renamed] = v 36 | count += 1 37 | if count > 0: 38 | print("Detected DataParallel: Renamed {} parameters".format(count)) 39 | return count 40 | 41 | def load_state_dict(model, state_dict, no_ignore=False): 42 | own_state = model.state_dict() 43 | count = 0 44 | for name, param in state_dict.items(): 45 | if name not in own_state: # ignore 46 | print("Warning: {} ignored because it does not exist in state_dict".format(name)) 47 | assert not no_ignore, "Ignoring param that does not exist in model's own state dict is not allowed." 48 | continue 49 | if isinstance(param, torch.nn.Parameter): 50 | # backwards compatibility for serialized parameters 51 | param = param.data 52 | try: 53 | own_state[name].copy_(param) 54 | except RuntimeError as e: 55 | print("Error in copying parameter {}, source shape: {}, destination shape: {}".format(name, param.shape, own_state[name].shape)) 56 | raise e 57 | count += 1 58 | if count != len(own_state): 59 | print("Warning: Model has {} parameters, copied {} from state dict".format(len(own_state), count)) 60 | return count 61 | 62 | def ensure_dir(dirname): 63 | dirname = Path(dirname) 64 | if not dirname.is_dir(): 65 | dirname.mkdir(parents=True, exist_ok=False) 66 | 67 | def read_json(fname): 68 | fname = Path(fname) 69 | with fname.open('rt') as handle: 70 | return json.load(handle, object_hook=OrderedDict) 71 | 72 | def write_json(content, fname): 73 | fname = Path(fname) 74 | with fname.open('wt') as handle: 75 | json.dump(content, handle, indent=4, sort_keys=False) 76 | 77 | def inf_loop(data_loader): 78 | ''' wrapper function for endless data loader. ''' 79 | for loader in repeat(data_loader): 80 | yield from loader 81 | 82 | class MetricTracker: 83 | def __init__(self, *keys): 84 | self._data = pd.DataFrame(index=keys, columns=['total', 'counts', 'average']) 85 | self.reset() 86 | 87 | def reset(self): 88 | for col in self._data.columns: 89 | self._data[col].values[:] = 0 90 | 91 | def update(self, key, value, n=1): 92 | if isinstance(value, tuple) and len(value) == 2: 93 | value, n = value 94 | self._data.total[key] += value * n 95 | self._data.counts[key] += n 96 | self._data.average[key] = self._data.total[key] / self._data.counts[key] 97 | 98 | def avg(self, key): 99 | return self._data.average[key] 100 | 101 | def result(self): 102 | return dict(self._data.average) 103 | --------------------------------------------------------------------------------