├── figure.PNG ├── loss └── SoftmaxLoss.py ├── config └── ImageNet_LT │ ├── clip_A_rn50.yaml │ ├── test.yaml │ └── clip_B_rn50.yaml ├── data ├── ImageNet │ └── gen_txt.py ├── ClassAwareSampler.py ├── dataloader.py ├── MixedPrioritizedSampler.py └── ClassPrioritySampler.py ├── README.md ├── logger.py ├── main.py ├── utils.py ├── classes.py └── train.py /figure.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gaopengcuhk/BALLAD/HEAD/figure.PNG -------------------------------------------------------------------------------- /loss/SoftmaxLoss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | def create_loss (): 4 | print('Loading Softmax Loss.') 5 | return nn.CrossEntropyLoss() 6 | 7 | -------------------------------------------------------------------------------- /config/ImageNet_LT/clip_A_rn50.yaml: -------------------------------------------------------------------------------- 1 | coslr: true 2 | criterions: 3 | PerformanceLoss: 4 | def_file: ./loss/SoftmaxLoss.py 5 | loss_params: {} 6 | optim_params: null 7 | weight: 1.0 8 | endlr: 0.0 9 | last: false 10 | model: 11 | clip: 12 | optim_params: {lr: 0.00001, momentum: 0.9, weight_decay: 0.0005} 13 | params: {visual_backbone: 'RN50'} 14 | adapter: 15 | optim_params: {lr: 0.2, momentum: 0.9, weight_decay: 0.0005} 16 | params: {feat_dim: 1024} 17 | shuffle: true 18 | training_opt: 19 | phaseA: true 20 | batch_size: 512 21 | dataset: ImageNet_LT 22 | display_step: 10 23 | feature_dim: 1024 24 | log_dir: ./output_A 25 | num_classes: 1000 26 | num_epochs: 50 27 | num_workers: 12 28 | open_threshold: 0.1 29 | sampler: null 30 | scheduler_params: {gamma: 0.1, step_size: 3} 31 | -------------------------------------------------------------------------------- /data/ImageNet/gen_txt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from tqdm import tqdm 4 | 5 | root = '/datasets01_101/imagenet_full_size/061417' 6 | split2txt = { 7 | 'train': 'ImageNet_train.txt', 8 | 'val': 'ImageNet_val.txt', 9 | # 'test': 'ImageNet_test.txt', 10 | } 11 | 12 | def convert(split, txt_file): 13 | clsnames = os.listdir(os.path.join(root, split)) 14 | clsnames.sort() 15 | 16 | lines = [] 17 | for i, name in enumerate(clsnames): 18 | imgs = os.listdir(os.path.join(root, split, name)) 19 | imgs.sort() 20 | for img in imgs: 21 | lines.append(os.path.join(split, name, img) + ' ' + str(i) + '\n') 22 | 23 | with open(txt_file, 'w') as f: 24 | f.writelines(lines) 25 | 26 | for k, v in split2txt.items(): 27 | print('===> Converting {} to {}'.format(k, v)) 28 | convert(k, v) 29 | -------------------------------------------------------------------------------- /config/ImageNet_LT/test.yaml: -------------------------------------------------------------------------------- 1 | coslr: true 2 | criterions: 3 | PerformanceLoss: 4 | def_file: ./loss/SoftmaxLoss.py 5 | loss_params: {} 6 | optim_params: null 7 | weight: 1.0 8 | endlr: 0.0 9 | last: false 10 | model_dir: ./output_B/final_model_checkpoint.pth 11 | model: 12 | clip: 13 | optim_params: {lr: 0.00001, momentum: 0.9, weight_decay: 0.0005} 14 | params: {visual_backbone: 'RN50'} 15 | adapter: 16 | optim_params: {lr: 0.2, momentum: 0.9, weight_decay: 0.0005} 17 | params: {feat_dim: 1024} 18 | shuffle: true 19 | training_opt: 20 | phaseA: False 21 | batch_size: 512 22 | dataset: ImageNet_LT 23 | display_step: 10 24 | feature_dim: 1024 25 | log_dir: ./output_B 26 | num_classes: 1000 27 | num_epochs: 10 28 | num_workers: 12 29 | open_threshold: 0.1 30 | sampler: {def_file: ./data/ClassAwareSampler.py, num_samples_cls: 4, type: ClassAwareSampler} 31 | scheduler_params: {gamma: 0.1, step_size: 3} 32 | -------------------------------------------------------------------------------- /config/ImageNet_LT/clip_B_rn50.yaml: -------------------------------------------------------------------------------- 1 | coslr: true 2 | criterions: 3 | PerformanceLoss: 4 | def_file: ./loss/SoftmaxLoss.py 5 | loss_params: {} 6 | optim_params: null 7 | weight: 1.0 8 | endlr: 0.0 9 | last: false 10 | model_dir: ./output_A/final_model_checkpoint.pth 11 | model: 12 | clip: 13 | optim_params: {lr: 0.00001, momentum: 0.9, weight_decay: 0.0005} 14 | params: {visual_backbone: 'RN50'} 15 | adapter: 16 | optim_params: {lr: 0.2, momentum: 0.9, weight_decay: 0.0005} 17 | params: {feat_dim: 1024} 18 | shuffle: true 19 | training_opt: 20 | phaseA: False 21 | batch_size: 512 22 | dataset: ImageNet_LT 23 | display_step: 10 24 | feature_dim: 1024 25 | log_dir: ./output_B 26 | num_classes: 1000 27 | num_epochs: 10 28 | num_workers: 12 29 | open_threshold: 0.1 30 | sampler: {def_file: ./data/ClassAwareSampler.py, num_samples_cls: 4, type: ClassAwareSampler} 31 | scheduler_params: {gamma: 0.1, step_size: 3} 32 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BALLAD 2 | This is the official code repository for [*A Simple Long-Tailed Rocognition Baseline via Vision-Language Model.*](https://arxiv.org/pdf/2111.14745.pdf) 3 | 4 | ![image](https://github.com/gaopengcuhk/BALLAD/blob/main/figure.PNG) 5 | 6 | ## Requirements 7 | * Python3 8 | * Pytorch(1.7.1 recommended) 9 | * yaml 10 | * other necessary packages 11 | 12 | ## Datasets 13 | * ImageNet_LT 14 | * Places_LT 15 | 16 | Download the [ImageNet_2014](http://image-net.org/index) and [Places_365](http://places2.csail.mit.edu/download.html). 17 | 18 | Modify the data_root in [main.py](main.py) to refer to your own dataset path. 19 | 20 | ## Training 21 | 22 | #### Phase A 23 | ``` 24 | python main.py --cfg ./config/ImageNet_LT/clip_A_rn50.yaml 25 | ``` 26 | 27 | #### Phase B 28 | ``` 29 | python main.py --cfg ./config/ImageNet_LT/clip_B_rn50.yaml 30 | ``` 31 | 32 | ## Testing 33 | ``` 34 | python main.py --cfg ./config/ImageNet_LT/test.yaml --test 35 | ``` 36 | 37 | ## Acknowledgments 38 | 39 | The codes is based on [https://github.com/zhmiao/OpenLongTailRecognition-OLTR](https://github.com/zhmiao/OpenLongTailRecognition-OLTR) and motivated by [https://github.com/facebookresearch/classifier-balancing](https://github.com/facebookresearch/classifier-balancing). 40 | 41 | 42 | -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import csv 4 | import h5py 5 | 6 | 7 | class Logger(object): 8 | def __init__(self, logdir): 9 | self.logdir = logdir 10 | if not os.path.isdir(logdir): 11 | os.makedirs(logdir) 12 | self.cfg_file = os.path.join(self.logdir, 'cfg.yaml') 13 | self.acc_file = os.path.join(self.logdir, 'acc.csv') 14 | self.loss_file = os.path.join(self.logdir, 'loss.csv') 15 | self.ws_file = os.path.join(self.logdir, 'ws.h5') 16 | self.acc_keys = None 17 | self.loss_keys = None 18 | self.logging_ws = False 19 | 20 | def log_cfg(self, cfg): 21 | print('===> Saving cfg parameters to: ', self.cfg_file) 22 | with open(self.cfg_file, 'w') as f: 23 | yaml.dump(cfg, f) 24 | 25 | def log_acc(self, accs): 26 | if self.acc_keys is None: 27 | self.acc_keys = [k for k in accs.keys()] 28 | with open(self.acc_file, 'w') as f: 29 | writer = csv.DictWriter(f, fieldnames=self.acc_keys) 30 | writer.writeheader() 31 | writer.writerow(accs) 32 | else: 33 | with open(self.acc_file, 'a') as f: 34 | writer = csv.DictWriter(f, fieldnames=self.acc_keys) 35 | writer.writerow(accs) 36 | 37 | def log_loss(self, losses): 38 | # valid_losses = {k: v for k, v in losses.items() if v is not None} 39 | valid_losses = losses 40 | if self.loss_keys is None: 41 | self.loss_keys = [k for k in valid_losses.keys()] 42 | with open(self.loss_file, 'w') as f: 43 | writer = csv.DictWriter(f, fieldnames=self.loss_keys) 44 | writer.writeheader() 45 | writer.writerow(valid_losses) 46 | else: 47 | with open(self.loss_file, 'a') as f: 48 | writer = csv.DictWriter(f, fieldnames=self.loss_keys) 49 | writer.writerow(valid_losses) 50 | 51 | def log_ws(self, e, ws): 52 | mode = 'a' if self.logging_ws else 'w' 53 | self.logging_ws = True 54 | 55 | key = 'Epoch{:02d}'.format(e) 56 | with h5py.File(self.ws_file, mode) as f: 57 | g = f.create_group(key) 58 | for k, v in ws.items(): 59 | g.create_dataset(k, data=v) 60 | -------------------------------------------------------------------------------- /data/ClassAwareSampler.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | from torch.utils.data.sampler import Sampler 4 | import pdb 5 | 6 | ################################## 7 | ## Class-aware sampling, partly implemented by frombeijingwithlove 8 | ################################## 9 | 10 | class RandomCycleIter: 11 | 12 | def __init__ (self, data, test_mode=False): 13 | self.data_list = list(data) 14 | self.length = len(self.data_list) 15 | self.i = self.length - 1 16 | self.test_mode = test_mode 17 | 18 | def __iter__ (self): 19 | return self 20 | 21 | def __next__ (self): 22 | self.i += 1 23 | 24 | if self.i == self.length: 25 | self.i = 0 26 | if not self.test_mode: 27 | random.shuffle(self.data_list) 28 | 29 | return self.data_list[self.i] 30 | 31 | def class_aware_sample_generator (cls_iter, data_iter_list, n, num_samples_cls=1): 32 | 33 | i = 0 34 | j = 0 35 | while i < n: 36 | 37 | # yield next(data_iter_list[next(cls_iter)]) 38 | 39 | if j >= num_samples_cls: 40 | j = 0 41 | 42 | if j == 0: 43 | temp_tuple = next(zip(*[data_iter_list[next(cls_iter)]]*num_samples_cls)) 44 | yield temp_tuple[j] 45 | else: 46 | yield temp_tuple[j] 47 | 48 | i += 1 49 | j += 1 50 | 51 | class ClassAwareSampler (Sampler): 52 | 53 | def __init__(self, data_source, num_samples_cls=1,): 54 | num_classes = len(np.unique(data_source.labels)) 55 | self.class_iter = RandomCycleIter(range(num_classes)) 56 | cls_data_list = [list() for _ in range(num_classes)] 57 | for i, label in enumerate(data_source.labels): 58 | cls_data_list[label].append(i) 59 | self.data_iter_list = [RandomCycleIter(x) for x in cls_data_list] 60 | self.num_samples = max([len(x) for x in cls_data_list]) * len(cls_data_list) 61 | self.num_samples_cls = num_samples_cls 62 | 63 | def __iter__ (self): 64 | return class_aware_sample_generator(self.class_iter, self.data_iter_list, 65 | self.num_samples, self.num_samples_cls) 66 | 67 | def __len__ (self): 68 | return self.num_samples 69 | 70 | def get_sampler(): 71 | return ClassAwareSampler 72 | 73 | ################################## -------------------------------------------------------------------------------- /data/dataloader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torchvision 3 | from torch.utils.data import Dataset, DataLoader, ConcatDataset 4 | from torchvision import transforms 5 | import os 6 | from PIL import Image 7 | import io 8 | import logging 9 | logger = logging.getLogger('global') 10 | import requests 11 | import time 12 | import os.path as osp 13 | import json 14 | 15 | 16 | try: 17 | from torchvision.transforms import InterpolationMode 18 | BICUBIC = InterpolationMode.BICUBIC 19 | except ImportError: 20 | BICUBIC = Image.BICUBIC 21 | 22 | def pil_loader(img_bytes, filepath): 23 | buff = io.BytesIO(img_bytes) 24 | try: 25 | with Image.open(buff) as img: 26 | img = img.convert('RGB') 27 | except IOError: 28 | logger.info('Failed in loading {}'.format(filepath)) 29 | return img 30 | 31 | # Image statistics 32 | RGB_statistics = { 33 | 'default': { 34 | 'mean': [0.485, 0.456, 0.406], 35 | 'std':[0.229, 0.224, 0.225] 36 | }, 37 | 'clip': { 38 | 'mean': [0.48145466, 0.4578275, 0.40821073], 39 | 'std': [0.26862954, 0.26130258, 0.27577711] 40 | } 41 | } 42 | 43 | # Data transformation with augmentation 44 | def get_data_transform(split, rgb_mean, rbg_std, key='default'): 45 | data_transforms = { 46 | 'train': transforms.Compose([ 47 | transforms.RandomResizedCrop(size=224, scale=(0.5,1), interpolation=BICUBIC), 48 | transforms.RandomHorizontalFlip(p=0.5), 49 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0), 50 | transforms.ToTensor(), 51 | transforms.Normalize(rgb_mean, rbg_std) 52 | ]), 53 | 'val': transforms.Compose([ 54 | transforms.Resize(224, interpolation=BICUBIC), 55 | transforms.CenterCrop(224), 56 | transforms.ToTensor(), 57 | transforms.Normalize(rgb_mean, rbg_std) 58 | ]), 59 | 'test': transforms.Compose([ 60 | transforms.Resize(224, interpolation=BICUBIC), 61 | transforms.CenterCrop(224), 62 | transforms.ToTensor(), 63 | transforms.Normalize(rgb_mean, rbg_std) 64 | ]) 65 | } 66 | return data_transforms[split] 67 | 68 | # Dataset 69 | class LT_Dataset(Dataset): 70 | 71 | def __init__(self, root, txt, transform=None): 72 | self.img_path = [] 73 | self.labels = [] 74 | self.transform = transform 75 | with open(txt) as f: 76 | for line in f: 77 | self.img_path.append(os.path.join(root, line.split()[0])) 78 | self.labels.append(int(line.split()[1])) 79 | 80 | def __len__(self): 81 | return len(self.labels) 82 | 83 | def __getitem__(self, index): 84 | 85 | path = self.img_path[index] 86 | label = self.labels[index] 87 | 88 | with open(path, 'rb') as f: 89 | sample = Image.open(f).convert('RGB') 90 | 91 | if self.transform is not None: 92 | sample = self.transform(sample) 93 | 94 | return sample, label, index 95 | 96 | # Load datasets 97 | def load_data(data_root, dataset, phase, batch_size, sampler_dic=None, num_workers=4, test_open=False, shuffle=True): 98 | 99 | if phase == 'train_plain': 100 | txt_split = 'train' 101 | elif phase == 'train_val': 102 | txt_split = 'val' 103 | phase = 'train' 104 | else: 105 | txt_split = phase 106 | txt = './data/%s/%s_%s.txt'%(dataset, dataset, txt_split) 107 | # txt = './data/%s/%s_%s.txt'%(dataset, dataset, (phase if phase != 'train_plain' else 'train')) 108 | 109 | print('Loading data from %s' % (txt)) 110 | 111 | key = 'clip' 112 | rgb_mean, rgb_std = RGB_statistics[key]['mean'], RGB_statistics[key]['std'] 113 | 114 | if phase not in ['train', 'val']: 115 | transform = get_data_transform('test', rgb_mean, rgb_std, key) 116 | else: 117 | transform = get_data_transform(phase, rgb_mean, rgb_std, key) 118 | 119 | print('Use data transformation:', transform) 120 | 121 | set_ = LT_Dataset(data_root, txt, transform) 122 | print(len(set_)) 123 | if phase == 'test' and test_open: 124 | open_txt = './data/%s/%s_open.txt'%(dataset, dataset) 125 | print('Testing with opensets from %s'%(open_txt)) 126 | open_set_ = LT_Dataset('./data/%s/%s_open'%(dataset, dataset), open_txt, transform) 127 | set_ = ConcatDataset([set_, open_set_]) 128 | 129 | if sampler_dic and phase == 'train': 130 | print('Using sampler: ', sampler_dic['sampler']) 131 | # print('Sample %s samples per-class.' % sampler_dic['num_samples_cls']) 132 | print('Sampler parameters: ', sampler_dic['params']) 133 | return DataLoader(dataset=set_, batch_size=batch_size, shuffle=False, 134 | sampler=sampler_dic['sampler'](set_, **sampler_dic['params']), 135 | num_workers=num_workers) 136 | else: 137 | print('No sampler.') 138 | print('Shuffle is %s.' % (shuffle)) 139 | return DataLoader(dataset=set_, batch_size=batch_size, 140 | shuffle=shuffle, num_workers=num_workers) 141 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import pprint 4 | from data import dataloader 5 | from train import model 6 | import warnings 7 | import yaml 8 | from utils import source_import, get_value 9 | 10 | ##change your data root here 11 | data_root = {'ImageNet': './datasets/ImageNet/', 12 | 'Places': './datasets/Places/'} 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--cfg', default=None, type=str) 16 | parser.add_argument('--test', default=False, action='store_true') 17 | parser.add_argument('--batch_size', type=int, default=None) 18 | parser.add_argument('--test_open', default=False, action='store_true') 19 | parser.add_argument('--output_logits', default=False) 20 | parser.add_argument('--model_dir', type=str, default=None) 21 | parser.add_argument('--save_feat', type=str, default='') 22 | 23 | # KNN testing parameters 24 | parser.add_argument('--knn', default=False, action='store_true') 25 | parser.add_argument('--feat_type', type=str, default='cl2n') 26 | parser.add_argument('--dist_type', type=str, default='l2') 27 | 28 | # Learnable tau 29 | parser.add_argument('--val_as_train', default=False, action='store_true') 30 | 31 | args = parser.parse_args() 32 | 33 | def update(config, args): 34 | # Change parameters 35 | config['training_opt']['batch_size'] = \ 36 | get_value(config['training_opt']['batch_size'], args.batch_size) 37 | 38 | # Testing with KNN 39 | if args.knn and args.test: 40 | training_opt = config['training_opt'] 41 | classifier_param = { 42 | 'feat_dim': training_opt['feature_dim'], 43 | 'num_classes': training_opt['num_classes'], 44 | 'feat_type': args.feat_type, 45 | 'dist_type': args.dist_type, 46 | 'log_dir': training_opt['log_dir']} 47 | classifier = { 48 | 'def_file': './models/KNNClassifier.py', 49 | 'params': classifier_param, 50 | 'optim_params': config['networks']['classifier']['optim_params']} 51 | config['networks']['classifier'] = classifier 52 | 53 | return config 54 | 55 | # ============================================================================ 56 | # LOAD CONFIGURATIONS 57 | with open(args.cfg) as f: 58 | config = yaml.safe_load(f) 59 | config = update(config, args) 60 | 61 | test_mode = args.test 62 | test_open = args.test_open 63 | if test_open: 64 | test_mode = True 65 | output_logits = args.output_logits 66 | training_opt = config['training_opt'] 67 | dataset = training_opt['dataset'] 68 | 69 | if not os.path.isdir(training_opt['log_dir']): 70 | os.makedirs(training_opt['log_dir']) 71 | 72 | print('Loading dataset from: %s' % data_root[dataset.rstrip('_LT')]) 73 | pprint.pprint(config) 74 | 75 | def split2phase(split): 76 | if split == 'train' and args.val_as_train: 77 | return 'train_val' 78 | else: 79 | return split 80 | 81 | if not test_mode: 82 | 83 | sampler_defs = training_opt['sampler'] 84 | if sampler_defs: 85 | if sampler_defs['type'] == 'ClassAwareSampler': 86 | sampler_dic = { 87 | 'sampler': source_import(sampler_defs['def_file']).get_sampler(), 88 | 'params': {'num_samples_cls': sampler_defs['num_samples_cls']} 89 | } 90 | elif sampler_defs['type'] in ['MixedPrioritizedSampler', 91 | 'ClassPrioritySampler']: 92 | sampler_dic = { 93 | 'sampler': source_import(sampler_defs['def_file']).get_sampler(), 94 | 'params': {k: v for k, v in sampler_defs.items() \ 95 | if k not in ['type', 'def_file']} 96 | } 97 | else: 98 | sampler_dic = None 99 | 100 | splits = ['train', 'train_plain', 'val'] 101 | if dataset not in ['ImageNet']: 102 | splits.append('test') 103 | data = {x: dataloader.load_data(data_root=data_root[dataset.rstrip('_LT')], 104 | dataset=dataset, phase=split2phase(x), 105 | batch_size=training_opt['batch_size'], 106 | sampler_dic=sampler_dic, 107 | num_workers=training_opt['num_workers']) 108 | for x in splits} 109 | 110 | training_model = model(config, data, test=False) 111 | 112 | training_model.train() 113 | 114 | else: 115 | 116 | warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", 117 | UserWarning) 118 | 119 | print('Under testing phase, we load training data simply to calculate \ 120 | training data number for each class.') 121 | 122 | splits = ['train', 'val', 'test'] 123 | test_split = 'test' 124 | if 'ImageNet' == training_opt['dataset']: 125 | splits = ['train', 'val'] 126 | test_split = 'val' 127 | if args.knn or True: 128 | splits.append('train_plain') 129 | 130 | data = {x: dataloader.load_data(data_root=data_root[dataset.rstrip('_LT')], 131 | dataset=dataset, phase=x, 132 | batch_size=training_opt['batch_size'], 133 | sampler_dic=None, 134 | test_open=test_open, 135 | num_workers=training_opt['num_workers'], 136 | shuffle=False) 137 | for x in splits} 138 | 139 | training_model = model(config, data, test=True) 140 | # training_model.load_model() 141 | #training_model.load_model(args.model_dir) 142 | if args.save_feat in ['train_plain', 'val', 'test']: 143 | saveit = True 144 | test_split = args.save_feat 145 | else: 146 | saveit = False 147 | 148 | training_model.eval(phase=test_split, openset=test_open, save_feat=saveit) 149 | 150 | if output_logits: 151 | training_model.output_logits(openset=test_open) 152 | 153 | print('ALL COMPLETED.') 154 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import torch 4 | from sklearn.metrics import f1_score 5 | import torch.nn.functional as F 6 | import importlib 7 | import pdb 8 | 9 | def source_import(file_path): 10 | """This function imports python module directly from source code using importlib""" 11 | spec = importlib.util.spec_from_file_location('', file_path) 12 | module = importlib.util.module_from_spec(spec) 13 | spec.loader.exec_module(module) 14 | return module 15 | 16 | def batch_show(inp, title=None): 17 | """Imshow for Tensor.""" 18 | inp = inp.numpy().transpose((1, 2, 0)) 19 | mean = np.array([0.485, 0.456, 0.406]) 20 | std = np.array([0.229, 0.224, 0.225]) 21 | inp = std * inp + mean 22 | inp = np.clip(inp, 0, 1) 23 | plt.figure(figsize=(20,20)) 24 | plt.imshow(inp) 25 | if title is not None: 26 | plt.title(title) 27 | 28 | def print_write(print_str, log_file): 29 | print(*print_str) 30 | if log_file is None: 31 | return 32 | with open(log_file, 'a') as f: 33 | print(*print_str, file=f) 34 | 35 | def init_weights(model, weights_path, caffe=False, classifier=False): 36 | """Initialize weights""" 37 | print('Pretrained %s weights path: %s' % ('classifier' if classifier else 'feature model', 38 | weights_path)) 39 | weights = torch.load(weights_path) 40 | if not classifier: 41 | if caffe: 42 | weights = {k: weights[k] if k in weights else model.state_dict()[k] 43 | for k in model.state_dict()} 44 | else: 45 | weights = weights['state_dict_best']['feat_model'] 46 | weights = {k: weights['module.' + k] if 'module.' + k in weights else model.state_dict()[k] 47 | for k in model.state_dict()} 48 | else: 49 | weights = weights['state_dict_best']['classifier'] 50 | weights = {k: weights['module.fc.' + k] if 'module.fc.' + k in weights else model.state_dict()[k] 51 | for k in model.state_dict()} 52 | model.load_state_dict(weights) 53 | return model 54 | 55 | def shot_acc (preds, labels, train_data, many_shot_thr=100, low_shot_thr=20, acc_per_cls=False): 56 | 57 | if isinstance(train_data, np.ndarray): 58 | training_labels = np.array(train_data).astype(int) 59 | else: 60 | training_labels = np.array(train_data.dataset.labels).astype(int) 61 | 62 | if isinstance(preds, torch.Tensor): 63 | preds = preds.detach().cpu().numpy() 64 | labels = labels.detach().cpu().numpy() 65 | elif isinstance(preds, np.ndarray): 66 | pass 67 | else: 68 | raise TypeError('Type ({}) of preds not supported'.format(type(preds))) 69 | train_class_count = [] 70 | test_class_count = [] 71 | class_correct = [] 72 | for l in np.unique(labels): 73 | train_class_count.append(len(training_labels[training_labels == l])) 74 | test_class_count.append(len(labels[labels == l])) 75 | class_correct.append((preds[labels == l] == labels[labels == l]).sum()) 76 | 77 | many_shot = [] 78 | median_shot = [] 79 | low_shot = [] 80 | for i in range(len(train_class_count)): 81 | if train_class_count[i] > many_shot_thr: 82 | many_shot.append((class_correct[i] / test_class_count[i])) 83 | elif train_class_count[i] < low_shot_thr: 84 | low_shot.append((class_correct[i] / test_class_count[i])) 85 | else: 86 | median_shot.append((class_correct[i] / test_class_count[i])) 87 | 88 | if len(many_shot) == 0: 89 | many_shot.append(0) 90 | if len(median_shot) == 0: 91 | median_shot.append(0) 92 | if len(low_shot) == 0: 93 | low_shot.append(0) 94 | 95 | if acc_per_cls: 96 | class_accs = [c / cnt for c, cnt in zip(class_correct, test_class_count)] 97 | return np.mean(many_shot), np.mean(median_shot), np.mean(low_shot), class_accs 98 | else: 99 | return np.mean(many_shot), np.mean(median_shot), np.mean(low_shot) 100 | 101 | def weighted_shot_acc (preds, labels, ws, train_data, many_shot_thr=100, low_shot_thr=20): 102 | 103 | training_labels = np.array(train_data.dataset.labels).astype(int) 104 | 105 | if isinstance(preds, torch.Tensor): 106 | preds = preds.detach().cpu().numpy() 107 | labels = labels.detach().cpu().numpy() 108 | elif isinstance(preds, np.ndarray): 109 | pass 110 | else: 111 | raise TypeError('Type ({}) of preds not supported'.format(type(preds))) 112 | train_class_count = [] 113 | test_class_count = [] 114 | class_correct = [] 115 | for l in np.unique(labels): 116 | train_class_count.append(len(training_labels[training_labels == l])) 117 | test_class_count.append(ws[labels==l].sum()) 118 | class_correct.append(((preds[labels==l] == labels[labels==l]) * ws[labels==l]).sum()) 119 | 120 | many_shot = [] 121 | median_shot = [] 122 | low_shot = [] 123 | for i in range(len(train_class_count)): 124 | if train_class_count[i] > many_shot_thr: 125 | many_shot.append((class_correct[i] / test_class_count[i])) 126 | elif train_class_count[i] < low_shot_thr: 127 | low_shot.append((class_correct[i] / test_class_count[i])) 128 | else: 129 | median_shot.append((class_correct[i] / test_class_count[i])) 130 | return np.mean(many_shot), np.mean(median_shot), np.mean(low_shot) 131 | 132 | def F_measure(preds, labels, openset=False, theta=None): 133 | 134 | if openset: 135 | # f1 score for openset evaluation 136 | true_pos = 0. 137 | false_pos = 0. 138 | false_neg = 0. 139 | 140 | for i in range(len(labels)): 141 | true_pos += 1 if preds[i] == labels[i] and labels[i] != -1 else 0 142 | false_pos += 1 if preds[i] != labels[i] and labels[i] != -1 and preds[i] != -1 else 0 143 | false_neg += 1 if preds[i] != labels[i] and labels[i] == -1 else 0 144 | 145 | precision = true_pos / (true_pos + false_pos) 146 | recall = true_pos / (true_pos + false_neg) 147 | return 2 * ((precision * recall) / (precision + recall + 1e-12)) 148 | else: 149 | # Regular f1 score 150 | return f1_score(labels.detach().cpu().numpy(), preds.detach().cpu().numpy(), average='macro') 151 | 152 | def mic_acc_cal(preds, labels): 153 | if isinstance(labels, tuple): 154 | assert len(labels) == 3 155 | targets_a, targets_b, lam = labels 156 | acc_mic_top1 = (lam * preds.eq(targets_a.data).cpu().sum().float() \ 157 | + (1 - lam) * preds.eq(targets_b.data).cpu().sum().float()) / len(preds) 158 | else: 159 | acc_mic_top1 = (preds == labels).sum().item() / len(labels) 160 | return acc_mic_top1 161 | 162 | 163 | def weighted_mic_acc_cal(preds, labels, ws): 164 | acc_mic_top1 = ws[preds == labels].sum() / ws.sum() 165 | return acc_mic_top1 166 | 167 | def class_count (data): 168 | labels = np.array(data.dataset.labels) 169 | class_data_num = [] 170 | for l in np.unique(labels): 171 | class_data_num.append(len(labels[labels == l])) 172 | return class_data_num 173 | 174 | # New Added 175 | def torch2numpy(x): 176 | if isinstance(x, torch.Tensor): 177 | return x.detach().cpu().numpy() 178 | elif isinstance(x, (list, tuple)): 179 | return tuple([torch2numpy(xi) for xi in x]) 180 | else: 181 | return x 182 | 183 | def logits2score(logits, labels): 184 | scores = F.softmax(logits, dim=1) 185 | score = scores.gather(1, labels.view(-1, 1)) 186 | score = score.squeeze().cpu().numpy() 187 | return score 188 | 189 | 190 | def logits2entropy(logits): 191 | scores = F.softmax(logits, dim=1) 192 | scores = scores.cpu().numpy() + 1e-30 193 | ent = -scores * np.log(scores) 194 | ent = np.sum(ent, 1) 195 | return ent 196 | 197 | 198 | def logits2CE(logits, labels): 199 | scores = F.softmax(logits, dim=1) 200 | score = scores.gather(1, labels.view(-1, 1)) 201 | score = score.squeeze().cpu().numpy() + 1e-30 202 | ce = -np.log(score) 203 | return ce 204 | 205 | 206 | def get_priority(ptype, logits, labels): 207 | if ptype == 'score': 208 | ws = 1 - logits2score(logits, labels) 209 | elif ptype == 'entropy': 210 | ws = logits2entropy(logits) 211 | elif ptype == 'CE': 212 | ws = logits2CE(logits, labels) 213 | 214 | return ws 215 | 216 | def get_value(oldv, newv): 217 | if newv is not None: 218 | return newv 219 | else: 220 | return oldv 221 | -------------------------------------------------------------------------------- /data/MixedPrioritizedSampler.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | from torch.utils.data.sampler import Sampler 4 | 5 | 6 | class PriorityTree(object): 7 | def __init__(self, capacity, fixed_weights=None, fixed_scale=1.0, 8 | init_weight=1.0): 9 | """ 10 | fixed_weights: weights that wont be updated by self.update() 11 | """ 12 | assert fixed_weights is None or len(fixed_weights) == capacity 13 | self._capacity = capacity 14 | self._tree_size = 2 * capacity - 1 15 | self.fixed_scale = fixed_scale 16 | self.fixed_weights = np.zeros(self._capacity) if fixed_weights is None \ 17 | else fixed_weights 18 | self.tree = np.zeros(self._tree_size) 19 | self._initialized = False 20 | self.initialize(init_weight) 21 | 22 | def initialize(self, init_weight): 23 | """Initialize the tree.""" 24 | 25 | # Rescale the fixed_weights if it is not zero 26 | if self.fixed_weights.sum() > 0 and init_weight > 0: 27 | self.fixed_weights *= self.fixed_scale * init_weight * self.capacity \ 28 | / self.fixed_weights.sum() 29 | print('FixedWeights: {}'.format(self.fixed_weights.sum())) 30 | 31 | self.update_whole(init_weight + self.fixed_weights) 32 | self._initialized = True 33 | 34 | def reset_fixed_weights(self, fixed_weights, rescale=False): 35 | """ Reset the manually designed weights and 36 | update the whole tree accordingly. 37 | 38 | @rescale: rescale the fixed_weights such that 39 | fixed_weights.sum() = self.fixed_scale * adaptive_weights.sum() 40 | """ 41 | 42 | adaptive_weights = self.get_adaptive_weights() 43 | fixed_sum = fixed_weights.sum() 44 | if rescale and fixed_sum > 0: 45 | scale = self.fixed_scale * adaptive_weights.sum() / fixed_sum 46 | self.fixed_weights = fixed_weights * scale 47 | else: 48 | self.fixed_weights = fixed_weights 49 | self.update_whole(self.fixed_weights + adaptive_weights) 50 | 51 | def update_whole(self, total_weights): 52 | """ Update the whole tree based on per-example sampling weights """ 53 | lefti = self.pointer_to_treeidx(0) 54 | righti = self.pointer_to_treeidx(self.capacity-1) 55 | self.tree[lefti:righti+1] = total_weights 56 | 57 | # Iteratively find a parent layer 58 | while lefti != 0 and righti != 0: 59 | lefti = (lefti - 1) // 2 if lefti != 0 else 0 60 | righti = (righti - 1) // 2 if righti != 0 else 0 61 | 62 | # Assign paraent weights from right to left 63 | for i in range(righti, lefti-1, -1): 64 | self.tree[i] = self.tree[2*i+1] + self.tree[2*i+2] 65 | 66 | def get_adaptive_weights(self): 67 | """ Get the instance-aware weights, that are not mannually designed""" 68 | return self.get_total_weights() - self.fixed_weights 69 | 70 | def get_total_weights(self): 71 | """ Get the per-example sampling weights 72 | return shape: [capacity] 73 | """ 74 | lefti = self.pointer_to_treeidx(0) 75 | righti = self.pointer_to_treeidx(self.capacity-1) 76 | return self.tree[lefti:righti+1] 77 | 78 | @property 79 | def size(self): 80 | return self._tree_size 81 | 82 | @property 83 | def capacity(self): 84 | return self._capacity 85 | 86 | def __len__(self): 87 | return self.capacity 88 | 89 | def pointer_to_treeidx(self, pointer): 90 | assert pointer < self.capacity 91 | return int(pointer + self.capacity - 1) 92 | 93 | def update(self, pointer, priority): 94 | assert pointer < self.capacity 95 | tree_idx = self.pointer_to_treeidx(pointer) 96 | priority += self.fixed_weights[pointer] 97 | delta = priority - self.tree[tree_idx] 98 | self.tree[tree_idx] = priority 99 | while tree_idx != 0: 100 | tree_idx = (tree_idx - 1) // 2 101 | self.tree[tree_idx] += delta 102 | 103 | def get_leaf(self, value): 104 | assert self._initialized, 'PriorityTree not initialized!!!!' 105 | assert self.total > 0, 'No priority weights setted!!' 106 | parent = 0 107 | while True: 108 | left_child = 2 * parent + 1 109 | right_child = 2 * parent + 2 110 | if left_child >= len(self.tree): 111 | tgt_leaf = parent 112 | break 113 | if value < self.tree[left_child]: 114 | parent = left_child 115 | else: 116 | value -= self.tree[left_child] 117 | parent = right_child 118 | data_idx = tgt_leaf - self.capacity + 1 119 | return data_idx, self.tree[tgt_leaf] # data idx, priority 120 | 121 | @property 122 | def total(self): 123 | assert self._initialized, 'PriorityTree not initialized!!!!' 124 | return self.tree[0] 125 | 126 | @property 127 | def max(self): 128 | return np.max(self.tree[-self.capacity:]) 129 | 130 | @property 131 | def min(self): 132 | assert self._initialized, 'PriorityTree not initialized!!!!' 133 | return np.min(self.tree[-self.capacity:]) 134 | 135 | def get_weights(self): 136 | return {'fixed_weights': self.fixed_weights, 137 | 'total_weights': self.get_total_weights()} 138 | 139 | 140 | class MixedPrioritizedSampler(Sampler): 141 | """ 142 | A sampler combining manually designed sampling strategy and prioritized 143 | sampling strategy. 144 | 145 | Manually disigned strategy contains two parts: 146 | 147 | $$ manual_weights = lam * balanced_weights + (1-lam) uniform_weights 148 | 149 | Here we use a generalized version of balanced weights as follows, 150 | when n limits to infinity, balanced_weights = real_balanced_weights 151 | 152 | $$ balanced_weights = uniform_weights ^ (1/n) 153 | 154 | Then the balanced weights are scaled such that 155 | 156 | $$ balanced_weights.sum() = balance_scale * uniform_weights.sum() 157 | 158 | Note: above weights are per-class weights 159 | 160 | Overall sampling weights are given as 161 | $$ sampling_weights = manual_weights * fixed_scale + priority_weights 162 | 163 | Arguments: 164 | @dataset: A dataset 165 | @balance_scale: The scale of balanced_weights 166 | @lam: A weight to combine balanced weights and uniform weights 167 | - None for shifting sampling 168 | - 0 for uniform sampling 169 | - 1 for balanced sampling 170 | @fixed_scale: The scale of manually designed weights 171 | @cycle: shifting strategy 172 | - 0 for linear shifting: 3 -> 2 - > 1 173 | - 1 for periodic shifting: 174 | 3 -> 2 - > 1 -> 3 -> 2 - > 1 -> 3 -> 2 - > 1 175 | - 2 for cosine-like periodic shifting: 176 | 3 -> 2 - > 1 -> 1 -> 2 - > 3 -> 3 -> 2 - > 1 177 | @nroot: 178 | - None for truly balanced weights 179 | - >= 2 for pseudo-balanced weights 180 | @rescale: whether to rebalance the manual weights and priority weights 181 | every epoch 182 | @root_decay: 183 | - 'exp': for exponential decay 184 | - 'linear': for linear decay 185 | """ 186 | def __init__(self, dataset, balance_scale=1.0, fixed_scale=1.0, 187 | lam=None, epochs=90, cycle=0, nroot=None, manual_only=False, 188 | rescale=False, root_decay=None, decay_gap=30, ptype='score', 189 | alpha=1.0): 190 | """ 191 | """ 192 | self.dataset = dataset 193 | self.balance_scale = balance_scale 194 | self.fixed_scale = fixed_scale 195 | self.epochs = epochs 196 | self.lam = lam 197 | self.cycle = cycle 198 | self.nroot = nroot 199 | self.rescale = rescale 200 | self.manual_only = manual_only 201 | self.root_decay = root_decay 202 | self.decay_gap = decay_gap 203 | self.ptype = ptype 204 | self.num_samples = len(dataset) 205 | self.alpha = alpha 206 | 207 | # If using root_decay, reset relevent parameters 208 | if self.root_decay in ['exp', 'linear', 'autoexp']: 209 | self.lam = 1 210 | self.manual_only = True 211 | self.nroot = 1 212 | if self.root_decay == 'autoexp': 213 | self.decay_gap = 1 214 | self.decay_factor = np.power(nroot, 1/(self.epochs-1)) 215 | else: 216 | assert self.root_decay is None 217 | assert self.nroot is None or self.nroot >= 2 218 | print("====> Decay GAP: {}".format(self.decay_gap)) 219 | 220 | # Take care of lambdas 221 | if self.lam is None: 222 | self.freeze = False 223 | if cycle == 0: 224 | self.lams = np.linspace(0, 1, epochs) 225 | elif cycle == 1: 226 | self.lams = np.concatenate([np.linspace(0,1,epochs//3)] * 3) 227 | elif cycle == 2: 228 | self.lams = np.concatenate([np.linspace(0,1,epochs//3), 229 | np.linspace(0,1,epochs//3)[::-1], 230 | np.linspace(0,1,epochs//3)]) 231 | else: 232 | raise NotImplementedError( 233 | 'cycle = {} not implemented'.format(cycle)) 234 | else: 235 | self.lams = [self.lam] 236 | self.freeze = True 237 | 238 | # Get num of samples per class 239 | self.cls_cnts = [] 240 | self.labels = labels = np.array(self.dataset.labels) 241 | for l in np.unique(labels): 242 | self.cls_cnts.append(np.sum(labels==l)) 243 | self.num_classes = len(self.cls_cnts) 244 | self.cnts = np.array(self.cls_cnts).astype(float) 245 | 246 | # Get per-class image indexes 247 | self.cls_idxs = [[] for _ in range(self.num_classes)] 248 | for i, label in enumerate(self.dataset.labels): 249 | self.cls_idxs[label].append(i) 250 | for ci in range(self.num_classes): 251 | self.cls_idxs[ci] = np.array(self.cls_idxs[ci]) 252 | 253 | # Build balanced weights based on class counts 254 | self.balanced_weights = self.get_balanced_weights(self.nroot) 255 | self.manual_weights = self.get_manual_weights(self.lams[0]) 256 | 257 | # Setup priority tree 258 | if self.ptype == 'score': 259 | self.init_weight = 1. 260 | elif self.ptype in ['CE', 'entropy']: 261 | self.init_weight = 6.9 262 | else: 263 | raise NotImplementedError('ptype {} not implemented'.format(self.ptype)) 264 | if self.manual_only: 265 | self.init_weight = 0. 266 | self.init_weight = np.power(self.init_weight, self.alpha) 267 | self.ptree = PriorityTree(self.num_samples, self.manual_weights, 268 | fixed_scale=self.fixed_scale, 269 | init_weight=self.init_weight) 270 | 271 | def get_manual_weights(self, lam): 272 | # Merge balanced weights and uniform weights 273 | if lam == 1: 274 | manual_weights = self.balanced_weights 275 | elif lam == 0: 276 | manual_weights = np.ones(len(self.balanced_weights)) 277 | else: 278 | manual_weights = self.balanced_weights * lam + (1-lam) 279 | return manual_weights 280 | 281 | def get_balanced_weights(self, nroot): 282 | """ Calculate normalized generalized balanced weights """ 283 | 284 | cnts = self.cnts 285 | if nroot is None: 286 | # Real balanced sampling weights 287 | cls_ws = cnts.min() / cnts 288 | elif nroot >= 1: 289 | # Generalized balanced weights 290 | cls_ws = cnts / cnts.sum() 291 | cls_ws = np.power(cls_ws, 1./nroot) * cnts.sum() 292 | cls_ws = cls_ws / cnts 293 | else: 294 | raise NotImplementedError('root:{} not implemented'.format(nroot)) 295 | 296 | # Get un-normalized weights 297 | balanced_weights = np.zeros(self.num_samples) 298 | for ci in range(self.num_classes): 299 | balanced_weights[self.cls_idxs[ci]] = cls_ws[ci] 300 | 301 | # Normalization and rescale 302 | balanced_weights *= self.num_samples / balanced_weights.sum() * \ 303 | self.balance_scale 304 | return balanced_weights 305 | 306 | def __iter__(self): 307 | for _ in range(self.num_samples): 308 | w = random.random() * self.ptree.total 309 | i, pri = self.ptree.get_leaf(w) 310 | yield i 311 | 312 | def __len__(self): 313 | return self.num_samples 314 | 315 | def reset_weights(self, epoch): 316 | if not self.freeze and self.fixed_scale > 0: 317 | if epoch >= self.epochs: 318 | e = self.epochs - 1 319 | elif epoch < 1: 320 | e = 0 321 | else: 322 | e = epoch 323 | self.manual_weights = self.get_manual_weights(self.lams[e]) 324 | self.ptree.reset_fixed_weights(self.manual_weights, self.rescale) 325 | 326 | if self.root_decay in ['exp', 'linear', 'autoexp'] and epoch % self.decay_gap == 0: 327 | if self.root_decay == 'exp': 328 | self.nroot *= 2 329 | elif self.root_decay == 'linear': 330 | self.nroot += 1 331 | elif self.root_decay == 'autoexp': 332 | # self.nroot *= self.decay_factor 333 | self.nroot = np.power(self.decay_factor, epoch) 334 | 335 | bw = self.get_balanced_weights(self.nroot) 336 | self.ptree.reset_fixed_weights(bw) 337 | 338 | def update_weights(self, inds, weights): 339 | """ Update priority weights """ 340 | if not self.manual_only: 341 | weights = np.clip(weights, 0, self.init_weight) 342 | weights = np.power(weights, self.alpha) 343 | for i, w in zip(inds, weights): 344 | self.ptree.update(i, w) 345 | 346 | def get_weights(self): 347 | return self.ptree.get_weights() 348 | 349 | 350 | def get_sampler(): 351 | return MixedPrioritizedSampler 352 | -------------------------------------------------------------------------------- /classes.py: -------------------------------------------------------------------------------- 1 | CLASSES = ["tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray", 2 | "stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco", 3 | "indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper", 4 | "kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander", 5 | "smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog", 6 | "tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin", 7 | "box turtle", "banded gecko", "green iguana", "Carolina anole", 8 | "desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard", 9 | "Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile", 10 | "American alligator", "triceratops", "worm snake", "ring-necked snake", 11 | "eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake", 12 | "vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra", 13 | "green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake", 14 | "sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider", 15 | "barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider", 16 | "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl", 17 | "quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet", 18 | "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck", 19 | "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby", 20 | "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch", 21 | "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab", 22 | "fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab", 23 | "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron", 24 | "great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot", 25 | "bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher", 26 | "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion", 27 | "Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel", 28 | "Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle", 29 | "Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound", 30 | "English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound", 31 | "Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound", 32 | "Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier", 33 | "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier", 34 | "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier", 35 | "Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier", 36 | "Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer", 37 | "Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier", 38 | "Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier", 39 | "Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever", 40 | "Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla", 41 | "English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel", 42 | "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel", 43 | "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard", 44 | "Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie", 45 | "Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann", 46 | "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog", 47 | "Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff", 48 | "French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky", 49 | "Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog", 50 | "Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon", 51 | "Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle", 52 | "Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf", 53 | "red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox", 54 | "kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat", 55 | "Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger", 56 | "cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose", 57 | "meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle", 58 | "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper", 59 | "cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper", 60 | "lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly", 61 | "monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly", 62 | "starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit", 63 | "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse", 64 | "zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison", 65 | "ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)", 66 | "gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat", 67 | "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan", 68 | "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque", 69 | "langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin", 70 | "howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey", 71 | "ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda", 72 | "giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish", 73 | "sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown", 74 | "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance", 75 | "amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle", 76 | "backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo", 77 | "baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel", 78 | "wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel", 79 | "bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)", 80 | "beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini", 81 | "ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet", 82 | "bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra", 83 | "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest", 84 | "high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe", 85 | "can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton", 86 | "car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran", 87 | "CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw", 88 | "storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking", 89 | "church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker", 90 | "coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard", 91 | "candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot", 92 | "cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed", 93 | "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer", 94 | "rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table", 95 | "dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig", 96 | "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar", 97 | "electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder", 98 | "feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute", 99 | "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed", 100 | "freight car", "French horn", "frying pan", "fur coat", "garbage truck", 101 | "gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola", 102 | "gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine", 103 | "hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer", 104 | "handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet", 105 | "holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar", 106 | "horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep", 107 | "T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat", 108 | "ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library", 109 | "lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion", 110 | "music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag", 111 | "mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask", 112 | "matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone", 113 | "microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile", 114 | "mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor", 115 | "moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa", 116 | "mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail", 117 | "neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina", 118 | "odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart", 119 | "oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush", 120 | "pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench", 121 | "parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case", 122 | "pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube", 123 | "picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball", 124 | "pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag", 125 | "plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho", 126 | "pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug", 127 | "printer", "prison", "missile", "projector", "hockey puck", "punching bag", "purse", "quill", 128 | "quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel", 129 | "recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator", 130 | "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser", 131 | "rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal", 132 | "sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard", 133 | "CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store", 134 | "shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap", 135 | "shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door", 136 | "slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock", 137 | "solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater", 138 | "space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight", 139 | "stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf", 140 | "stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa", 141 | "submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge", 142 | "mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe", 143 | "table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball", 144 | "thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof", 145 | "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store", 146 | "tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod", 147 | "triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard", 148 | "umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling", 149 | "velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball", 150 | "waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink", 151 | "washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle", 152 | "hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing", 153 | "wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website", 154 | "comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu", 155 | "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette", 156 | "bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli", 157 | "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber", 158 | "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange", 159 | "lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate", 160 | "hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito", 161 | "red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef", 162 | "geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player", 163 | "bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn", 164 | "rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom", 165 | "earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper"] 166 | 167 | CUSTOM_TEMPLATES = { 168 | 'ImageNet': 'a photo of a {}.' 169 | } 170 | -------------------------------------------------------------------------------- /data/ClassPrioritySampler.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | from torch.utils.data.sampler import Sampler 4 | 5 | 6 | class RandomCycleIter: 7 | 8 | def __init__ (self, data, test_mode=False): 9 | self.data_list = list(data) 10 | self.length = len(self.data_list) 11 | self.i = self.length - 1 12 | self.test_mode = test_mode 13 | 14 | def __iter__ (self): 15 | return self 16 | 17 | def __next__ (self): 18 | self.i += 1 19 | 20 | if self.i == self.length: 21 | self.i = 0 22 | if not self.test_mode: 23 | random.shuffle(self.data_list) 24 | 25 | return self.data_list[self.i] 26 | 27 | 28 | class PriorityTree(object): 29 | def __init__(self, capacity, init_weights, fixed_weights=None, fixed_scale=1.0, 30 | alpha=1.0): 31 | """ 32 | fixed_weights: weights that wont be updated by self.update() 33 | """ 34 | assert fixed_weights is None or len(fixed_weights) == capacity 35 | assert len(init_weights) == capacity 36 | self.alpha = alpha 37 | self._capacity = capacity 38 | self._tree_size = 2 * capacity - 1 39 | self.fixed_scale = fixed_scale 40 | self.fixed_weights = np.zeros(self._capacity) if fixed_weights is None \ 41 | else fixed_weights 42 | self.tree = np.zeros(self._tree_size) 43 | self._initialized = False 44 | self.initialize(init_weights) 45 | 46 | def initialize(self, init_weights): 47 | """Initialize the tree.""" 48 | 49 | # Rescale the fixed_weights if it is not zero 50 | self.fixed_scale_init = self.fixed_scale 51 | if self.fixed_weights.sum() > 0 and init_weights.sum() > 0: 52 | self.fixed_scale_init *= init_weights.sum() / self.fixed_weights.sum() 53 | self.fixed_weights *= self.fixed_scale * init_weights.sum() \ 54 | / self.fixed_weights.sum() 55 | print('FixedWeights: {}'.format(self.fixed_weights.sum())) 56 | 57 | self.update_whole(init_weights + self.fixed_weights) 58 | self._initialized = True 59 | 60 | def reset_adaptive_weights(self, adaptive_weights): 61 | self.update_whole(self.fixed_weights + adaptive_weights) 62 | 63 | def reset_fixed_weights(self, fixed_weights, rescale=False): 64 | """ Reset the manually designed weights and 65 | update the whole tree accordingly. 66 | 67 | @rescale: rescale the fixed_weights such that 68 | fixed_weights.sum() = self.fixed_scale * adaptive_weights.sum() 69 | """ 70 | 71 | adaptive_weights = self.get_adaptive_weights() 72 | fixed_sum = fixed_weights.sum() 73 | if rescale and fixed_sum > 0: 74 | # Rescale fixedweight based on adaptive weights 75 | scale = self.fixed_scale * adaptive_weights.sum() / fixed_sum 76 | else: 77 | # Rescale fixedweight based on previous fixedweight 78 | scale = self.fixed_weights.sum() / fixed_sum 79 | self.fixed_weights = fixed_weights * scale 80 | self.update_whole(self.fixed_weights + adaptive_weights) 81 | 82 | def update_whole(self, total_weights): 83 | """ Update the whole tree based on per-example sampling weights """ 84 | if self.alpha != 1: 85 | total_weights = np.power(total_weights, self.alpha) 86 | lefti = self.pointer_to_treeidx(0) 87 | righti = self.pointer_to_treeidx(self.capacity-1) 88 | self.tree[lefti:righti+1] = total_weights 89 | 90 | # Iteratively find a parent layer 91 | while lefti != 0 and righti != 0: 92 | lefti = (lefti - 1) // 2 if lefti != 0 else 0 93 | righti = (righti - 1) // 2 if righti != 0 else 0 94 | 95 | # Assign paraent weights from right to left 96 | for i in range(righti, lefti-1, -1): 97 | self.tree[i] = self.tree[2*i+1] + self.tree[2*i+2] 98 | 99 | def get_adaptive_weights(self): 100 | """ Get the instance-aware weights, that are not mannually designed""" 101 | if self.alpha == 1: 102 | return self.get_total_weights() - self.fixed_weights 103 | else: 104 | return self.get_raw_total_weights() - self.fixed_weights 105 | 106 | def get_total_weights(self): 107 | """ Get the per-example sampling weights 108 | return shape: [capacity] 109 | """ 110 | lefti = self.pointer_to_treeidx(0) 111 | righti = self.pointer_to_treeidx(self.capacity-1) 112 | return self.tree[lefti:righti+1] 113 | 114 | def get_raw_total_weights(self): 115 | """ Get the per-example sampling weights 116 | return shape: [capacity] 117 | """ 118 | lefti = self.pointer_to_treeidx(0) 119 | righti = self.pointer_to_treeidx(self.capacity-1) 120 | return np.power(self.tree[lefti:righti+1], 1/self.alpha) 121 | 122 | @property 123 | def size(self): 124 | return self._tree_size 125 | 126 | @property 127 | def capacity(self): 128 | return self._capacity 129 | 130 | def __len__(self): 131 | return self.capacity 132 | 133 | def pointer_to_treeidx(self, pointer): 134 | assert pointer < self.capacity 135 | return int(pointer + self.capacity - 1) 136 | 137 | def update(self, pointer, priority): 138 | assert pointer < self.capacity 139 | tree_idx = self.pointer_to_treeidx(pointer) 140 | priority += self.fixed_weights[pointer] 141 | if self.alpha != 1: 142 | priority = np.power(priority, self.alpha) 143 | delta = priority - self.tree[tree_idx] 144 | self.tree[tree_idx] = priority 145 | while tree_idx != 0: 146 | tree_idx = (tree_idx - 1) // 2 147 | self.tree[tree_idx] += delta 148 | 149 | def update_delta(self, pointer, delta): 150 | assert pointer < self.capacity 151 | tree_idx = self.pointer_to_treeidx(pointer) 152 | ratio = 1- self.fixed_weights[pointer] / self.tree[tree_idx] 153 | # delta *= ratio 154 | if self.alpha != 1: 155 | # Update delta 156 | if self.tree[tree_idx] < 0 or \ 157 | np.power(self.tree[tree_idx], 1/self.alpha) + delta < 0: 158 | import pdb; pdb.set_trace() 159 | delta = np.power(np.power(self.tree[tree_idx], 1/self.alpha) + delta, 160 | self.alpha) \ 161 | - self.tree[tree_idx] 162 | self.tree[tree_idx] += delta 163 | while tree_idx != 0: 164 | tree_idx = (tree_idx - 1) // 2 165 | self.tree[tree_idx] += delta 166 | 167 | def get_leaf(self, value): 168 | assert self._initialized, 'PriorityTree not initialized!!!!' 169 | assert self.total > 0, 'No priority weights setted!!' 170 | parent = 0 171 | while True: 172 | left_child = 2 * parent + 1 173 | right_child = 2 * parent + 2 174 | if left_child >= len(self.tree): 175 | tgt_leaf = parent 176 | break 177 | if value < self.tree[left_child]: 178 | parent = left_child 179 | else: 180 | value -= self.tree[left_child] 181 | parent = right_child 182 | data_idx = tgt_leaf - self.capacity + 1 183 | return data_idx, self.tree[tgt_leaf] # data idx, priority 184 | 185 | @property 186 | def total(self): 187 | assert self._initialized, 'PriorityTree not initialized!!!!' 188 | return self.tree[0] 189 | 190 | @property 191 | def max(self): 192 | return np.max(self.tree[-self.capacity:]) 193 | 194 | @property 195 | def min(self): 196 | assert self._initialized, 'PriorityTree not initialized!!!!' 197 | return np.min(self.tree[-self.capacity:]) 198 | 199 | def get_weights(self): 200 | wdict = {'fixed_weights': self.fixed_weights, 201 | 'total_weights': self.get_total_weights()} 202 | if self.alpha != 1: 203 | wdict.update({'raw_total_weights': self.get_raw_total_weights(), 204 | 'alpha': self.alpha}) 205 | 206 | return wdict 207 | 208 | class ClassPrioritySampler(Sampler): 209 | """ 210 | A sampler combining manually designed sampling strategy and prioritized 211 | sampling strategy. 212 | 213 | Manually disigned strategy contains two parts: 214 | 215 | $$ manual_weights = lam * balanced_weights + (1-lam) uniform_weights 216 | 217 | Here we use a generalized version of balanced weights as follows, 218 | when n limits to infinity, balanced_weights = real_balanced_weights 219 | 220 | $$ balanced_weights = uniform_weights ^ (1/n) 221 | 222 | Then the balanced weights are scaled such that 223 | 224 | $$ balanced_weights.sum() = balance_scale * uniform_weights.sum() 225 | 226 | Note: above weights are per-class weights 227 | 228 | Overall sampling weights are given as 229 | $$ sampling_weights = manual_weights * fixed_scale + priority_weights 230 | 231 | Arguments: 232 | @dataset: A dataset 233 | @balance_scale: The scale of balanced_weights 234 | @lam: A weight to combine balanced weights and uniform weights 235 | - None for shifting sampling 236 | - 0 for uniform sampling 237 | - 1 for balanced sampling 238 | @fixed_scale: The scale of manually designed weights 239 | - fixed_scale < 0 means, the manually designed distribution will 240 | be used as the backend distribution of priorities. 241 | @cycle: shifting strategy 242 | - 0 for linear shifting: 3 -> 2 - > 1 243 | - 1 for periodic shifting: 244 | 3 -> 2 - > 1 -> 3 -> 2 - > 1 -> 3 -> 2 - > 1 245 | - 2 for cosine-like periodic shifting: 246 | 3 -> 2 - > 1 -> 1 -> 2 - > 3 -> 3 -> 2 - > 1 247 | @nroot: 248 | - None for truly balanced weights 249 | - >= 2 for pseudo-balanced weights 250 | @rescale: whether to rebalance the manual weights and priority weights 251 | every epoch 252 | @root_decay: 253 | - 'exp': for exponential decay 254 | - 'linear': for linear decay 255 | """ 256 | def __init__(self, dataset, balance_scale=1.0, fixed_scale=1.0, 257 | lam=None, epochs=90, cycle=0, nroot=None, manual_only=False, 258 | rescale=False, root_decay=None, decay_gap=30, ptype='score', 259 | pri_mode='train', momentum=0., alpha=1.0): 260 | """ 261 | """ 262 | self.dataset = dataset 263 | self.balance_scale = balance_scale 264 | self.fixed_scale = fixed_scale 265 | self.epochs = epochs 266 | self.lam = lam 267 | self.cycle = cycle 268 | self.nroot = nroot 269 | self.rescale = rescale 270 | self.manual_only = manual_only 271 | self.root_decay = root_decay 272 | self.decay_gap = decay_gap 273 | self.ptype = ptype 274 | self.pri_mode = pri_mode 275 | self.num_samples = len(dataset) 276 | self.manual_as_backend = False 277 | self.momentum = momentum 278 | self.alpha = alpha 279 | 280 | assert 0. <= self.momentum <= 1.0 281 | assert 0. <= self.alpha 282 | 283 | # Change the backend distribution of priority if needed 284 | if self.fixed_scale < 0: 285 | self.fixed_scale = 0 286 | self.manual_as_backend = True 287 | 288 | # If using root_decay, reset relevent parameters 289 | if self.root_decay in ['exp', 'linear', 'autoexp']: 290 | self.lam = 1 291 | self.manual_only = True 292 | self.nroot = 1 293 | if self.root_decay == 'autoexp': 294 | self.decay_gap = 1 295 | self.decay_factor = np.power(nroot, 1/(self.epochs-1)) 296 | else: 297 | assert self.root_decay is None 298 | assert self.nroot is None or self.nroot > 1 299 | print("====> Decay GAP: {}".format(self.decay_gap)) 300 | 301 | # Take care of lambdas 302 | self.freeze = True 303 | if self.lam is None: 304 | self.freeze = False 305 | if cycle == 0: 306 | self.lams = np.linspace(0, 1, epochs) 307 | elif cycle == 1: 308 | self.lams = np.concatenate([np.linspace(0,1,epochs//3)] * 3) 309 | elif cycle == 2: 310 | self.lams = np.concatenate([np.linspace(0,1,epochs//3), 311 | np.linspace(0,1,epochs//3)[::-1], 312 | np.linspace(0,1,epochs//3)]) 313 | else: 314 | raise NotImplementedError( 315 | 'cycle = {} not implemented'.format(cycle)) 316 | else: 317 | self.lams = [self.lam] 318 | 319 | # Get num of samples per class 320 | self.cls_cnts = [] 321 | self.labels = labels = np.array(self.dataset.labels) 322 | for l in np.unique(labels): 323 | self.cls_cnts.append(np.sum(labels==l)) 324 | self.num_classes = len(self.cls_cnts) 325 | self.cnts = np.array(self.cls_cnts).astype(float) 326 | 327 | # Get per-class image indexes 328 | self.cls_idxs = [[] for _ in range(self.num_classes)] 329 | for i, label in enumerate(self.dataset.labels): 330 | self.cls_idxs[label].append(i) 331 | self.data_iter_list = [RandomCycleIter(x) for x in self.cls_idxs] 332 | for ci in range(self.num_classes): 333 | self.cls_idxs[ci] = np.array(self.cls_idxs[ci]) 334 | 335 | # Build balanced weights based on class counts 336 | self.balanced_weights = self.get_balanced_weights(self.nroot) 337 | self.uniform_weights = self.get_uniform_weights() 338 | self.manual_weights = self.get_manual_weights(self.lams[0]) 339 | 340 | # back_weights = self.get_balanced_weights(1.5) 341 | back_weights = self.uniform_weights 342 | 343 | # Calculate priority ratios that reshape priority into target distribution 344 | self.per_cls_ratios = self.get_cls_ratios( 345 | self.manual_weights if self.manual_as_backend else back_weights) 346 | self.per_example_ratios = self.broadcast(self.per_cls_ratios) 347 | 348 | # Setup priority tree 349 | if self.ptype == 'score': 350 | self.init_weight = 1. 351 | elif self.ptype in ['CE', 'entropy']: 352 | self.init_weight = 6.9 353 | else: 354 | raise NotImplementedError('ptype {} not implemented'.format(self.ptype)) 355 | if self.manual_only: 356 | self.init_weight = 0. 357 | self.per_example_uni_weights = np.ones(self.num_samples) * self.init_weight 358 | self.per_example_velocities = np.zeros(self.num_samples) 359 | # init_priorities = np.power(self.init_weight, self.alpha) \ 360 | # * self.uniform_weights * self.per_cls_ratios 361 | init_priorities = self.init_weight * self.uniform_weights * self.per_cls_ratios 362 | self.ptree = PriorityTree(self.num_classes, init_priorities, 363 | self.manual_weights.copy(), fixed_scale=self.fixed_scale, 364 | alpha=self.alpha) 365 | 366 | def get_cls_ratios(self, tgt_weights): 367 | if tgt_weights is self.uniform_weights: 368 | return np.ones_like(self.uniform_weights) 369 | per_cls_ratios = tgt_weights / self.uniform_weights 370 | per_cls_ratios *= self.uniform_weights.sum() / tgt_weights.sum() 371 | return per_cls_ratios 372 | 373 | def get_cls_weights(self): 374 | ratioed_ws = self.per_example_uni_weights * self.per_example_ratios 375 | return self.debroadcast_sum(ratioed_ws) 376 | 377 | def broadcast(self, per_cls_info): 378 | per_exmaple_info = np.zeros(self.num_samples) 379 | # Braodcast per-cls info to each example 380 | for ci in range(self.num_classes): 381 | per_exmaple_info[self.cls_idxs[ci]] = per_cls_info[ci] 382 | return per_exmaple_info 383 | 384 | def debroadcast_sum(self, per_example_info): 385 | per_cls_info = np.zeros(self.num_classes) 386 | # DeBraodcast per-example info to each cls by summation 387 | for ci in range(self.num_classes): 388 | per_cls_info[ci] = per_example_info[self.cls_idxs[ci]].sum() 389 | return per_cls_info 390 | 391 | def get_manual_weights(self, lam): 392 | # Merge balanced weights and uniform weights 393 | if lam == 1: 394 | manual_weights = self.balanced_weights.copy() 395 | elif lam == 0: 396 | manual_weights = self.uniform_weights.copy() 397 | else: 398 | manual_weights = self.balanced_weights * lam + (1-lam) * self.uniform_weights 399 | return manual_weights 400 | 401 | def get_uniform_weights(self): 402 | return self.cnts.copy() 403 | 404 | def get_balanced_weights(self, nroot): 405 | """ Calculate normalized generalized balanced weights """ 406 | 407 | cnts = self.cnts 408 | if nroot is None: 409 | # Real balanced sampling weights, each class has the same weights 410 | # Un-normalized !!! 411 | cls_ws = np.ones(len(cnts)) 412 | elif nroot >= 1: 413 | # Generalized balanced weights 414 | # Un-normalized !!! 415 | cls_ws = cnts / cnts.sum() 416 | cls_ws = np.power(cls_ws, 1./nroot) * cnts.sum() 417 | cls_ws = cls_ws 418 | else: 419 | raise NotImplementedError('root:{} not implemented'.format(nroot)) 420 | 421 | # Get un-normalized weights 422 | balanced_weights = cls_ws 423 | 424 | # Normalization and rescale 425 | balanced_weights *= self.num_samples / balanced_weights.sum() * \ 426 | self.balance_scale 427 | return balanced_weights 428 | 429 | def __iter__(self): 430 | for _ in range(self.num_samples): 431 | w = random.random() * self.ptree.total 432 | ci, pri = self.ptree.get_leaf(w) 433 | yield next(self.data_iter_list[ci]) 434 | 435 | def __len__(self): 436 | return self.num_samples 437 | 438 | def reset_weights(self, epoch): 439 | # If it is linear shifting 440 | if not self.freeze: 441 | e = np.clip(epoch, 0, self.epochs-1) 442 | self.manual_weights = self.get_manual_weights(self.lams[e]) 443 | # make sure 'self.fixed_scale > 0' and 'self.manual_as_backend = True' are 444 | # mutually exclusive 445 | if self.fixed_scale > 0: 446 | self.ptree.reset_fixed_weights(self.manual_weights, self.rescale) 447 | if self.manual_as_backend: 448 | self.update_backend_distribution(self.manual_weights) 449 | 450 | # If it is root decay 451 | if self.root_decay in ['exp', 'linear', 'autoexp'] and epoch % self.decay_gap == 0: 452 | if self.root_decay == 'exp': 453 | self.nroot *= 2 454 | elif self.root_decay == 'linear': 455 | self.nroot += 1 456 | elif self.root_decay == 'autoexp': 457 | # self.nroot *= self.decay_factor 458 | self.nroot = np.power(self.decay_factor, epoch) 459 | 460 | bw = self.get_balanced_weights(self.nroot) 461 | if self.manual_as_backend: 462 | self.update_backend_distribution(bw) 463 | else: 464 | self.ptree.reset_fixed_weights(bw) 465 | 466 | def update_backend_distribution(self, tgt_weights): 467 | # Recalculate the cls ratios based on the given target distribution 468 | self.per_cls_ratios = self.get_cls_ratios(tgt_weights) 469 | self.per_example_ratios = self.broadcast(self.per_cls_ratios) 470 | 471 | # Recalculate the new per-class weights based on the new ratios 472 | # new_backend_weights = self.init_weight * self.uniform_weights * self.per_cls_ratios 473 | new_cls_weights = self.get_cls_weights() 474 | self.ptree.reset_adaptive_weights(new_cls_weights) 475 | 476 | def update_weights(self, inds, weights, labels): 477 | """ Update priority weights """ 478 | if not self.manual_only and self.pri_mode == 'train': 479 | weights = np.clip(weights, 0, self.init_weight) 480 | 481 | # Iterate over all classes in the batch 482 | for l in np.unique(labels): 483 | # Calculate per-class delta weights 484 | example_inds = inds[labels==l] 485 | last_weights = self.per_example_uni_weights[example_inds] 486 | # delta = np.power(weights[labels==l], self.alpha) - \ 487 | # np.power(last_weights, self.alpha) 488 | delta = weights[labels==l] - last_weights 489 | delta = self.momentum * self.per_example_velocities[example_inds] + \ 490 | (1-self.momentum) * delta 491 | 492 | # Update velocities 493 | self.per_example_velocities[example_inds] = delta 494 | # Update per-example weights 495 | # self.per_example_uni_weights[example_inds] = weights[labels==l] 496 | self.per_example_uni_weights[example_inds] += delta 497 | 498 | # Sacle the delta 499 | # (ie, the per-example weights both before and after update) 500 | delta *= self.per_example_ratios[example_inds] 501 | 502 | # Update tree 503 | if self.alpha == 1: 504 | self.ptree.update_delta(l, delta.sum()) 505 | else: 506 | self.ptree.update(l, self.per_example_uni_weights[self.cls_idxs[l]].sum()) 507 | 508 | 509 | def reset_priority(self, weights, labels): 510 | if self.pri_mode == 'valid': 511 | assert len(np.unique(labels)) == self.num_classes 512 | weights = np.clip(weights, 0, self.init_weight) 513 | cls_weights = np.zeros(self.num_classes) 514 | for c in np.unique(labels): 515 | cls_weights[c] = weights[labels==c].mean() 516 | cls_weights *= self.cnts 517 | cls_weights *= self.per_cls_ratios 518 | self.ptree.reset_adaptive_weights(cls_weights) 519 | 520 | def get_weights(self): 521 | return self.ptree.get_weights() 522 | 523 | 524 | def get_sampler(): 525 | return ClassPrioritySampler 526 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import enum 2 | import os 3 | import copy 4 | import pickle 5 | from re import template 6 | from numpy.core.fromnumeric import cumprod 7 | import torch 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | import torch.nn.functional as F 11 | from tqdm import tqdm 12 | from utils import * 13 | from logger import Logger 14 | import time 15 | import numpy as np 16 | import warnings 17 | import pdb 18 | # import clip 19 | from clip import clip 20 | from classes import CLASSES, CUSTOM_TEMPLATES 21 | 22 | def load_clip_to_cpu(visual_backbone): 23 | backbone_name = visual_backbone 24 | url = clip._MODELS[backbone_name] 25 | model_path = clip._download(url, os.path.expanduser("~/.cache/clip")) 26 | 27 | try: 28 | # loading JIT archive 29 | model = torch.jit.load(model_path, map_location="cpu").eval() 30 | state_dict = None 31 | 32 | except RuntimeError: 33 | state_dict = torch.load(model_path, map_location="cpu") 34 | 35 | model = clip.build_model(state_dict or model.state_dict()) 36 | 37 | return model 38 | 39 | class TextEncoder(nn.Module): 40 | def __init__(self, clip_model): 41 | super().__init__() 42 | self.transformer = clip_model.transformer 43 | self.positional_embedding = clip_model.positional_embedding 44 | self.ln_final = clip_model.ln_final 45 | self.text_projection = clip_model.text_projection 46 | self.dtype = clip_model.dtype 47 | self.token_embedding = clip_model.token_embedding 48 | 49 | def forward(self, text): 50 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 51 | 52 | x = x + self.positional_embedding.type(self.dtype) 53 | x = x.permute(1, 0, 2) # NLD -> LND 54 | x = self.transformer(x) 55 | x = x.permute(1, 0, 2) # LND -> NLD 56 | x = self.ln_final(x).type(self.dtype) 57 | 58 | # x.shape = [batch_size, n_ctx, transformer.width] 59 | # take features from the eot embedding (eot_token is the highest number in each sequence) 60 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 61 | 62 | return x 63 | 64 | class model (): 65 | 66 | def __init__(self, config, data, test=False): 67 | 68 | self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 69 | self.config = config 70 | self.training_opt = self.config['training_opt'] 71 | self.model_opt = self.config['model'] 72 | self.data = data 73 | self.test_mode = test 74 | self.num_gpus = torch.cuda.device_count() 75 | self.do_shuffle = config['shuffle'] if 'shuffle' in config else False 76 | self.clip_model = load_clip_to_cpu(self.model_opt['clip']['params']['visual_backbone']) 77 | 78 | # Setup logger 79 | self.logger = Logger(self.training_opt['log_dir']) 80 | 81 | # Initialize model 82 | self.init_models() 83 | 84 | # Under training mode, initialize training steps, optimizers, schedulers, criterions, and centroids 85 | if not self.test_mode: 86 | 87 | print('Using steps for training.') 88 | self.training_data_num = len(self.data['train'].dataset) 89 | self.epoch_steps = int(self.training_data_num \ 90 | / self.training_opt['batch_size']) 91 | 92 | # Initialize model optimizer and scheduler 93 | print('Initializing model optimizer.') 94 | self.scheduler_params = self.training_opt['scheduler_params'] 95 | self.model_optimizer, \ 96 | self.model_optimizer_scheduler = self.init_optimizers(self.model_optim_params_list) 97 | self.init_criterions() 98 | 99 | # Set up log file 100 | self.log_file = os.path.join(self.training_opt['log_dir'], 'log.txt') 101 | if os.path.isfile(self.log_file): 102 | os.remove(self.log_file) 103 | self.logger.log_cfg(self.config) 104 | else: 105 | self.log_file = None 106 | 107 | def init_models(self, optimizer=True): 108 | self.model_optim_params_list = [] 109 | 110 | print("Using", torch.cuda.device_count(), "GPUs.") 111 | 112 | self.visual_model = torch.nn.DataParallel(self.clip_model.visual).cuda() 113 | text_model = TextEncoder(self.clip_model) 114 | self.text_model = torch.nn.DataParallel(text_model).cuda() 115 | 116 | feat_dim = self.model_opt['adapter']['params']['feat_dim'] 117 | # self.load_model(self.config['model_dir']) 118 | self.adapter = torch.nn.DataParallel(nn.Linear(feat_dim, feat_dim, bias=False)).cuda() 119 | 120 | if self.training_opt['phaseA'] is not True: 121 | self.load_model(self.config['model_dir']) 122 | for param_name, param in self.visual_model.named_parameters(): 123 | param.requires_grad = False 124 | 125 | for param_name, param in self.text_model.named_parameters(): 126 | param.requires_grad = False 127 | 128 | optim_params_adapter = self.model_opt['adapter']['optim_params'] 129 | self.model_optim_params_list.append({'params': self.adapter.parameters(), 130 | 'lr': optim_params_adapter['lr'], 131 | 'momentum': optim_params_adapter['momentum'], 132 | 'weight_decay': optim_params_adapter['weight_decay']}) 133 | 134 | 135 | optim_params_clip = self.model_opt['clip']['optim_params'] 136 | self.model_optim_params_list.append({'params': self.visual_model.parameters(), 137 | 'lr': optim_params_clip['lr'], 138 | 'momentum': optim_params_clip['momentum'], 139 | 'weight_decay': optim_params_clip['weight_decay']}) 140 | 141 | self.model_optim_params_list.append({'params': self.text_model.parameters(), 142 | 'lr': optim_params_clip['lr'], 143 | 'momentum': optim_params_clip['momentum'], 144 | 'weight_decay': optim_params_clip['weight_decay']}) 145 | 146 | 147 | 148 | def init_criterions(self): 149 | criterion_defs = self.config['criterions'] 150 | self.criterions = {} 151 | self.criterion_weights = {} 152 | 153 | for key, val in criterion_defs.items(): 154 | def_file = val['def_file'] 155 | loss_args = list(val['loss_params'].values()) 156 | 157 | self.criterions[key] = source_import(def_file).create_loss(*loss_args).cuda() 158 | self.criterion_weights[key] = val['weight'] 159 | 160 | if val['optim_params']: 161 | print('Initializing criterion optimizer.') 162 | optim_params = val['optim_params'] 163 | optim_params = [{'params': self.criterions[key].parameters(), 164 | 'lr': optim_params['lr'], 165 | 'momentum': optim_params['momentum'], 166 | 'weight_decay': optim_params['weight_decay']}] 167 | # Initialize criterion optimizer and scheduler 168 | self.criterion_optimizer, \ 169 | self.criterion_optimizer_scheduler = self.init_optimizers(optim_params) 170 | else: 171 | self.criterion_optimizer = None 172 | 173 | def init_optimizers(self, optim_params): 174 | optimizer = optim.SGD(optim_params) 175 | if self.config['coslr']: 176 | print("===> Using coslr eta_min={}".format(self.config['endlr'])) 177 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 178 | optimizer, self.training_opt['num_epochs'], eta_min=self.config['endlr']) 179 | else: 180 | scheduler = optim.lr_scheduler.StepLR(optimizer, 181 | step_size=self.scheduler_params['step_size'], 182 | gamma=self.scheduler_params['gamma']) 183 | return optimizer, scheduler 184 | 185 | def batch_forward(self, inputs, phase='train'): 186 | ''' 187 | This is a general single batch running function. 188 | ''' 189 | classnames = CLASSES 190 | templates = CUSTOM_TEMPLATES['ImageNet'] 191 | 192 | #with torch.no_grad(): 193 | texts = torch.cat([clip.tokenize(templates.format(c)) for c in classnames]) 194 | texts = texts.cuda() 195 | zeroshot_weights = self.text_model(texts).float() 196 | zeroshot_weights = zeroshot_weights / zeroshot_weights.norm(dim=-1, keepdim=True) 197 | 198 | image_features = self.visual_model(inputs).float() 199 | 200 | x = image_features 201 | 202 | if self.training_opt['phaseA'] is not True: 203 | x = self.adapter(image_features) 204 | ratio = 0.2 205 | x = ratio * x + (1-ratio) * image_features 206 | 207 | x = x/x.norm(dim=-1, keepdim=True) 208 | logits = 100. * x @ zeroshot_weights.t() 209 | 210 | self.logits = logits 211 | 212 | 213 | def batch_backward(self): 214 | # Zero out optimizer gradients 215 | self.model_optimizer.zero_grad() 216 | if self.criterion_optimizer: 217 | self.criterion_optimizer.zero_grad() 218 | # Back-propagation from loss outputs 219 | self.loss.backward() 220 | # Step optimizers 221 | self.model_optimizer.step() 222 | if self.criterion_optimizer: 223 | self.criterion_optimizer.step() 224 | 225 | def batch_loss(self, labels): 226 | self.loss = 0 227 | 228 | # First, apply performance loss 229 | if 'PerformanceLoss' in self.criterions.keys(): 230 | self.loss_perf = self.criterions['PerformanceLoss'](self.logits, labels) 231 | self.loss_perf *= self.criterion_weights['PerformanceLoss'] 232 | self.loss += self.loss_perf 233 | 234 | # Apply loss on features if set up 235 | if 'FeatureLoss' in self.criterions.keys(): 236 | self.loss_feat = self.criterions['FeatureLoss'](self.features, labels) 237 | self.loss_feat = self.loss_feat * self.criterion_weights['FeatureLoss'] 238 | # Add feature loss to total loss 239 | self.loss += self.loss_feat 240 | 241 | def shuffle_batch(self, x, y): 242 | index = torch.randperm(x.size(0)) 243 | x = x[index] 244 | y = y[index] 245 | return x, y 246 | 247 | def train(self): 248 | # When training the network 249 | print_str = ['Phase: train'] 250 | print_write(print_str, self.log_file) 251 | time.sleep(0.25) 252 | 253 | print_write(['Do shuffle??? --- ', self.do_shuffle], self.log_file) 254 | 255 | # Initialize best model 256 | best_model_weights = {} 257 | best_model_weights['visual_model'] = copy.deepcopy(self.visual_model.state_dict()) 258 | best_model_weights['text_model'] = copy.deepcopy(self.text_model.state_dict()) 259 | if self.training_opt['phaseA'] is not True: 260 | best_model_weights['classifier'] = copy.deepcopy(self.adapter.state_dict()) 261 | best_acc = 0.0 262 | best_epoch = 0 263 | # best_centroids = self.centroids 264 | 265 | end_epoch = self.training_opt['num_epochs'] 266 | 267 | # Loop over epochs 268 | for epoch in range(1, end_epoch + 1): 269 | 270 | torch.cuda.empty_cache() 271 | 272 | # Set model modes and set scheduler 273 | # In training, step optimizer scheduler and set model to train() 274 | self.model_optimizer_scheduler.step() 275 | if self.criterion_optimizer: 276 | self.criterion_optimizer_scheduler.step() 277 | 278 | # Iterate over dataset 279 | total_preds = [] 280 | total_labels = [] 281 | 282 | for step, (inputs, labels, indexes) in enumerate(self.data['train']): 283 | # Break when step equal to epoch step 284 | if step == self.epoch_steps: 285 | break 286 | if self.do_shuffle: 287 | inputs, labels = self.shuffle_batch(inputs, labels) 288 | inputs, labels = inputs.cuda(), labels.cuda() 289 | 290 | # If on training phase, enable gradients 291 | with torch.set_grad_enabled(True): 292 | 293 | # If training, forward with loss, and no top 5 accuracy calculation 294 | self.batch_forward(inputs, 295 | phase='train') 296 | self.batch_loss(labels) 297 | self.batch_backward() 298 | 299 | # Tracking predictions 300 | _, preds = torch.max(self.logits, 1) 301 | total_preds.append(torch2numpy(preds)) 302 | total_labels.append(torch2numpy(labels)) 303 | 304 | # Output minibatch training results 305 | if step % self.training_opt['display_step'] == 0: 306 | 307 | minibatch_loss_feat = self.loss_feat.item() \ 308 | if 'FeatureLoss' in self.criterions.keys() else None 309 | minibatch_loss_perf = self.loss_perf.item() \ 310 | if 'PerformanceLoss' in self.criterions else None 311 | minibatch_loss_total = self.loss.item() 312 | minibatch_acc = mic_acc_cal(preds, labels) 313 | 314 | print_str = ['Epoch: [%d/%d]' 315 | % (epoch, self.training_opt['num_epochs']), 316 | 'Step: %5d' 317 | % (step), 318 | 'Minibatch_loss_feature: %.3f' 319 | % (minibatch_loss_feat) if minibatch_loss_feat else '', 320 | 'Minibatch_loss_performance: %.3f' 321 | % (minibatch_loss_perf) if minibatch_loss_perf else '', 322 | 'Minibatch_accuracy_micro: %.3f' 323 | % (minibatch_acc)] 324 | print_write(print_str, self.log_file) 325 | 326 | loss_info = { 327 | 'Epoch': epoch, 328 | 'Step': step, 329 | 'Total': minibatch_loss_total, 330 | 'CE': minibatch_loss_perf, 331 | 'feat': minibatch_loss_feat 332 | } 333 | 334 | self.logger.log_loss(loss_info) 335 | 336 | # Update priority weights if using PrioritizedSampler 337 | # if self.training_opt['sampler'] and \ 338 | # self.training_opt['sampler']['type'] == 'PrioritizedSampler': 339 | if hasattr(self.data['train'].sampler, 'update_weights'): 340 | if hasattr(self.data['train'].sampler, 'ptype'): 341 | ptype = self.data['train'].sampler.ptype 342 | else: 343 | ptype = 'score' 344 | ws = get_priority(ptype, self.logits.detach(), labels) 345 | # ws = logits2score(self.logits.detach(), labels) 346 | inlist = [indexes.cpu().numpy(), ws] 347 | if self.training_opt['sampler']['type'] == 'ClassPrioritySampler': 348 | inlist.append(labels.cpu().numpy()) 349 | self.data['train'].sampler.update_weights(*inlist) 350 | # self.data['train'].sampler.update_weights(indexes.cpu().numpy(), ws) 351 | 352 | if hasattr(self.data['train'].sampler, 'get_weights'): 353 | self.logger.log_ws(epoch, self.data['train'].sampler.get_weights()) 354 | if hasattr(self.data['train'].sampler, 'reset_weights'): 355 | self.data['train'].sampler.reset_weights(epoch) 356 | 357 | # After every epoch, validation 358 | rsls = {'epoch': epoch} 359 | rsls_train = self.eval_with_preds(total_preds, total_labels) 360 | rsls_eval = self.eval(phase='val') 361 | rsls.update(rsls_train) 362 | rsls.update(rsls_eval) 363 | 364 | # Reset class weights for sampling if pri_mode is valid 365 | if hasattr(self.data['train'].sampler, 'reset_priority'): 366 | ws = get_priority(self.data['train'].sampler.ptype, 367 | self.total_logits.detach(), 368 | self.total_labels) 369 | self.data['train'].sampler.reset_priority(ws, self.total_labels.cpu().numpy()) 370 | 371 | # Log results 372 | self.logger.log_acc(rsls) 373 | 374 | # Under validation, the best model need to be updated 375 | if self.eval_acc_mic_top1 > best_acc: 376 | best_epoch = epoch 377 | best_acc = self.eval_acc_mic_top1 378 | #best_centroids = self.centroids 379 | best_model_weights['visual_model'] = copy.deepcopy(self.visual_model.state_dict()) 380 | best_model_weights['text_model'] = copy.deepcopy(self.text_model.state_dict()) 381 | if self.training_opt['phaseA'] is not True: 382 | best_model_weights['classifier'] = copy.deepcopy(self.adapter.state_dict()) 383 | 384 | print('===> Saving checkpoint') 385 | self.save_latest(epoch) 386 | 387 | print() 388 | print('Training Complete.') 389 | 390 | print_str = ['Best validation accuracy is %.3f at epoch %d' % (best_acc, best_epoch)] 391 | print_write(print_str, self.log_file) 392 | # Save the best model and best centroids if calculated 393 | self.save_model(epoch, best_epoch, best_model_weights, best_acc) 394 | 395 | # Test on the test set 396 | # self.reset_model(best_model_weights) 397 | self.eval('test' if 'test' in self.data else 'val') 398 | print('Done') 399 | 400 | 401 | def eval_with_preds(self, preds, labels): 402 | # Count the number of examples 403 | n_total = sum([len(p) for p in preds]) 404 | 405 | # Split the examples into normal and mixup 406 | normal_preds, normal_labels = [], [] 407 | mixup_preds, mixup_labels1, mixup_labels2, mixup_ws = [], [], [], [] 408 | for p, l in zip(preds, labels): 409 | if isinstance(l, tuple): 410 | mixup_preds.append(p) 411 | mixup_labels1.append(l[0]) 412 | mixup_labels2.append(l[1]) 413 | mixup_ws.append(l[2] * np.ones_like(l[0])) 414 | else: 415 | normal_preds.append(p) 416 | normal_labels.append(l) 417 | 418 | # Calculate normal prediction accuracy 419 | rsl = {'train_all':0., 'train_many':0., 'train_median':0., 'train_low': 0.} 420 | if len(normal_preds) > 0: 421 | normal_preds, normal_labels = list(map(np.concatenate, [normal_preds, normal_labels])) 422 | n_top1 = mic_acc_cal(normal_preds, normal_labels) 423 | n_top1_many, \ 424 | n_top1_median, \ 425 | n_top1_low, = shot_acc(normal_preds, normal_labels, self.data['train']) 426 | rsl['train_all'] += len(normal_preds) / n_total * n_top1 427 | rsl['train_many'] += len(normal_preds) / n_total * n_top1_many 428 | rsl['train_median'] += len(normal_preds) / n_total * n_top1_median 429 | rsl['train_low'] += len(normal_preds) / n_total * n_top1_low 430 | 431 | # Calculate mixup prediction accuracy 432 | if len(mixup_preds) > 0: 433 | mixup_preds, mixup_labels, mixup_ws = \ 434 | list(map(np.concatenate, [mixup_preds*2, mixup_labels1+mixup_labels2, mixup_ws])) 435 | mixup_ws = np.concatenate([mixup_ws, 1-mixup_ws]) 436 | n_top1 = weighted_mic_acc_cal(mixup_preds, mixup_labels, mixup_ws) 437 | n_top1_many, \ 438 | n_top1_median, \ 439 | n_top1_low, = weighted_shot_acc(mixup_preds, mixup_labels, mixup_ws, self.data['train']) 440 | rsl['train_all'] += len(mixup_preds) / 2 / n_total * n_top1 441 | rsl['train_many'] += len(mixup_preds) / 2 / n_total * n_top1_many 442 | rsl['train_median'] += len(mixup_preds) / 2 / n_total * n_top1_median 443 | rsl['train_low'] += len(mixup_preds) / 2 / n_total * n_top1_low 444 | 445 | # Top-1 accuracy and additional string 446 | print_str = ['\n Training acc Top1: %.3f \n' % (rsl['train_all']), 447 | 'Many_top1: %.3f' % (rsl['train_many']), 448 | 'Median_top1: %.3f' % (rsl['train_median']), 449 | 'Low_top1: %.3f' % (rsl['train_low']), 450 | '\n'] 451 | print_write(print_str, self.log_file) 452 | 453 | return rsl 454 | 455 | def eval(self, phase='val', openset=False, save_feat=False): 456 | 457 | print_str = ['Phase: %s' % (phase)] 458 | print_write(print_str, self.log_file) 459 | time.sleep(0.25) 460 | 461 | if openset: 462 | print('Under openset test mode. Open threshold is %.1f' 463 | % self.training_opt['open_threshold']) 464 | 465 | torch.cuda.empty_cache() 466 | 467 | self.total_logits = torch.empty((0, self.training_opt['num_classes'])).cuda() 468 | self.total_labels = torch.empty(0, dtype=torch.long).cuda() 469 | self.total_paths = np.empty(0) 470 | 471 | get_feat_only = save_feat 472 | feats_all, labels_all, idxs_all, logits_all = [], [], [], [] 473 | featmaps_all = [] 474 | # Iterate over dataset 475 | for inputs, labels, paths in tqdm(self.data[phase]): 476 | inputs, labels = inputs.cuda(), labels.cuda() 477 | 478 | # If on training phase, enable gradients 479 | with torch.set_grad_enabled(False): 480 | 481 | # In validation or testing 482 | self.batch_forward(inputs, phase=phase) 483 | if not get_feat_only: 484 | self.total_logits = torch.cat((self.total_logits, self.logits)) 485 | self.total_labels = torch.cat((self.total_labels, labels)) 486 | self.total_paths = np.concatenate((self.total_paths, paths)) 487 | 488 | if get_feat_only: 489 | logits_all.append(self.logits.cpu().numpy()) 490 | feats_all.append(self.features.cpu().numpy()) 491 | labels_all.append(labels.cpu().numpy()) 492 | idxs_all.append(paths.numpy()) 493 | 494 | if get_feat_only: 495 | typ = 'feat' 496 | if phase == 'train_plain': 497 | name = 'train{}_all.pkl'.format(typ) 498 | elif phase == 'test': 499 | name = 'test{}_all.pkl'.format(typ) 500 | elif phase == 'val': 501 | name = 'val{}_all.pkl'.format(typ) 502 | 503 | fname = os.path.join(self.training_opt['log_dir'], name) 504 | print('===> Saving feats to ' + fname) 505 | with open(fname, 'wb') as f: 506 | pickle.dump({ 507 | 'feats': np.concatenate(feats_all), 508 | 'labels': np.concatenate(labels_all), 509 | 'idxs': np.concatenate(idxs_all), 510 | }, 511 | f, protocol=4) 512 | return 513 | probs, preds = F.softmax(self.total_logits.detach(), dim=1).max(dim=1) 514 | 515 | if openset: 516 | preds[probs < self.training_opt['open_threshold']] = -1 517 | self.openset_acc = mic_acc_cal(preds[self.total_labels == -1], 518 | self.total_labels[self.total_labels == -1]) 519 | print('\n\nOpenset Accuracy: %.3f' % self.openset_acc) 520 | 521 | # Calculate the overall accuracy and F measurement 522 | self.eval_acc_mic_top1= mic_acc_cal(preds[self.total_labels != -1], 523 | self.total_labels[self.total_labels != -1]) 524 | self.eval_f_measure = F_measure(preds, self.total_labels, openset=openset, 525 | theta=self.training_opt['open_threshold']) 526 | self.many_acc_top1, \ 527 | self.median_acc_top1, \ 528 | self.low_acc_top1, \ 529 | self.cls_accs = shot_acc(preds[self.total_labels != -1], 530 | self.total_labels[self.total_labels != -1], 531 | self.data['train'], 532 | acc_per_cls=True) 533 | # Top-1 accuracy and additional string 534 | print_str = ['\n\n', 535 | 'Phase: %s' 536 | % (phase), 537 | '\n\n', 538 | 'Evaluation_accuracy_micro_top1: %.3f' 539 | % (self.eval_acc_mic_top1), 540 | '\n', 541 | 'Averaged F-measure: %.3f' 542 | % (self.eval_f_measure), 543 | '\n', 544 | 'Many_shot_accuracy_top1: %.3f' 545 | % (self.many_acc_top1), 546 | 'Median_shot_accuracy_top1: %.3f' 547 | % (self.median_acc_top1), 548 | 'Low_shot_accuracy_top1: %.3f' 549 | % (self.low_acc_top1), 550 | '\n'] 551 | 552 | rsl = {phase + '_all': self.eval_acc_mic_top1, 553 | phase + '_many': self.many_acc_top1, 554 | phase + '_median': self.median_acc_top1, 555 | phase + '_low': self.low_acc_top1, 556 | phase + '_fscore': self.eval_f_measure} 557 | 558 | if phase == 'val': 559 | print_write(print_str, self.log_file) 560 | else: 561 | acc_str = ["{:.1f} \t {:.1f} \t {:.1f} \t {:.1f}".format( 562 | self.many_acc_top1 * 100, 563 | self.median_acc_top1 * 100, 564 | self.low_acc_top1 * 100, 565 | self.eval_acc_mic_top1 * 100)] 566 | if self.log_file is not None and os.path.exists(self.log_file): 567 | print_write(print_str, self.log_file) 568 | print_write(acc_str, self.log_file) 569 | else: 570 | print(*print_str) 571 | print(*acc_str) 572 | 573 | if phase == 'test': 574 | with open(os.path.join(self.training_opt['log_dir'], 'cls_accs.pkl'), 'wb') as f: 575 | pickle.dump(self.cls_accs, f) 576 | return rsl 577 | 578 | def load_model(self, model_dir=None): 579 | model_dir = self.training_opt['log_dir'] if model_dir is None else model_dir 580 | if not model_dir.endswith('.pth'): 581 | print('No pretrained Phase A model') 582 | 583 | print('Validation on the best model.') 584 | print('Loading model from %s' % (model_dir)) 585 | 586 | checkpoint = torch.load(model_dir, map_location='cpu') 587 | model_state = checkpoint['state_dict_best'] 588 | 589 | self.visual_model.load_state_dict(model_state['visual_model']) 590 | self.text_model.load_state_dict(model_state['text_model']) 591 | 592 | if self.test_mode is True: 593 | self.adapter.load_state_dict(model_state['classifier']) 594 | 595 | def save_latest(self, epoch): 596 | model_weights = {} 597 | model_weights['visual_model'] = copy.deepcopy(self.visual_model.state_dict()) 598 | model_weights['text_model'] = copy.deepcopy(self.text_model.state_dict()) 599 | if self.training_opt['phaseA'] is not True: 600 | model_weights['classifier'] = copy.deepcopy(self.adapter.state_dict()) 601 | 602 | model_states = { 603 | 'epoch': epoch, 604 | 'state_dict': model_weights 605 | } 606 | 607 | model_dir = os.path.join(self.training_opt['log_dir'], 608 | 'latest_model_checkpoint.pth') 609 | torch.save(model_states, model_dir) 610 | 611 | def save_model(self, epoch, best_epoch, best_model_weights, best_acc, centroids=None): 612 | 613 | model_states = {'epoch': epoch, 614 | 'best_epoch': best_epoch, 615 | 'state_dict_best': best_model_weights, 616 | 'best_acc': best_acc, 617 | 'centroids': centroids} 618 | 619 | model_dir = os.path.join(self.training_opt['log_dir'], 620 | 'final_model_checkpoint.pth') 621 | 622 | torch.save(model_states, model_dir) 623 | 624 | def output_logits(self, openset=False): 625 | filename = os.path.join(self.training_opt['log_dir'], 626 | 'logits_%s'%('open' if openset else 'close')) 627 | print("Saving total logits to: %s.npz" % filename) 628 | np.savez(filename, 629 | logits=self.total_logits.detach().cpu().numpy(), 630 | labels=self.total_labels.detach().cpu().numpy(), 631 | paths=self.total_paths) 632 | 633 | --------------------------------------------------------------------------------