├── illustration.png ├── finalized_model.sav ├── data ├── __init__.py ├── __pycache__ │ ├── datamgr.cpython-35.pyc │ ├── datamgr.cpython-36.pyc │ ├── dataset.cpython-35.pyc │ ├── dataset.cpython-36.pyc │ ├── __init__.cpython-35.pyc │ ├── __init__.cpython-36.pyc │ ├── feature_loader.cpython-35.pyc │ ├── additional_transforms.cpython-35.pyc │ └── additional_transforms.cpython-36.pyc ├── additional_transforms.py ├── dataset.py └── datamgr.py ├── __pycache__ └── FSLTask.cpython-37.pyc ├── requirements.txt ├── configs.py ├── filelists ├── CUB │ ├── download_CUB.sh │ └── write_CUB_filelist.py └── miniImagenet │ ├── download_miniImagenet.sh │ └── write_miniImagenet_filelist.py ├── .idea ├── vcs.xml ├── .gitignore ├── inspectionProfiles │ └── profiles_settings.xml ├── modules.xml └── Few_Shot_Distribution_Calibration.iml ├── io_utils.py ├── save_plk.py ├── README.md ├── wrn_model.py ├── FSLTask.py └── evaluate_DC.py /illustration.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KPeng9510/TransDARC/HEAD/illustration.png -------------------------------------------------------------------------------- /finalized_model.sav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KPeng9510/TransDARC/HEAD/finalized_model.sav -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from . import datamgr 2 | from . import dataset 3 | from . import additional_transforms 4 | -------------------------------------------------------------------------------- /__pycache__/FSLTask.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KPeng9510/TransDARC/HEAD/__pycache__/FSLTask.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/datamgr.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KPeng9510/TransDARC/HEAD/data/__pycache__/datamgr.cpython-35.pyc -------------------------------------------------------------------------------- /data/__pycache__/datamgr.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KPeng9510/TransDARC/HEAD/data/__pycache__/datamgr.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/dataset.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KPeng9510/TransDARC/HEAD/data/__pycache__/dataset.cpython-35.pyc -------------------------------------------------------------------------------- /data/__pycache__/dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KPeng9510/TransDARC/HEAD/data/__pycache__/dataset.cpython-36.pyc -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.17.2 2 | matplotlib==3.1.1 3 | tqdm==4.36.1 4 | torchvision==0.6.0 5 | torch==1.5.0 6 | Pillow==7.1.2 7 | -------------------------------------------------------------------------------- /data/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KPeng9510/TransDARC/HEAD/data/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /data/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KPeng9510/TransDARC/HEAD/data/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/feature_loader.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KPeng9510/TransDARC/HEAD/data/__pycache__/feature_loader.cpython-35.pyc -------------------------------------------------------------------------------- /data/__pycache__/additional_transforms.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KPeng9510/TransDARC/HEAD/data/__pycache__/additional_transforms.cpython-35.pyc -------------------------------------------------------------------------------- /data/__pycache__/additional_transforms.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KPeng9510/TransDARC/HEAD/data/__pycache__/additional_transforms.cpython-36.pyc -------------------------------------------------------------------------------- /configs.py: -------------------------------------------------------------------------------- 1 | save_dir = '.' 2 | data_dir = {} 3 | data_dir['CUB'] = './filelists/CUB/' 4 | data_dir['miniImagenet'] = './filelists/miniImagenet/' 5 | -------------------------------------------------------------------------------- /filelists/CUB/download_CUB.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | wget http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz 3 | tar -zxvf CUB_200_2011.tgz 4 | python write_CUB_filelist.py 5 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Editor-based HTTP Client requests 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /filelists/miniImagenet/download_miniImagenet.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | wget https://raw.githubusercontent.com/twitter/meta-learning-lstm/master/data/miniImagenet/train.csv 3 | wget https://raw.githubusercontent.com/twitter/meta-learning-lstm/master/data/miniImagenet/val.csv 4 | wget https://raw.githubusercontent.com/twitter/meta-learning-lstm/master/data/miniImagenet/test.csv 5 | wget http://image-net.org/image/ILSVRC2015/ILSVRC2015_CLS-LOC.tar.gz 6 | tar -zxvf ILSVRC2015_CLS-LOC.tar.gz 7 | python write_miniImagenet_filelist.py 8 | -------------------------------------------------------------------------------- /.idea/Few_Shot_Distribution_Calibration.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /data/additional_transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import torch 9 | from PIL import ImageEnhance 10 | 11 | transformtypedict=dict(Brightness=ImageEnhance.Brightness, Contrast=ImageEnhance.Contrast, Sharpness=ImageEnhance.Sharpness, Color=ImageEnhance.Color) 12 | 13 | 14 | 15 | class ImageJitter(object): 16 | def __init__(self, transformdict): 17 | self.transforms = [(transformtypedict[k], transformdict[k]) for k in transformdict] 18 | 19 | 20 | def __call__(self, img): 21 | out = img 22 | randtensor = torch.rand(len(self.transforms)) 23 | 24 | for i, (transformer, alpha) in enumerate(self.transforms): 25 | r = alpha*(randtensor[i]*2.0 -1.0) + 1 26 | out = transformer(out).enhance(r).convert('RGB') 27 | 28 | return out 29 | 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /filelists/CUB/write_CUB_filelist.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from os import listdir 3 | from os.path import isfile, isdir, join 4 | import os 5 | import json 6 | import random 7 | 8 | cwd = os.getcwd() 9 | data_path = join(cwd,'CUB_200_2011/images') 10 | savedir = './' 11 | dataset_list = ['base','val','novel'] 12 | 13 | #if not os.path.exists(savedir): 14 | # os.makedirs(savedir) 15 | 16 | folder_list = [f for f in listdir(data_path) if isdir(join(data_path, f))] 17 | folder_list.sort() 18 | label_dict = dict(zip(folder_list,range(0,len(folder_list)))) 19 | 20 | classfile_list_all = [] 21 | 22 | for i, folder in enumerate(folder_list): 23 | folder_path = join(data_path, folder) 24 | classfile_list_all.append( [ join(folder_path, cf) for cf in listdir(folder_path) if (isfile(join(folder_path,cf)) and cf[0] != '.')]) 25 | random.shuffle(classfile_list_all[i]) 26 | 27 | 28 | for dataset in dataset_list: 29 | file_list = [] 30 | label_list = [] 31 | for i, classfile_list in enumerate(classfile_list_all): 32 | if 'base' in dataset: 33 | if (i%2 == 0): 34 | file_list = file_list + classfile_list 35 | label_list = label_list + np.repeat(i, len(classfile_list)).tolist() 36 | if 'val' in dataset: 37 | if (i%4 == 1): 38 | file_list = file_list + classfile_list 39 | label_list = label_list + np.repeat(i, len(classfile_list)).tolist() 40 | if 'novel' in dataset: 41 | if (i%4 == 3): 42 | file_list = file_list + classfile_list 43 | label_list = label_list + np.repeat(i, len(classfile_list)).tolist() 44 | 45 | fo = open(savedir + dataset + ".json", "w") 46 | fo.write('{"label_names": [') 47 | fo.writelines(['"%s",' % item for item in folder_list]) 48 | fo.seek(0, os.SEEK_END) 49 | fo.seek(fo.tell()-1, os.SEEK_SET) 50 | fo.write('],') 51 | 52 | fo.write('"image_names": [') 53 | fo.writelines(['"%s",' % item for item in file_list]) 54 | fo.seek(0, os.SEEK_END) 55 | fo.seek(fo.tell()-1, os.SEEK_SET) 56 | fo.write('],') 57 | 58 | fo.write('"image_labels": [') 59 | fo.writelines(['%d,' % item for item in label_list]) 60 | fo.seek(0, os.SEEK_END) 61 | fo.seek(fo.tell()-1, os.SEEK_SET) 62 | fo.write(']}') 63 | 64 | fo.close() 65 | print("%s -OK" %dataset) 66 | -------------------------------------------------------------------------------- /filelists/miniImagenet/write_miniImagenet_filelist.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from os import listdir 3 | from os.path import isfile, isdir, join 4 | import os 5 | import json 6 | import random 7 | import re 8 | 9 | cwd = os.getcwd() 10 | data_path = join(cwd,'ILSVRC2015/Data/CLS-LOC/train') 11 | #data_path = join('/home/yuqing/phd/code/miniimagenet/images') 12 | savedir = './' 13 | dataset_list = ['base', 'val', 'novel'] 14 | 15 | #if not os.path.exists(savedir): 16 | # os.makedirs(savedir) 17 | 18 | cl = -1 19 | folderlist = [] 20 | 21 | datasetmap = {'base':'train','val':'val','novel':'test'}; 22 | filelists = {'base':{},'val':{},'novel':{} } 23 | filelists_flat = {'base':[],'val':[],'novel':[] } 24 | labellists_flat = {'base':[],'val':[],'novel':[] } 25 | 26 | for dataset in dataset_list: 27 | with open(datasetmap[dataset] + ".csv", "r") as lines: 28 | for i, line in enumerate(lines): 29 | if i == 0: 30 | continue 31 | fid, _ , label = re.split(',|\.', line) 32 | label = label.replace('\n','') 33 | if not label in filelists[dataset]: 34 | folderlist.append(label) 35 | filelists[dataset][label] = [] 36 | fnames = listdir( join(data_path, label) ) 37 | fname_number = [ int(re.split('_|\.', fname)[1]) for fname in fnames] 38 | sorted_fnames = list(zip( *sorted( zip(fnames, fname_number), key = lambda f_tuple: f_tuple[1] )))[0] 39 | 40 | fid = int(fid[-5:])-1 41 | fname = join( data_path,label, sorted_fnames[fid] ) 42 | filelists[dataset][label].append(fname) 43 | 44 | for key, filelist in filelists[dataset].items(): 45 | cl += 1 46 | random.shuffle(filelist) 47 | filelists_flat[dataset] += filelist 48 | labellists_flat[dataset] += np.repeat(cl, len(filelist)).tolist() 49 | 50 | for dataset in dataset_list: 51 | fo = open(savedir + dataset + ".json", "w") 52 | fo.write('{"label_names": [') 53 | fo.writelines(['"%s",' % item for item in folderlist]) 54 | fo.seek(0, os.SEEK_END) 55 | fo.seek(fo.tell()-1, os.SEEK_SET) 56 | fo.write('],') 57 | 58 | fo.write('"image_names": [') 59 | fo.writelines(['"%s",' % item for item in filelists_flat[dataset]]) 60 | fo.seek(0, os.SEEK_END) 61 | fo.seek(fo.tell()-1, os.SEEK_SET) 62 | fo.write('],') 63 | 64 | fo.write('"image_labels": [') 65 | fo.writelines(['%d,' % item for item in labellists_flat[dataset]]) 66 | fo.seek(0, os.SEEK_END) 67 | fo.seek(fo.tell()-1, os.SEEK_SET) 68 | fo.write(']}') 69 | 70 | fo.close() 71 | print("%s -OK" %dataset) 72 | -------------------------------------------------------------------------------- /io_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import glob 4 | import argparse 5 | 6 | import numpy as np 7 | import os 8 | import glob 9 | import argparse 10 | 11 | 12 | def parse_args(script): 13 | parser = argparse.ArgumentParser(description='few-shot script %s' % (script)) 14 | parser.add_argument('--dataset', default='miniImagenet', help='CUB/miniImagenet') 15 | parser.add_argument('--model', default='WideResNet28_10', help='model: WideResNet28_10/ResNet{18}') 16 | parser.add_argument('--method', default='S2M2_R', help='rotation/S2M2_R') 17 | parser.add_argument('--train_aug', default='True', 18 | help='perform data augmentation or not during training ') # still required for save_features.py and test.py to find the model path correctly 19 | 20 | if script == 'train': 21 | parser.add_argument('--num_classes', default=200, type=int, 22 | help='total number of classes') # make it larger than the maximum label value in base class 23 | parser.add_argument('--save_freq', default=10, type=int, help='Save frequency') 24 | parser.add_argument('--start_epoch', default=0, type=int, help='Starting epoch') 25 | parser.add_argument('--stop_epoch', default=400, type=int, 26 | help='Stopping epoch') # for meta-learning methods, each epoch contains 100 episodes. The default epoch number is dataset dependent. See train.py 27 | parser.add_argument('--resume', action='store_true', 28 | help='continue from previous trained model with largest epoch') 29 | parser.add_argument('--lr', default=0.001, type=int, help='learning rate') 30 | parser.add_argument('--batch_size', default=16, type=int, help='batch size ') 31 | parser.add_argument('--test_batch_size', default=2, type=int, help='batch size ') 32 | parser.add_argument('--alpha', default=2.0, type=int, help='for S2M2 training ') 33 | elif script == 'test': 34 | parser.add_argument('--num_classes', default=200, type=int, help='total number of classes') 35 | 36 | return parser.parse_args() 37 | 38 | 39 | def get_assigned_file(checkpoint_dir, num): 40 | assign_file = os.path.join(checkpoint_dir, '{:d}.tar'.format(num)) 41 | return assign_file 42 | 43 | 44 | def get_resume_file(checkpoint_dir): 45 | filelist = glob.glob(os.path.join(checkpoint_dir, '*.tar')) 46 | if len(filelist) == 0: 47 | return None 48 | 49 | filelist = [x for x in filelist if os.path.basename(x) != 'best.tar'] 50 | epochs = np.array([int(os.path.splitext(os.path.basename(x))[0]) for x in filelist]) 51 | max_epoch = np.max(epochs) 52 | resume_file = os.path.join(checkpoint_dir, '{:d}.tar'.format(max_epoch)) 53 | return resume_file 54 | 55 | 56 | def get_best_file(checkpoint_dir): 57 | best_file = os.path.join(checkpoint_dir, 'best.tar') 58 | if os.path.isfile(best_file): 59 | return best_file 60 | else: 61 | return get_resume_file(checkpoint_dir) -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | # This code is modified from https://github.com/facebookresearch/low-shot-shrink-hallucinate 2 | 3 | import torch 4 | from PIL import Image 5 | import json 6 | import numpy as np 7 | import torchvision.transforms as transforms 8 | import os 9 | identity = lambda x:x 10 | class SimpleDataset: 11 | def __init__(self, data_file, transform, target_transform=identity): 12 | with open(data_file, 'r') as f: 13 | self.meta = json.load(f) 14 | self.transform = transform 15 | self.target_transform = target_transform 16 | 17 | 18 | def __getitem__(self,i): 19 | image_path = os.path.join(self.meta['image_names'][i]) 20 | img = Image.open(image_path).convert('RGB') 21 | img = self.transform(img) 22 | target = self.target_transform(self.meta['image_labels'][i]) 23 | return img, target 24 | 25 | def __len__(self): 26 | return len(self.meta['image_names']) 27 | 28 | 29 | class SetDataset: 30 | def __init__(self, data_file, batch_size, transform): 31 | with open(data_file, 'r') as f: 32 | self.meta = json.load(f) 33 | 34 | self.cl_list = np.unique(self.meta['image_labels']).tolist() 35 | 36 | self.sub_meta = {} 37 | for cl in self.cl_list: 38 | self.sub_meta[cl] = [] 39 | 40 | for x,y in zip(self.meta['image_names'],self.meta['image_labels']): 41 | self.sub_meta[y].append(x) 42 | 43 | self.sub_dataloader = [] 44 | sub_data_loader_params = dict(batch_size = batch_size, 45 | shuffle = True, 46 | num_workers = 0, #use main thread only or may receive multiple batches 47 | pin_memory = False) 48 | for cl in self.cl_list: 49 | sub_dataset = SubDataset(self.sub_meta[cl], cl, transform = transform ) 50 | self.sub_dataloader.append( torch.utils.data.DataLoader(sub_dataset, **sub_data_loader_params) ) 51 | 52 | def __getitem__(self,i): 53 | return next(iter(self.sub_dataloader[i])) 54 | 55 | def __len__(self): 56 | return len(self.cl_list) 57 | 58 | class SubDataset: 59 | def __init__(self, sub_meta, cl, transform=transforms.ToTensor(), target_transform=identity): 60 | self.sub_meta = sub_meta 61 | self.cl = cl 62 | self.transform = transform 63 | self.target_transform = target_transform 64 | 65 | def __getitem__(self,i): 66 | #print( '%d -%d' %(self.cl,i)) 67 | image_path = os.path.join( self.sub_meta[i]) 68 | img = Image.open(image_path).convert('RGB') 69 | img = self.transform(img) 70 | target = self.target_transform(self.cl) 71 | return img, target 72 | 73 | def __len__(self): 74 | return len(self.sub_meta) 75 | 76 | class EpisodicBatchSampler(object): 77 | def __init__(self, n_classes, n_way, n_episodes): 78 | self.n_classes = n_classes 79 | self.n_way = n_way 80 | self.n_episodes = n_episodes 81 | 82 | def __len__(self): 83 | return self.n_episodes 84 | 85 | def __iter__(self): 86 | for i in range(self.n_episodes): 87 | yield torch.randperm(self.n_classes)[:self.n_way] 88 | -------------------------------------------------------------------------------- /data/datamgr.py: -------------------------------------------------------------------------------- 1 | # This code is modified from https://github.com/facebookresearch/low-shot-shrink-hallucinate 2 | 3 | import torch 4 | from PIL import Image 5 | import numpy as np 6 | import torchvision.transforms as transforms 7 | import data.additional_transforms as add_transforms 8 | from data.dataset import SimpleDataset, SetDataset, EpisodicBatchSampler 9 | from abc import abstractmethod 10 | 11 | class TransformLoader: 12 | def __init__(self, image_size, 13 | normalize_param = dict(mean= [0.485, 0.456, 0.406] , std=[0.229, 0.224, 0.225]), 14 | jitter_param = dict(Brightness=0.4, Contrast=0.4, Color=0.4)): 15 | self.image_size = image_size 16 | self.normalize_param = normalize_param 17 | self.jitter_param = jitter_param 18 | 19 | def parse_transform(self, transform_type): 20 | if transform_type=='ImageJitter': 21 | method = add_transforms.ImageJitter( self.jitter_param ) 22 | return method 23 | method = getattr(transforms, transform_type) 24 | if transform_type=='RandomSizedCrop': 25 | return method(self.image_size) 26 | elif transform_type=='CenterCrop': 27 | return method(self.image_size) 28 | elif transform_type=='Scale': 29 | return method([int(self.image_size*1.15), int(self.image_size*1.15)]) 30 | elif transform_type=='Normalize': 31 | return method(**self.normalize_param ) 32 | else: 33 | return method() 34 | 35 | def get_composed_transform(self, aug = False): 36 | if aug: 37 | transform_list = ['RandomSizedCrop', 'ImageJitter', 'RandomHorizontalFlip', 'ToTensor', 'Normalize'] 38 | else: 39 | transform_list = ['Scale','CenterCrop', 'ToTensor', 'Normalize'] 40 | 41 | transform_funcs = [ self.parse_transform(x) for x in transform_list] 42 | transform = transforms.Compose(transform_funcs) 43 | return transform 44 | 45 | class DataManager: 46 | @abstractmethod 47 | def get_data_loader(self, data_file, aug): 48 | pass 49 | 50 | 51 | class SimpleDataManager(DataManager): 52 | def __init__(self, image_size, batch_size): 53 | super(SimpleDataManager, self).__init__() 54 | self.batch_size = batch_size 55 | self.trans_loader = TransformLoader(image_size) 56 | 57 | def get_data_loader(self, data_file, aug): #parameters that would change on train/val set 58 | transform = self.trans_loader.get_composed_transform(aug) 59 | dataset = SimpleDataset(data_file, transform) 60 | data_loader_params = dict(batch_size = self.batch_size, shuffle = True, num_workers = 12, pin_memory = True) 61 | data_loader = torch.utils.data.DataLoader(dataset, **data_loader_params) 62 | 63 | return data_loader 64 | 65 | class SetDataManager(DataManager): 66 | def __init__(self, image_size, n_way, n_support, n_query, n_eposide =100): 67 | super(SetDataManager, self).__init__() 68 | self.image_size = image_size 69 | self.n_way = n_way 70 | self.batch_size = n_support + n_query 71 | self.n_eposide = n_eposide 72 | 73 | self.trans_loader = TransformLoader(image_size) 74 | 75 | def get_data_loader(self, data_file, aug): #parameters that would change on train/val set 76 | transform = self.trans_loader.get_composed_transform(aug) 77 | dataset = SetDataset( data_file , self.batch_size, transform ) 78 | sampler = EpisodicBatchSampler(len(dataset), self.n_way, self.n_eposide ) 79 | data_loader_params = dict(batch_sampler = sampler, num_workers = 12, pin_memory = True) 80 | data_loader = torch.utils.data.DataLoader(dataset, **data_loader_params) 81 | return data_loader 82 | 83 | 84 | -------------------------------------------------------------------------------- /save_plk.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import argparse 4 | import csv 5 | import os 6 | import collections 7 | import pickle 8 | import random 9 | 10 | import numpy as np 11 | import torch 12 | from torch.autograd import Variable 13 | import torch.backends.cudnn as cudnn 14 | import torch.nn as nn 15 | import torch.optim as optim 16 | import torchvision.transforms as transforms 17 | import torchvision.datasets as datasets 18 | from io_utils import parse_args 19 | from data.datamgr import SimpleDataManager , SetDataManager 20 | import configs 21 | 22 | import wrn_model 23 | 24 | import torch.nn.functional as F 25 | 26 | from io_utils import parse_args, get_resume_file ,get_assigned_file 27 | from os import path 28 | 29 | use_gpu = torch.cuda.is_available() 30 | 31 | class WrappedModel(nn.Module): 32 | def __init__(self, module): 33 | super(WrappedModel, self).__init__() 34 | self.module = module 35 | def forward(self, x): 36 | return self.module(x) 37 | 38 | def save_pickle(file, data): 39 | with open(file, 'wb') as f: 40 | pickle.dump(data, f) 41 | 42 | def load_pickle(file): 43 | with open(file, 'rb') as f: 44 | return pickle.load(f) 45 | 46 | def extract_feature(val_loader, model, checkpoint_dir, tag='last',set='base'): 47 | save_dir = '{}/{}'.format(checkpoint_dir, tag) 48 | if os.path.isfile(save_dir + '/%s_features.plk'%set): 49 | data = load_pickle(save_dir + '/%s_features.plk'%set) 50 | return data 51 | else: 52 | if not os.path.isdir(save_dir): 53 | os.makedirs(save_dir) 54 | 55 | #model.eval() 56 | with torch.no_grad(): 57 | 58 | output_dict = collections.defaultdict(list) 59 | 60 | for i, (inputs, labels) in enumerate(val_loader): 61 | # compute output 62 | inputs = inputs.cuda() 63 | labels = labels.cuda() 64 | outputs,_ = model(inputs) 65 | outputs = outputs.cpu().data.numpy() 66 | 67 | for out, label in zip(outputs, labels): 68 | output_dict[label.item()].append(out) 69 | 70 | all_info = output_dict 71 | save_pickle(save_dir + '/%s_features.plk'%set, all_info) 72 | return all_info 73 | 74 | if __name__ == '__main__': 75 | params = parse_args('test') 76 | params.model = 'WideResNet28_10' 77 | params.method = 'S2M2_R' 78 | 79 | loadfile_base = configs.data_dir[params.dataset] + 'base.json' 80 | loadfile_novel = configs.data_dir[params.dataset] + 'novel.json' 81 | if params.dataset == 'miniImagenet' or params.dataset == 'CUB': 82 | datamgr = SimpleDataManager(84, batch_size = 256) 83 | base_loader = datamgr.get_data_loader(loadfile_base, aug=False) 84 | novel_loader = datamgr.get_data_loader(loadfile_novel, aug = False) 85 | 86 | checkpoint_dir = '%s/checkpoints/%s/%s_%s' %(configs.save_dir, params.dataset, params.model, params.method) 87 | modelfile = get_resume_file(checkpoint_dir) 88 | 89 | if params.model == 'WideResNet28_10': 90 | model = wrn_model.wrn28_10(num_classes=params.num_classes) 91 | 92 | 93 | model = model.cuda() 94 | cudnn.benchmark = True 95 | 96 | checkpoint = torch.load(modelfile) 97 | state = checkpoint['state'] 98 | state_keys = list(state.keys()) 99 | 100 | callwrap = False 101 | if 'module' in state_keys[0]: 102 | callwrap = True 103 | if callwrap: 104 | model = WrappedModel(model) 105 | model_dict_load = model.state_dict() 106 | model_dict_load.update(state) 107 | model.load_state_dict(model_dict_load) 108 | model.eval() 109 | output_dict_base = extract_feature(base_loader, model, checkpoint_dir, tag='last', set='base') 110 | print("base set features saved!") 111 | output_dict_novel=extract_feature(novel_loader, model, checkpoint_dir, tag='last',set='novel') 112 | print("novel features saved!") 113 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TransDARC (IROS 2022, ReadME in progress) 2 | 3 | ## News 4 | **TransDARC** [[**PDF**](https://arxiv.org/pdf/2203.00927.pdf)] is accepted to **IROS2022** for an **Oral** presentation. 5 | 6 | ## Extract and save features 7 | Please first train the video feature extraction backbone from video swin transformer in mmaction2 repo using drive and act dataset, 8 | 9 | In our paper, we use the same configuration of the video swin transformer for swin-base as mentioned in https://github.com/SwinTransformer/Video-Swin-Transformer.git while using the ImageNet pretraining. 10 | 11 | The corresponding configuration of traing for swin_base_patch244_window877_kinetics400_22k.py 12 | 13 | 14 | 15 | train_pipeline = [ 16 | 17 | 18 | dict(type='SampleFrames', clip_len=32, frame_interval=2, num_clips=2), 19 | 20 | 21 | dict(type='Resize', scale=(224, 224), keep_ratio=False), 22 | 23 | 24 | dict(type='FormatShape', input_format='NCTHW'), 25 | 26 | 27 | dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), 28 | 29 | 30 | dict(type='ToTensor', keys=['imgs', 'label']) 31 | 32 | 33 | ] 34 | 35 | 36 | val_pipeline = [ 37 | 38 | 39 | dict( 40 | 41 | 42 | type='SampleFrames', 43 | 44 | 45 | clip_len=32, 46 | 47 | 48 | frame_interval=2, 49 | 50 | 51 | num_clips=2, 52 | 53 | 54 | test_mode=True), 55 | 56 | 57 | dict(type='FormatShape', input_format='NCTHW'), 58 | 59 | 60 | dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), 61 | 62 | 63 | dict(type='ToTensor', keys=['imgs', 'label']) 64 | 65 | 66 | ] 67 | 68 | 69 | test_pipeline = [ 70 | 71 | 72 | dict( 73 | 74 | 75 | type='SampleFrames', 76 | 77 | 78 | clip_len=32, 79 | 80 | 81 | frame_interval=2, 82 | 83 | 84 | num_clips=2, 85 | 86 | 87 | test_mode=True), 88 | 89 | 90 | dict(type='FormatShape', input_format='NCTHW'), 91 | 92 | 93 | dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), 94 | 95 | 96 | dict(type='ToTensor', keys=['imgs', 'label']) 97 | 98 | 99 | ] 100 | 101 | PLease average the clips output at the top of video swin transformer 102 | 103 | ## Evaluate our distribution calibration 104 | 105 | To evaluate our distribution calibration method, run: 106 | 107 | ```eval 108 | python evaluate_DC.py 109 | ``` 110 | ## Verification of our code 111 | 112 | The logits of Video Swin Base are available at https://drive.google.com/drive/folders/1MJY8toH3PSV--pA2EvL8qjPIjcaTdmvr?usp=sharing, which is for fine-grained level split 0 driver activity recognition. (notice that, the result reported in the paper is the mean average over three splits, i.e., split0, split1, and split2) 113 | 114 | Please note that the performance is evaluated unfortunately differently by using unbalanced mean average Top-1 accuracy. Under balanced accuracy our model can still achieve 71.0 accuracy for fine-grained activity recognition on the test set. In case you want to compare with us, you could either evaluate following the same way our paper use and compare with TransDARC, or contact my email to get more results for the other tasks considering balanced accuracy evaluation. Thanks! 115 | 116 | ## Please consider citing our paper once you are interested in it. [[**PDF**](https://arxiv.org/pdf/2203.00927.pdf)] 117 | 118 | ``` 119 | @article{peng2022transdarc, 120 | title={TransDARC: Transformer-based Driver Activity Recognition with Latent Space Feature Calibration}, 121 | author={Peng, Kunyu and Roitberg, Alina and Yang, Kailun and Zhang, Jiaming and Stiefelhagen, Rainer}, 122 | journal={2022 IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS)}, 123 | year={2022} 124 | } 125 | ``` 126 | 127 | -------------------------------------------------------------------------------- /wrn_model.py: -------------------------------------------------------------------------------- 1 | ### dropout has been removed in this code. original code had dropout##### 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.init as init 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | 8 | import sys, os 9 | import numpy as np 10 | import random 11 | act = torch.nn.ReLU() 12 | 13 | 14 | import math 15 | from torch.nn.utils.weight_norm import WeightNorm 16 | 17 | 18 | 19 | class BasicBlock(nn.Module): 20 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0): 21 | super(BasicBlock, self).__init__() 22 | self.bn1 = nn.BatchNorm2d(in_planes) 23 | self.relu1 = nn.ReLU(inplace=True) 24 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 25 | padding=1, bias=False) 26 | self.bn2 = nn.BatchNorm2d(out_planes) 27 | self.relu2 = nn.ReLU(inplace=True) 28 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 29 | padding=1, bias=False) 30 | self.droprate = dropRate 31 | self.equalInOut = (in_planes == out_planes) 32 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 33 | padding=0, bias=False) or None 34 | def forward(self, x): 35 | if not self.equalInOut: 36 | x = self.relu1(self.bn1(x)) 37 | else: 38 | out = self.relu1(self.bn1(x)) 39 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 40 | if self.droprate > 0: 41 | out = F.dropout(out, p=self.droprate, training=self.training) 42 | out = self.conv2(out) 43 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 44 | 45 | class distLinear(nn.Module): 46 | def __init__(self, indim, outdim): 47 | super(distLinear, self).__init__() 48 | self.L = nn.Linear( indim, outdim, bias = False) 49 | self.class_wise_learnable_norm = True #See the issue#4&8 in the github 50 | if self.class_wise_learnable_norm: 51 | WeightNorm.apply(self.L, 'weight', dim=0) #split the weight update component to direction and norm 52 | 53 | if outdim <=200: 54 | self.scale_factor = 2; #a fixed scale factor to scale the output of cos value into a reasonably large input for softmax 55 | else: 56 | self.scale_factor = 10; #in omniglot, a larger scale factor is required to handle >1000 output classes. 57 | 58 | def forward(self, x): 59 | x_norm = torch.norm(x, p=2, dim =1).unsqueeze(1).expand_as(x) 60 | x_normalized = x.div(x_norm+ 0.00001) 61 | if not self.class_wise_learnable_norm: 62 | L_norm = torch.norm(self.L.weight.data, p=2, dim =1).unsqueeze(1).expand_as(self.L.weight.data) 63 | self.L.weight.data = self.L.weight.data.div(L_norm + 0.00001) 64 | cos_dist = self.L(x_normalized) #matrix product by forward function, but when using WeightNorm, this also multiply the cosine distance by a class-wise learnable norm, see the issue#4&8 in the github 65 | scores = self.scale_factor* (cos_dist) 66 | 67 | return scores 68 | 69 | class NetworkBlock(nn.Module): 70 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 71 | super(NetworkBlock, self).__init__() 72 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) 73 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 74 | layers = [] 75 | for i in range(int(nb_layers)): 76 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) 77 | return nn.Sequential(*layers) 78 | def forward(self, x): 79 | return self.layer(x) 80 | 81 | 82 | def to_one_hot(inp,num_classes): 83 | 84 | y_onehot = torch.FloatTensor(inp.size(0), num_classes) 85 | if torch.cuda.is_available(): 86 | y_onehot = y_onehot.cuda() 87 | 88 | y_onehot.zero_() 89 | x = inp.type(torch.LongTensor) 90 | if torch.cuda.is_available(): 91 | x = x.cuda() 92 | 93 | x = torch.unsqueeze(x , 1) 94 | y_onehot.scatter_(1, x , 1) 95 | 96 | return Variable(y_onehot,requires_grad=False) 97 | # return y_onehot 98 | 99 | 100 | def mixup_data(x, y, lam): 101 | 102 | '''Compute the mixup data. Return mixed inputs, pairs of targets, and lambda''' 103 | 104 | batch_size = x.size()[0] 105 | index = torch.randperm(batch_size) 106 | if torch.cuda.is_available(): 107 | index = index.cuda() 108 | mixed_x = lam * x + (1 - lam) * x[index,:] 109 | y_a, y_b = y, y[index] 110 | 111 | return mixed_x, y_a, y_b, lam 112 | 113 | 114 | class WideResNet(nn.Module): 115 | def __init__(self, depth=28, widen_factor=10, num_classes= 200 , loss_type = 'dist', per_img_std = False, stride = 1 ): 116 | dropRate = 0.5 117 | flatten = True 118 | super(WideResNet, self).__init__() 119 | nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor] 120 | assert((depth - 4) % 6 == 0) 121 | n = (depth - 4) / 6 122 | block = BasicBlock 123 | # 1st conv before any network block 124 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 125 | padding=1, bias=False) 126 | # 1st block 127 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, stride, dropRate) 128 | # 2nd block 129 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 130 | # 3rd block 131 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 132 | # global average pooling and linear 133 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 134 | self.relu = nn.ReLU(inplace=True) 135 | self.nChannels = nChannels[3] 136 | 137 | if loss_type == 'softmax': 138 | self.linear = nn.Linear(nChannels[3], int(num_classes)) 139 | self.linear.bias.data.fill_(0) 140 | else: 141 | self.linear = distLinear(nChannels[3], int(num_classes)) 142 | 143 | self.num_classes = num_classes 144 | if flatten: 145 | self.final_feat_dim = 640 146 | for m in self.modules(): 147 | if isinstance(m, nn.Conv2d): 148 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 149 | m.weight.data.normal_(0, math.sqrt(2. / n)) 150 | elif isinstance(m, nn.BatchNorm2d): 151 | m.weight.data.fill_(1) 152 | m.bias.data.zero_() 153 | 154 | 155 | def forward(self, x, target= None, mixup=False, mixup_hidden=True, mixup_alpha=None , lam = 0.4): 156 | if target is not None: 157 | if mixup_hidden: 158 | layer_mix = random.randint(0,3) 159 | elif mixup: 160 | layer_mix = 0 161 | else: 162 | layer_mix = None 163 | 164 | out = x 165 | 166 | target_a = target_b = target 167 | 168 | if layer_mix == 0: 169 | out, target_a , target_b , lam = mixup_data(out, target, lam=lam) 170 | 171 | out = self.conv1(out) 172 | out = self.block1(out) 173 | 174 | 175 | if layer_mix == 1: 176 | out, target_a , target_b , lam = mixup_data(out, target, lam=lam) 177 | 178 | out = self.block2(out) 179 | 180 | if layer_mix == 2: 181 | out, target_a , target_b , lam = mixup_data(out, target, lam=lam) 182 | 183 | 184 | out = self.block3(out) 185 | if layer_mix == 3: 186 | out, target_a , target_b , lam = mixup_data(out, target, lam=lam) 187 | 188 | out = self.relu(self.bn1(out)) 189 | out = F.avg_pool2d(out, out.size()[2:]) 190 | out = out.view(out.size(0), -1) 191 | out1 = self.linear(out) 192 | 193 | return out , out1 , target_a , target_b 194 | else: 195 | out = x 196 | out = self.conv1(out) 197 | out = self.block1(out) 198 | out = self.block2(out) 199 | out = self.block3(out) 200 | out = self.relu(self.bn1(out)) 201 | out = F.avg_pool2d(out, out.size()[2:]) 202 | out = out.view(out.size(0), -1) 203 | out1 = self.linear(out) 204 | return out, out1 205 | 206 | 207 | 208 | def wrn28_10(num_classes=200 , loss_type = 'dist'): 209 | model = WideResNet(depth=28, widen_factor=10, num_classes=num_classes, loss_type = loss_type , per_img_std = False, stride = 1 ) 210 | return model 211 | 212 | -------------------------------------------------------------------------------- /FSLTask.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import numpy as np 4 | import torch 5 | import sys 6 | from pprint import pprint 7 | # from tqdm import tqdm 8 | nc = 34 9 | # ======================================================== 10 | # Usefull paths 11 | _datasetFeaturesFiles = {"train": "/cvhci/data/activity/kpeng/logits_split0_chunk90_swin_base_last_logits768_last22train.pkl", 12 | #"train": "/cvhci/data/activity/kpeng/logits_split1_chunk90_swin_base_last_logits768_last_augmentedtrain.pkl", 13 | #"train": "/cvhci/data/activity/kpeng/logits_split2_chunk90_swin_base_last_logits768_lasttrain.pkl", 14 | #"train": "/cvhci/data/activity/kpeng/task_level_logits_split0_chunk90_swin_base_last_logits768_lasttrain.pkl", 15 | #"train": "/cvhci/data/activity/kpeng/object_level_logits_split2_chunk90_swin_base_last_logits768_lasttrain.pkl", 16 | #"train": "/cvhci/data/activity/kpeng/location_level_logits_split1_chunk90_swin_base_last_logits768_lasttrain.pkl", 17 | "train_aug": "/cvhci/data/activity/kpeng/logits_split0_chunk90_swin_base_last_logits768_last22augmentedtrain", 18 | #"train_aug": "/cvhci/data/activity/kpeng/logits_split1_chunk90_swin_base_last_logits768_lastaugmentedrealtrain.pkl", 19 | #"train_aug": "/cvhci/data/activity/kpeng/logits_split2_chunk90_swin_base_last_logits768_last_augmentedtrain.pkl", 20 | #"train_aug": "/cvhci/data/activity/kpeng/task_level_logits_split0_chunk90_swin_base_last_logits768_lastaugmentedtrain.pkl", 21 | #"train_aug": "/cvhci/data/activity/kpeng/object_level_logits_split2_chunk90_swin_base_last_logits768_lastaugmentedtrain.pkl", 22 | #"train_aug": "/cvhci/data/activity/kpeng/location_level_logits_split1_chunk90_swin_base_last_logits768_lastaugmentedtrain.pkl", 23 | "eval": "/cvhci/data/activity/kpeng/logits_split0_chunk90_swin_base_last_logits768_last22val.pkl", 24 | #"eval": "/cvhci/data/activity/kpeng/logits_split2_chunk90_swin_base_last_logits768_lastval.pkl", 25 | #"eval": "/cvhci/data/activity/kpeng/logits_split1_chunk90_swin_base_last_logits768_lastaugmentedval.pkl", 26 | #"eval": "/cvhci/data/activity/kpeng/task_level_logits_split0_chunk90_swin_base_last_logits768_lastval.pkl", 27 | #"test": "/cvhci/data/activity/kpeng/logits_split0_chunk90_swin_base_last_logits768_last_ids2test.pkl", 28 | #"eval": "/cvhci/data/activity/kpeng/object_level_logits_split2_chunk90_swin_base_last_logits768_lastval.pkl", 29 | #"eval": "/cvhci/data/activity/kpeng/location_level_logits_split1_chunk90_swin_base_last_logits768_lastval.pkl", 30 | "test": "/cvhci/data/activity/kpeng/logits_split0_chunk90_swin_base_last_logits768_last22test.pkl" 31 | #"test": "/cvhci/data/activity/kpeng/logits_split1_chunk90_swin_base_last_logits768_lastaugmentedtest.pkl" 32 | #"test": "/cvhci/data/activity/kpeng/task_level_logits_split0_chunk90_swin_base_last_logits768_lasttest.pkl" 33 | #"test": "/cvhci/data/activity/kpeng/object_level_logits_split2_chunk90_swin_base_last_logits768_lasttest.pkl", 34 | #"test": "/cvhci/data/activity/kpeng/location_level_logits_split1_chunk90_swin_base_last_logits768_lasttest.pkl", 35 | } 36 | _cacheDir = "./cache" 37 | _maxRuns = 10000 38 | _min_examples = -1 39 | 40 | # ======================================================== 41 | # Module internal functions and variables 42 | 43 | _randStates = None 44 | _rsCfg = None 45 | 46 | 47 | 48 | def load_label_feature(item): 49 | #f = open("/cvhci/data/activity/Drive&Act/kunyu/annotation_list.pkl", 'rb') 50 | f = open("/cvhci/data/activity/Drive&Act/kunyu/annotation_list.pkl", 'rb') 51 | annotation = [] 52 | class_index = pickle.load(f) 53 | print(class_index) 54 | f.close() 55 | infos = item.keys() 56 | class_label = [] 57 | for info in infos: 58 | info = ''.join([item[0] for item in list(info)]) 59 | #print(info) 60 | activity = info.split(',')[-2] 61 | if activity not in class_label: 62 | class_label.append(activity) 63 | #print(activity) 64 | label = class_index.index(activity) 65 | annotation.append(label) 66 | features = item.values() 67 | #print(features) 68 | print(class_label) 69 | feature = [term for term in features] 70 | #print(feature) 71 | return feature, annotation 72 | 73 | def _load_pickle(file_train,file_train_aug,file_eval,file_test): 74 | dataset = dict() 75 | with open(file_train, 'rb') as f: 76 | data = pickle.load(f) 77 | feature, annotation = load_label_feature(data) 78 | #labels = [np.full(shape=len(data[key]), fill_value=key) 79 | # for key in data] 80 | #data = [features for key in data for features in data[key]] 81 | dataset['data_train'] = torch.FloatTensor(np.stack(feature, axis=0)) 82 | #print(dataset['data_train'].size()) 83 | dataset['labels_train'] = torch.LongTensor(np.stack(annotation, axis=0)) 84 | with open(file_train_aug, 'rb') as f: 85 | data = pickle.load(f) 86 | feature, annotation = load_label_feature(data) 87 | # labels = [np.full(shape=len(data[key]), fill_value=key) 88 | # for key in data] 89 | # data = [features for key in data for features in data[key]] 90 | dataset['data_train_aug'] = torch.FloatTensor(np.stack(feature, axis=0)) 91 | # print(dataset['data_train'].size()) 92 | dataset['labels_train_aug'] = torch.LongTensor(np.stack(annotation, axis=0)) 93 | with open(file_eval, 'rb') as f: 94 | data = pickle.load(f) 95 | feature, annotation = load_label_feature(data) 96 | #labels = [np.full(shape=len(data[key]), fill_value=key) 97 | # for key in data] 98 | #data = [features for key in data for features in data[key]] 99 | 100 | dataset['data_eval'] = torch.FloatTensor(np.stack(feature, axis=0)) 101 | dataset['labels_eval'] = torch.LongTensor(np.stack(annotation, axis=0)) 102 | #print(dataset['labels_eval']) 103 | #sys.exit() 104 | with open(file_test, 'rb') as f: 105 | data = pickle.load(f) 106 | feature, annotation = load_label_feature(data) 107 | #labels = [np.full(shape=len(data[key]), fill_value=key) 108 | # for key in data] 109 | #data = [features for key in data for features in data[key]] 110 | dataset['data_test'] = torch.FloatTensor(np.stack(feature, axis=0)) 111 | dataset['labels_test'] = torch.LongTensor(np.stack(annotation,axis=0)) 112 | return dataset 113 | 114 | def calculate_samples_per_class(annotation): 115 | class_index_max = torch.max(annotation)+1 116 | number_per_class = torch.zeros(class_index_max) 117 | for i in range(class_index_max): 118 | number_per_class[i] = (annotation == i).sum() 119 | return number_per_class 120 | def arrange_dataset(data, annotation): 121 | class_index_max = torch.max(annotation)+1 122 | arranged_dataset = [] 123 | for i in range(class_index_max): 124 | mask = annotation == i 125 | arranged_dataset.append(data[mask,:,:]) 126 | return arranged_dataset 127 | def rare_class_selection(data, sam_number): 128 | #print(data.size()) 129 | rare_class_threshold = 100 130 | annotation = torch.arange(nc) 131 | print(sam_number) 132 | rare_mask = (sam_number < rare_class_threshold).bool() 133 | rich_mask = sam_number >= rare_class_threshold 134 | rare_mask_list = [int(item) for item in annotation[rare_mask].tolist()] 135 | #print(rare_mask_list) 136 | #print(data) 137 | rare_data = [data[i].squeeze() for i in rare_mask_list] 138 | rare_classes = annotation[rare_mask] 139 | #rich_data = data[rich_mask, :,:] 140 | rich_classes = annotation[rich_mask] 141 | rich_mask_list = [int(item) for item in annotation[rich_mask].tolist()] 142 | #print(rich_mask_list) 143 | # print(data) 144 | rich_data = [data[i].squeeze() for i in rich_mask_list] 145 | return rare_data, rare_classes, rich_data, rich_classes 146 | def load_dataset_driveact(): 147 | train_path = _datasetFeaturesFiles['train'] 148 | train_aug_path = _datasetFeaturesFiles['train_aug'] 149 | 150 | val_path= _datasetFeaturesFiles['eval'] 151 | test_path = _datasetFeaturesFiles['test'] 152 | #load_train_dataset 153 | 154 | dataset = _load_pickle(train_path,train_aug_path, val_path, test_path) 155 | feature_train = dataset['data_train'] 156 | annotation_train = dataset['labels_train'] 157 | feature_val, annotation_val = dataset['data_eval'], dataset['labels_eval'] 158 | feature_test, annotation_test = dataset['data_test'], dataset['labels_test'] 159 | feature_train_aug, annotation_train_aug = dataset['data_train_aug'], dataset['labels_train_aug'] 160 | #print(feature_train-feature_train_aug) 161 | #sys.exit() 162 | sample_per_class_train = calculate_samples_per_class(annotation_train_aug) 163 | sample_per_class_train_aug = calculate_samples_per_class(annotation_train) 164 | arrange_dataset_train = arrange_dataset(feature_train, annotation_train) 165 | arrange_dataset_train_aug = arrange_dataset(feature_train_aug, annotation_train_aug) 166 | rare_data, rare_classes, rich_data, rich_classes = rare_class_selection(arrange_dataset_train, sample_per_class_train) 167 | rare_data_aug, rare_classes_aug, rich_data_aug, rich_classes_aug = rare_class_selection(arrange_dataset_train_aug, sample_per_class_train_aug) 168 | 169 | return rare_data, rare_classes, rich_data, rich_classes, feature_val,annotation_val,feature_test, annotation_test, sample_per_class_train,rare_data_aug, rare_classes_aug, rich_data_aug, rich_classes_aug,sample_per_class_train_aug 170 | 171 | 172 | # ========================================================= 173 | # Callable variables and functions from outside the module 174 | 175 | data = None 176 | labels = None 177 | dsName = None 178 | 179 | 180 | def loadDataSet(dsname): 181 | if dsname not in _datasetFeaturesFiles: 182 | raise NameError('Unknwown dataset: {}'.format(dsname)) 183 | 184 | global dsName, data, labels, _randStates, _rsCfg, _min_examples 185 | dsName = dsname 186 | _randStates = None 187 | _rsCfg = None 188 | 189 | # Loading data from files on computer 190 | # home = expanduser("~") 191 | dataset = _load_pickle(_datasetFeaturesFiles[dsname]) 192 | 193 | # Computing the number of items per class in the dataset 194 | _min_examples = dataset["labels"].shape[0] 195 | for i in range(dataset["labels"].shape[0]): 196 | if torch.where(dataset["labels"] == dataset["labels"][i])[0].shape[0] > 0: 197 | _min_examples = min(_min_examples, torch.where( 198 | dataset["labels"] == dataset["labels"][i])[0].shape[0]) 199 | print("Guaranteed number of items per class: {:d}\n".format(_min_examples)) 200 | 201 | # Generating data tensors 202 | data = torch.zeros((0, _min_examples, dataset["data"].shape[1])) 203 | labels = dataset["labels"].clone() 204 | while labels.shape[0] > 0: 205 | indices = torch.where(dataset["labels"] == labels[0])[0] 206 | data = torch.cat([data, dataset["data"][indices, :] 207 | [:_min_examples].view(1, _min_examples, -1)], dim=0) 208 | indices = torch.where(labels != labels[0])[0] 209 | labels = labels[indices] 210 | print("Total of {:d} classes, {:d} elements each, with dimension {:d}\n".format( 211 | data.shape[0], data.shape[1], data.shape[2])) 212 | 213 | 214 | def GenerateRun(iRun, cfg, regenRState=False, generate=True): 215 | global _randStates, data, _min_examples 216 | if not regenRState: 217 | np.random.set_state(_randStates[iRun]) 218 | 219 | classes = np.random.permutation(np.arange(data.shape[0]))[:cfg["ways"]] 220 | shuffle_indices = np.arange(_min_examples) 221 | dataset = None 222 | if generate: 223 | dataset = torch.zeros( 224 | (cfg['ways'], cfg['shot']+cfg['queries'], data.shape[2])) 225 | for i in range(cfg['ways']): 226 | shuffle_indices = np.random.permutation(shuffle_indices) 227 | if generate: 228 | dataset[i] = data[classes[i], shuffle_indices, 229 | :][:cfg['shot']+cfg['queries']] 230 | 231 | return dataset 232 | 233 | 234 | def ClassesInRun(iRun, cfg): 235 | global _randStates, data 236 | np.random.set_state(_randStates[iRun]) 237 | classes = np.random.permutation(np.arange(data.shape[0]))[:cfg["ways"]] 238 | return classes 239 | 240 | 241 | def setRandomStates(cfg): 242 | global _randStates, _maxRuns, _rsCfg 243 | if _rsCfg == cfg: 244 | return 245 | rsFile = os.path.join(_cacheDir, "RandStates_{}_s{}_q{}_w{}".format( 246 | dsName, cfg['shot'], cfg['queries'], cfg['ways'])) 247 | if not os.path.exists(rsFile): 248 | print("{} does not exist, regenerating it...".format(rsFile)) 249 | np.random.seed(0) 250 | _randStates = [] 251 | for iRun in range(_maxRuns): 252 | _randStates.append(np.random.get_state()) 253 | GenerateRun(iRun, cfg, regenRState=True, generate=False) 254 | torch.save(_randStates, rsFile) 255 | else: 256 | print("reloading random states from file....") 257 | _randStates = torch.load(rsFile) 258 | _rsCfg = cfg 259 | 260 | 261 | def GenerateRunSet(start=None, end=None, cfg=None): 262 | global dataset, _maxRuns 263 | if start is None: 264 | start = 0 265 | if end is None: 266 | end = _maxRuns 267 | if cfg is None: 268 | cfg = {"shot": 1, "ways": 5, "queries": 15} 269 | 270 | setRandomStates(cfg) 271 | print("generating task from {} to {}".format(start, end)) 272 | 273 | dataset = torch.zeros( 274 | (end-start, cfg['ways'], cfg['shot']+cfg['queries'], data.shape[2])) 275 | for iRun in range(end-start): 276 | dataset[iRun] = GenerateRun(start+iRun, cfg) 277 | 278 | return dataset 279 | 280 | 281 | # define a main code to test this module 282 | if __name__ == "__main__": 283 | 284 | print("Testing Task loader for Few Shot Learning") 285 | loadDataSet('miniimagenet') 286 | 287 | cfg = {"shot": 1, "ways": 5, "queries": 15} 288 | setRandomStates(cfg) 289 | 290 | run10 = GenerateRun(10, cfg) 291 | print("First call:", run10[:2, :2, :2]) 292 | 293 | run10 = GenerateRun(10, cfg) 294 | print("Second call:", run10[:2, :2, :2]) 295 | 296 | ds = GenerateRunSet(start=2, end=12, cfg=cfg) 297 | print("Third call:", ds[8, :2, :2, :2]) 298 | print(ds.size()) 299 | -------------------------------------------------------------------------------- /evaluate_DC.py: -------------------------------------------------------------------------------- 1 | import torch.optim.lr_scheduler 2 | import pickle 3 | import numpy as np 4 | import torch.nn as nn 5 | import torch 6 | from sklearn.linear_model import LogisticRegression 7 | from tqdm import tqdm 8 | from FSLTask import load_dataset_driveact 9 | use_gpu = torch.cuda.is_available() 10 | import sklearn 11 | import sys 12 | from torch.utils.data import Dataset 13 | from torchvision.transforms import ToTensor 14 | from torchvision import datasets 15 | from torch.utils.data import DataLoader 16 | import random 17 | SEED = 42 18 | 19 | random.seed(SEED) 20 | torch.manual_seed(SEED) 21 | torch.cuda.manual_seed(SEED) 22 | torch.cuda.manual_seed_all(SEED) 23 | np.random.seed(SEED) 24 | torch.backends.cudnn.deterministic = True 25 | torch.backends.cudnn.benchmark = False 26 | from sklearn.metrics import confusion_matrix 27 | 28 | from einops import rearrange, repeat 29 | from einops.layers.torch import Rearrange 30 | global nc 31 | nc = 34 32 | # helpers 33 | 34 | def pair(t): 35 | return t if isinstance(t, tuple) else (t, t) 36 | 37 | # classes 38 | 39 | class PreNorm(nn.Module): 40 | def __init__(self, dim, fn): 41 | super().__init__() 42 | self.norm = nn.LayerNorm(dim) 43 | self.fn = fn 44 | def forward(self, x, **kwargs): 45 | return self.fn(self.norm(x), **kwargs) 46 | 47 | class FeedForward(nn.Module): 48 | def __init__(self, dim, hidden_dim, dropout = 0.): 49 | super().__init__() 50 | self.net = nn.Sequential( 51 | nn.Linear(dim, hidden_dim), 52 | nn.GELU(), 53 | nn.Dropout(dropout), 54 | nn.Linear(hidden_dim, dim), 55 | nn.Dropout(dropout) 56 | ) 57 | def forward(self, x): 58 | return self.net(x) 59 | 60 | class Attention(nn.Module): 61 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 62 | super().__init__() 63 | inner_dim = dim_head * heads 64 | project_out = not (heads == 1 and dim_head == dim) 65 | 66 | self.heads = heads 67 | self.scale = dim_head ** -0.5 68 | 69 | self.attend = nn.Softmax(dim = -1) 70 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 71 | 72 | self.to_out = nn.Sequential( 73 | nn.Linear(inner_dim, dim), 74 | nn.Dropout(dropout) 75 | ) if project_out else nn.Identity() 76 | 77 | def forward(self, x): 78 | qkv = self.to_qkv(x).chunk(3, dim = -1) 79 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) 80 | 81 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale 82 | 83 | attn = self.attend(dots) 84 | 85 | out = torch.matmul(attn, v) 86 | out = rearrange(out, 'b h n d -> b n (h d)') 87 | return self.to_out(out) 88 | 89 | 90 | def random_feature_interpolation(selected_mean, query, k, num): 91 | num_q = query.shape[0] 92 | #print(num_q) 93 | num_base = selected_mean.shape[0] 94 | origin = np.random.choice(num_q, num) 95 | target = np.random.choice(num_base, num) 96 | alpha = np.stack([np.random.rand(num)]*1024, axis=-1)*0.07 97 | #print(query[origin].shape,selected_mean[target].shape) 98 | generated_feature = query[origin,:] + alpha*selected_mean[target,:] 99 | #print(generated_feature.shape) 100 | #sys.exit() 101 | return generated_feature 102 | 103 | def distribution_calibration(query, base_means, base_cov, k,alpha, num): 104 | query = query.numpy() 105 | dist = [] 106 | k=1 107 | alpha=0.21 108 | #print(query.shape) 109 | for i in range(len(base_means)): 110 | dist.append(np.linalg.norm(query-base_means[i])) 111 | index = np.argpartition(dist, k)[:k] 112 | #mean_basics = np.array(base_means)[index] 113 | #print(mean_basics.shape) 114 | #sys.exit() 115 | selected_mean = np.array(base_means)[index] 116 | mean = np.concatenate([np.array(base_means)[index], np.squeeze(query[np.newaxis, :])]) 117 | #mean = np.squeeze(query[np.newaxis, :]) 118 | calibrated_mean = np.mean(mean, axis=0) 119 | calibrated_cov = np.mean(np.array(base_cov)[index], axis=0)+alpha 120 | samples = random_feature_interpolation(selected_mean,query,k, num) 121 | #print(calibrated_mean) 122 | #print(calibrated_cov) 123 | #feature interpolation based feature augmentation 124 | 125 | return calibrated_mean, calibrated_cov, samples 126 | 127 | class CustomImageDataset(Dataset): 128 | def __init__(self, feature, annotation, transform=None, target_transform=None): 129 | 130 | self.feature = feature #.view(-1,34) 131 | #print(self.feature.shape) 132 | #sys-exit() 133 | self.annotations = annotation 134 | def __len__(self): 135 | return len(self.annotations) 136 | def __getitem__(self, idx): 137 | return self.feature[idx], self.annotations[idx] 138 | 139 | class Net(nn.Module): 140 | def __init__(self): 141 | super(Net, self).__init__() 142 | self.fc1 = nn.Linear(1024, 256) 143 | self.attention = nn.Linear(256,128) 144 | self.attention2 = nn.Linear(128, 256) 145 | #self.pool = nn.MaxPool1d(128) 146 | self.relu = nn.ReLU() 147 | #self.bn = torch.nn.BatchNorm1d(128) 148 | self.dropout = torch.nn.Dropout(p=0.5, inplace=False) 149 | self.softmax = nn.Sigmoid() 150 | #self.conv2 = nn.Conv2d(6, 16, 5) 151 | # self.fc1 = nn.Linear(16 * 5 * 5, 120) 152 | # self.fc2 = nn.Linear(120, 84) 153 | self.fc2 = nn.Linear(256,nc) 154 | self.fc3=nn.Linear(1024,nc) 155 | 156 | def forward(self, x,y): 157 | y = y.float() 158 | x=self.fc1(x.float()) 159 | att = self.softmax(self.attention2(self.relu(self.attention(x)))) 160 | return self.fc2(self.relu(self.dropout(x+att*x))),self.fc3(y), 161 | 162 | def calculate_samples_per_class(annotation): 163 | class_index_max = nc 164 | #print(annotation) 165 | number_per_class = torch.zeros(int(class_index_max)) 166 | for i in range(class_index_max): 167 | number_per_class[i] = (annotation == i).sum() 168 | return number_per_class 169 | def calculate_weights(annotation, weights): 170 | class_index_max = nc 171 | #number_per_class = torch.zeros(class_index_max) 172 | sampler_weight = torch.zeros_like(annotation) 173 | for i in range(class_index_max): 174 | mask= annotation == i 175 | sampler_weight[mask] = weights[i] 176 | return sampler_weight 177 | def generate_train(rare_data, rare_classes, rich_data, rich_classes,sample_num_per_class): 178 | base_means = [] 179 | base_cov = [] 180 | for key in range(len(rich_data)): 181 | feature = np.array(rich_data[key]) 182 | print(feature.shape) 183 | mean = np.mean(feature, axis=0) 184 | cov = np.cov(feature.T) 185 | base_means.append(mean) 186 | base_cov.append(cov) 187 | 188 | # ---- classification for each task 189 | acc_list = [] 190 | print('Start classification for %d tasks...'%(n_runs)) 191 | 192 | #support_data = ndatas[i][:n_lsamples].numpy() 193 | #support_label = labels[i][:n_lsamples].numpy() 194 | #query_data = ndatas[i][n_lsamples:].numpy() 195 | #query_label = labels[i][n_lsamples:].numpy() 196 | # ---- Tukey's transform 197 | beta = 0.5 198 | #support_data = np.power(support_data[:, ] ,beta) 199 | #query_data = np.power(query_data[:, ] ,beta) 200 | # ---- cross distribution calibration for rare classes 201 | sampled_data = [] 202 | sampled_label = [] 203 | count = 0 204 | np.set_printoptions(threshold=sys.maxsize) 205 | #for i in range(len(rare_classes)): 206 | # print(rare_data[i].shape) 207 | for i in range(len(rare_classes)): 208 | print(sample_num_per_class[i]) 209 | 210 | #if sample_num_per_class[rare_classes[i]] == 0: 211 | # continue 212 | num_sampled = 1000 #(int(torch.max(sample_num_per_class) - sample_num_per_class[rare_classes[i]])) 213 | count += num_sampled 214 | #print(rare_data[i].shape) 215 | mean, cov, samples = distribution_calibration(rare_data[i], base_means, base_cov, 2, 0.21, 1000) 216 | #print(num_sampled) 217 | #print(samples.shape) 218 | #print(cov) 219 | sampled_data.append(samples) 220 | #sampled_data.append(np.random.multivariate_normal(list(mean), list(cov), num_sampled, 'warn')) 221 | #print(np.mean(sampled_data[i], axis=0)) 222 | #print(np.max(sampled_data[i], axis=0)) 223 | #val_data = feature_val[annotation_val==rare_classes[i]].numpy() 224 | #print(val_data) 225 | #print(np.mean(val_data, axis=0)-np.mean(sampled_data[i], axis=0)) 226 | sampled_label.extend([rare_classes[i]]*num_sampled) 227 | #sampled_label.extend([rare_classes[i]] * int(sample_num_per_class[rare_classes[i]])) 228 | #sys.exit() 229 | #sys.exit() 230 | sampled_data = np.concatenate(sampled_data).reshape(count, 1024) 231 | rare_data = np.concatenate(rare_data, axis=0) 232 | rare_label = []#torch.zeros(rare_data.shape[0]) 233 | for i in range(len(rare_classes)): 234 | rare_label.extend([rare_classes[i]] * int(sample_num_per_class[rare_classes[i]])) 235 | rare_label = np.array(rare_label) 236 | X_aug_1 = sampled_data #np.concatenate([rare_data, sampled_data]) 237 | Y_aug_1 = sampled_label #np.concatenate([rare_label,sampled_label]) 238 | #print(X_aug_1.shape) 239 | #print(Y_aug_1.shape) 240 | 241 | # ---- self distribution calibration for rich classes 242 | #num_sampled = int(750 / n_shot) 243 | sampled_data = [] 244 | sampled_label = [] 245 | count = 0 246 | 247 | for i in range(len(rich_classes)): 248 | num_sampled = 1000 #int(torch.max(sample_num_per_class) - sample_num_per_class[rich_classes[i]]) 249 | #if sample_num_per_class>500: 250 | # continue 251 | count += num_sampled 252 | #print(rich_classes[i]) 253 | #mean, conv, samples = base_means[i], base_cov[i] 254 | #print(samples.shape) 255 | mean, cov, samples = distribution_calibration(rich_data[i], base_means, base_cov, 2, 0.21, 1000) 256 | #sampled_data.append(np.random.multivariate_normal(mean=mean, cov=cov, size=num_sampled)) 257 | sampled_data.append(samples) 258 | sampled_label.extend([rich_classes[i]] * num_sampled) 259 | #sampled_label.extend([rich_classes[i]] * int(sample_num_per_class[rich_classes[i]])) 260 | sampled_data = np.concatenate(sampled_data).reshape(count, 1024) 261 | rich_label = []#torch.zeros(rare_data.shape[0]) 262 | for i in range(len(rich_classes)): 263 | rich_label.extend([rich_classes[i]] * int(sample_num_per_class[rich_classes[i]])) 264 | rare_label = np.array(rich_label) 265 | rich_data = np.concatenate(rich_data, axis=0) 266 | X_aug_2 = sampled_data #rich_data #sampled_data #np.concatenate([rich_data, sampled_data]) 267 | Y_aug_2 = sampled_label #rich_label#sampled_label #np.concatenate([rich_label,sampled_label]) 268 | X_aug = np.concatenate([X_aug_1, X_aug_2]) 269 | Y_aug = np.concatenate([Y_aug_1, Y_aug_2]) 270 | #X_aug += np.random.normal(0, .1, X_aug.shape) 271 | return X_aug, Y_aug 272 | 273 | if __name__ == '__main__': 274 | # ---- data loading 275 | n_runs = 10000 276 | import FSLTask 277 | import torch.optim as optim 278 | 279 | 280 | 281 | rare_data, rare_classes, rich_data, rich_classes, feature_val, annotation_val, feature_test, annotation_test, sample_num_per_class,rare_data_aug, rare_classes_aug, rich_data_aug, rich_classes_aug,sample_num_per_class_aug = load_dataset_driveact() 282 | #acc = sklearn.metrics.top_k_accuracy_score(annotation_test, np.squeeze(feature_test), k=1) 283 | #print(acc) 284 | #noise_mean = np.zeros(34, 1024) 285 | 286 | #print(len(rare_data)+len(rich_data)) 287 | #print(len(annotation_test)) 288 | #sys.exit() 289 | #print(len(annotation_test)) 290 | #length= rare_classes.shape[0]+rich_data.shape[] 291 | #cfg = {'shot': n_shot, 'ways': n_ways, 'queries': n_queries} 292 | #FSLTask.loadDataSet(dataset) 293 | #FSLTask.setRandomStates(cfg) 294 | #rich_datas, rare_data = FSLTask.GenerateRunSet(end=n_runs, cfg=cfg) 295 | #ndatas = ndatas.permute(0, 2, 1, 3).reshape(n_runs, n_samples, -1) 296 | #labels = torch.arange(n_ways).view(1, 1, n_ways).expand(n_runs, n_shot + n_queries, 5).clone().view(n_runs, n_samples) 297 | # ---- Base class statistics 298 | 299 | #base_features_path = "./checkpoints/%s/WideResNet28_10_S2M2_R/last/base_features.plk"%dataset 300 | 301 | X_aug, Y_aug = generate_train(rare_data,rare_classes,rich_data,rich_classes,sample_num_per_class) 302 | X_aug2, Y_aug2 = generate_train(rare_data_aug,rare_classes_aug,rich_data_aug,rich_classes_aug,sample_num_per_class_aug) 303 | #print(torch.Tensor(Y_aug)) 304 | X_aug = np.concatenate([X_aug, X_aug2]) 305 | Y_aug = np.concatenate([Y_aug, Y_aug2]) 306 | sample_number = calculate_samples_per_class(torch.Tensor(Y_aug)) 307 | #print(sample_number) 308 | weights = 1/sample_number 309 | #print(weights) 310 | sampler_weight = calculate_weights(torch.Tensor(Y_aug), weights) 311 | 312 | # ---- train classifier 313 | #print(X_aug.shape) 314 | #if mode == 'train': 315 | # samples_weight = torch.from_numpy(np.array([weight[t] for t in dataset.gt_labels])) 316 | #print(sampler_weight) 317 | sampler = torch.utils.data.sampler.WeightedRandomSampler(sampler_weight,3000) 318 | 319 | 320 | dataset_train = CustomImageDataset(X_aug, Y_aug) 321 | #train_GAN(DataLoader(dataset_train, batch_size=256, sampler=sampler)) 322 | dataset_val = CustomImageDataset(np.squeeze(feature_val), annotation_val) 323 | dataset_test = CustomImageDataset(np.squeeze(feature_test), annotation_test) 324 | 325 | 326 | 327 | train_dataloader = DataLoader(dataset_train, batch_size=256, sampler=sampler) 328 | 329 | 330 | infer_dataloader = DataLoader(dataset_train, batch_size=256, shuffle=False) 331 | test_dataloader = DataLoader(dataset_test, batch_size=64, shuffle=False) 332 | val_dataloader = DataLoader(dataset_val, batch_size=64, shuffle=False) 333 | 334 | #model = LogisticRegression(max_iter=10000,verbose=10).fit(X=X_aug, y=Y_aug) 335 | model = Net() 336 | #resume = '/cvhci/temp/kpeng/driveact/models_swin_base/best_top1_acc_epoch_24.pth' 337 | #checkpoint = torch.load(resume) 338 | #print(checkpoint['state_dict']['cls_head.fc_cls.weight']) 339 | #print(checkpoint['state_dict']['cls_head.fc_cls.bias']) 340 | #model.fc1.weight.data = checkpoint['state_dict']['cls_head.fc_cls.weight'] 341 | #model.fc1.bias.data = checkpoint['state_dict']['cls_head.fc_cls.bias'] 342 | #sys.exit() 343 | model=model.cuda() 344 | criterion = nn.CrossEntropyLoss(reduce='None',reduction='mean') 345 | criterion2 = nn.CrossEntropyLoss(reduce=False,reduction='none') 346 | #optimizer = optim.SGD(model.parameters(), lr=0.0001, momentum=0.9) 347 | #scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9) 348 | 349 | optimizer = torch.optim.AdamW(model.parameters(), lr=0.000993, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01, amsgrad=False) 350 | #optimizer = torch.optim.SGD(model.parameters(), 0.001, 351 | # momentum=0.9, 352 | # weight_decay=0.01) 353 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=5, eta_min=0, last_epoch=- 1, verbose=False) 354 | 355 | #print(np.min(Y_aug)) 356 | #sys.exit() 357 | #criterion2 = torch.nn.CosineEmbeddingLoss(margin=0.0, size_average=None, reduce=False,) 358 | for epoch in range(1500): 359 | hard_samples = [] 360 | model.train() 361 | for step, (data,label) in enumerate(train_dataloader): 362 | optimizer.zero_grad() 363 | predicts,y = model(data.cuda(), data.cuda()) 364 | 365 | loss = criterion(predicts, label.cuda()) #+ criterion2(predicts,y, torch.ones(y.size()[0]).cuda()) 366 | loss.backward() 367 | optimizer.step() 368 | scheduler.step() 369 | if (epoch > 50) and (epoch%30 == 0): 370 | model.eval() 371 | for index, (data, label) in enumerate(infer_dataloader): 372 | with torch.no_grad(): 373 | #label = torch.Tensor(label).cuda().double() 374 | predicts,y = model(data.cuda(), data.cuda()) 375 | #print(predicts.size()) 376 | difficulty = criterion2(predicts, label.cuda()) 377 | #print(difficulty) 378 | hard_samples.append(difficulty.data) 379 | #print(hard_samples) 380 | difficulty = torch.cat(hard_samples, dim=0) 381 | threshold = 1.2* torch.mean(difficulty) 382 | mask = (difficulty>threshold).cpu().numpy() 383 | hard_set = X_aug[mask] 384 | hard_label = Y_aug[mask] 385 | hard_dataset = CustomImageDataset(hard_set, hard_label) 386 | 387 | hard_train_dataloader = DataLoader(hard_dataset, batch_size=256, shuffle=True) 388 | for sub_epoch in range(1): 389 | model.train() 390 | for step, (data,label) in enumerate(hard_train_dataloader): 391 | optimizer.zero_grad() 392 | predicts,y = model(data.cuda(), data.cuda()) 393 | loss = 3*criterion(predicts, label.cuda()) #+ criterion2(predicts,y, torch.ones(y.size()[0]).cuda()) 394 | loss.backward() 395 | optimizer.step() 396 | #scheduler.step() 397 | print(epoch, 'hard_loss', loss) 398 | print(epoch, 'loss', loss) 399 | 400 | val_predict = [] 401 | model.eval() 402 | for step, (data,label) in enumerate(val_dataloader): 403 | with torch.no_grad(): 404 | #data = torch.nn.functional.normalize(data, dim=-1) 405 | 406 | predicts,y = model(data.cuda(), data.cuda()) 407 | val_predict.append(predicts.cpu()) 408 | val_predict = torch.cat(val_predict, dim=0).cpu().numpy() 409 | test_predict = [] 410 | #val_predict = np.argmax(val_predict, axis=-1) 411 | #print(predicts) 412 | acc = sklearn.metrics.top_k_accuracy_score(annotation_val, val_predict, k=1) 413 | f = open('/cvhci/data/activity/kpeng/ts_val_midlevel_predict_split0.pkl', 'wb') 414 | pickle.dump(val_predict,f) 415 | f.close() 416 | f = open('/cvhci/data/activity/kpeng/ts_val_midlevel_label_split0.pkl', 'wb') 417 | pickle.dump(annotation_val,f) 418 | f.close() 419 | print('two-stage calibration eval ACC : %f'%acc) 420 | #predicts = model.predict(np.squeeze(feature_test)) 421 | cm = confusion_matrix(annotation_val, np.argmax(val_predict, axis=-1)) 422 | f = open("/cvhci/data/activity/Drive&Act/kunyu/annotation_list.pkl", 'rb') 423 | annotation = [] 424 | class_index = pickle.load(f) 425 | # We will store the results in a dictionary for easy access later 426 | per_class_accuracies = {} 427 | # Calculate the accuracy for each one of our classes 428 | for idx, cls in enumerate(range(nc)): 429 | # True negatives are all the samples that are not our current GT class (not the current row) 430 | # and were not predicted as the current class (not the current column) 431 | true_negatives = np.sum(np.delete(np.delete(cm, idx, axis=0), idx, axis=1)) 432 | # True positives are all the samples of our current GT class that were predicted as such 433 | true_positives = cm[idx, idx] 434 | # The accuracy for the current class is ratio between correct predictions to all predictions 435 | per_class_accuracies[cls] = (true_positives + true_negatives) / np.sum(cm) 436 | #print(class_index[idx], 'val_accuracy', per_class_accuracies[cls]) 437 | model.eval() 438 | 439 | 440 | for step, (data,label) in enumerate(test_dataloader): 441 | with torch.no_grad(): 442 | data = torch.nn.functional.normalize(data, dim=-1) 443 | predicts,y = model(data.cuda(), data.cuda()) 444 | test_predict.append(predicts.cpu()) 445 | test_predict = torch.cat(test_predict, dim=0).cpu().numpy() 446 | #print(np.squeeze(feature_test).shape) 447 | #sys.exit() 448 | #predicts = model.predict(np.squeeze(feature_val)) 449 | #predicts = np.argmax(predicts, axis=-1) 450 | #test_predict = np.argmax(test_predict, axis=-1) 451 | acc = sklearn.metrics.top_k_accuracy_score(annotation_test, test_predict, k=1) 452 | f = open('/cvhci/data/activity/kpeng/ts_test_midlevel_predict_split0.pkl', 'wb') 453 | pickle.dump(test_predict,f) 454 | f.close() 455 | f = open('/cvhci/data/activity/kpeng/ts_test_midlevel_label_split0.pkl', 'wb') 456 | pickle.dump(annotation_test,f) 457 | f.close() 458 | print('two-stage calibration test ACC : %f' % acc) 459 | #for i in range(34): 460 | # mask = annotation_test == i 461 | # acc = sklearn.metrics.top_k_accuracy_score(torch.argmax(annotation_test[mask], dim=-1), test_predict[mask], k=1) 462 | # print('class', class_index[i], 'accuracy:', acc) 463 | #filename = 'finalized_model.sav' 464 | #pickle.dump(model, open(filename, 'wb')) 465 | # Get the confusion matrix 466 | cm = confusion_matrix(annotation_test, np.argmax(test_predict, axis=-1)) 467 | # We will store the results in a dictionary for easy access later 468 | per_class_accuracies = {} 469 | # Calculate the accuracy for each one of our classes 470 | for idx, cls in enumerate(range(nc)): 471 | # True negatives are all the samples that are not our current GT class (not the current row) 472 | # and were not predicted as the current class (not the current column) 473 | true_negatives = np.sum(np.delete(np.delete(cm, idx, axis=0), idx, axis=1)) 474 | # True positives are all the samples of our current GT class that were predicted as such 475 | true_positives = cm[idx, idx] 476 | # The accuracy for the current class is ratio between correct predictions to all predictions 477 | per_class_accuracies[cls] = (true_positives + true_negatives) / np.sum(cm) 478 | #print(class_index[idx], 'test_accuracy', per_class_accuracies[cls]) 479 | cm = confusion_matrix(annotation_test, np.argmax(feature_test, axis=-1)) 480 | # We will store the results in a dictionary for easy access later 481 | per_class_accuracies = {} 482 | # Calculate the accuracy for each one of our classes 483 | for idx, cls in enumerate(range(nc)): 484 | # True negatives are all the samples that are not our current GT class (not the current row) 485 | # and were not predicted as the current class (not the current column) 486 | true_negatives = np.sum(np.delete(np.delete(cm, idx, axis=0), idx, axis=1)) 487 | # True positives are all the samples of our current GT class that were predicted as such 488 | true_positives = cm[idx, idx] 489 | # The accuracy for the current class is ratio between correct predictions to all predictions 490 | per_class_accuracies[cls] = (true_positives + true_negatives) / np.sum(cm) 491 | #print(class_index[idx], 'test_accuracy', per_class_accuracies[cls]) 492 | 493 | 494 | 495 | --------------------------------------------------------------------------------