├── illustration.png
├── finalized_model.sav
├── data
├── __init__.py
├── __pycache__
│ ├── datamgr.cpython-35.pyc
│ ├── datamgr.cpython-36.pyc
│ ├── dataset.cpython-35.pyc
│ ├── dataset.cpython-36.pyc
│ ├── __init__.cpython-35.pyc
│ ├── __init__.cpython-36.pyc
│ ├── feature_loader.cpython-35.pyc
│ ├── additional_transforms.cpython-35.pyc
│ └── additional_transforms.cpython-36.pyc
├── additional_transforms.py
├── dataset.py
└── datamgr.py
├── __pycache__
└── FSLTask.cpython-37.pyc
├── requirements.txt
├── configs.py
├── filelists
├── CUB
│ ├── download_CUB.sh
│ └── write_CUB_filelist.py
└── miniImagenet
│ ├── download_miniImagenet.sh
│ └── write_miniImagenet_filelist.py
├── .idea
├── vcs.xml
├── .gitignore
├── inspectionProfiles
│ └── profiles_settings.xml
├── modules.xml
└── Few_Shot_Distribution_Calibration.iml
├── io_utils.py
├── save_plk.py
├── README.md
├── wrn_model.py
├── FSLTask.py
└── evaluate_DC.py
/illustration.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KPeng9510/TransDARC/HEAD/illustration.png
--------------------------------------------------------------------------------
/finalized_model.sav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KPeng9510/TransDARC/HEAD/finalized_model.sav
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | from . import datamgr
2 | from . import dataset
3 | from . import additional_transforms
4 |
--------------------------------------------------------------------------------
/__pycache__/FSLTask.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KPeng9510/TransDARC/HEAD/__pycache__/FSLTask.cpython-37.pyc
--------------------------------------------------------------------------------
/data/__pycache__/datamgr.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KPeng9510/TransDARC/HEAD/data/__pycache__/datamgr.cpython-35.pyc
--------------------------------------------------------------------------------
/data/__pycache__/datamgr.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KPeng9510/TransDARC/HEAD/data/__pycache__/datamgr.cpython-36.pyc
--------------------------------------------------------------------------------
/data/__pycache__/dataset.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KPeng9510/TransDARC/HEAD/data/__pycache__/dataset.cpython-35.pyc
--------------------------------------------------------------------------------
/data/__pycache__/dataset.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KPeng9510/TransDARC/HEAD/data/__pycache__/dataset.cpython-36.pyc
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==1.17.2
2 | matplotlib==3.1.1
3 | tqdm==4.36.1
4 | torchvision==0.6.0
5 | torch==1.5.0
6 | Pillow==7.1.2
7 |
--------------------------------------------------------------------------------
/data/__pycache__/__init__.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KPeng9510/TransDARC/HEAD/data/__pycache__/__init__.cpython-35.pyc
--------------------------------------------------------------------------------
/data/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KPeng9510/TransDARC/HEAD/data/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/data/__pycache__/feature_loader.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KPeng9510/TransDARC/HEAD/data/__pycache__/feature_loader.cpython-35.pyc
--------------------------------------------------------------------------------
/data/__pycache__/additional_transforms.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KPeng9510/TransDARC/HEAD/data/__pycache__/additional_transforms.cpython-35.pyc
--------------------------------------------------------------------------------
/data/__pycache__/additional_transforms.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KPeng9510/TransDARC/HEAD/data/__pycache__/additional_transforms.cpython-36.pyc
--------------------------------------------------------------------------------
/configs.py:
--------------------------------------------------------------------------------
1 | save_dir = '.'
2 | data_dir = {}
3 | data_dir['CUB'] = './filelists/CUB/'
4 | data_dir['miniImagenet'] = './filelists/miniImagenet/'
5 |
--------------------------------------------------------------------------------
/filelists/CUB/download_CUB.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | wget http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz
3 | tar -zxvf CUB_200_2011.tgz
4 | python write_CUB_filelist.py
5 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Editor-based HTTP Client requests
5 | /httpRequests/
6 | # Datasource local storage ignored files
7 | /dataSources/
8 | /dataSources.local.xml
9 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/filelists/miniImagenet/download_miniImagenet.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | wget https://raw.githubusercontent.com/twitter/meta-learning-lstm/master/data/miniImagenet/train.csv
3 | wget https://raw.githubusercontent.com/twitter/meta-learning-lstm/master/data/miniImagenet/val.csv
4 | wget https://raw.githubusercontent.com/twitter/meta-learning-lstm/master/data/miniImagenet/test.csv
5 | wget http://image-net.org/image/ILSVRC2015/ILSVRC2015_CLS-LOC.tar.gz
6 | tar -zxvf ILSVRC2015_CLS-LOC.tar.gz
7 | python write_miniImagenet_filelist.py
8 |
--------------------------------------------------------------------------------
/.idea/Few_Shot_Distribution_Calibration.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/data/additional_transforms.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 |
8 | import torch
9 | from PIL import ImageEnhance
10 |
11 | transformtypedict=dict(Brightness=ImageEnhance.Brightness, Contrast=ImageEnhance.Contrast, Sharpness=ImageEnhance.Sharpness, Color=ImageEnhance.Color)
12 |
13 |
14 |
15 | class ImageJitter(object):
16 | def __init__(self, transformdict):
17 | self.transforms = [(transformtypedict[k], transformdict[k]) for k in transformdict]
18 |
19 |
20 | def __call__(self, img):
21 | out = img
22 | randtensor = torch.rand(len(self.transforms))
23 |
24 | for i, (transformer, alpha) in enumerate(self.transforms):
25 | r = alpha*(randtensor[i]*2.0 -1.0) + 1
26 | out = transformer(out).enhance(r).convert('RGB')
27 |
28 | return out
29 |
30 |
31 |
32 |
33 |
--------------------------------------------------------------------------------
/filelists/CUB/write_CUB_filelist.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from os import listdir
3 | from os.path import isfile, isdir, join
4 | import os
5 | import json
6 | import random
7 |
8 | cwd = os.getcwd()
9 | data_path = join(cwd,'CUB_200_2011/images')
10 | savedir = './'
11 | dataset_list = ['base','val','novel']
12 |
13 | #if not os.path.exists(savedir):
14 | # os.makedirs(savedir)
15 |
16 | folder_list = [f for f in listdir(data_path) if isdir(join(data_path, f))]
17 | folder_list.sort()
18 | label_dict = dict(zip(folder_list,range(0,len(folder_list))))
19 |
20 | classfile_list_all = []
21 |
22 | for i, folder in enumerate(folder_list):
23 | folder_path = join(data_path, folder)
24 | classfile_list_all.append( [ join(folder_path, cf) for cf in listdir(folder_path) if (isfile(join(folder_path,cf)) and cf[0] != '.')])
25 | random.shuffle(classfile_list_all[i])
26 |
27 |
28 | for dataset in dataset_list:
29 | file_list = []
30 | label_list = []
31 | for i, classfile_list in enumerate(classfile_list_all):
32 | if 'base' in dataset:
33 | if (i%2 == 0):
34 | file_list = file_list + classfile_list
35 | label_list = label_list + np.repeat(i, len(classfile_list)).tolist()
36 | if 'val' in dataset:
37 | if (i%4 == 1):
38 | file_list = file_list + classfile_list
39 | label_list = label_list + np.repeat(i, len(classfile_list)).tolist()
40 | if 'novel' in dataset:
41 | if (i%4 == 3):
42 | file_list = file_list + classfile_list
43 | label_list = label_list + np.repeat(i, len(classfile_list)).tolist()
44 |
45 | fo = open(savedir + dataset + ".json", "w")
46 | fo.write('{"label_names": [')
47 | fo.writelines(['"%s",' % item for item in folder_list])
48 | fo.seek(0, os.SEEK_END)
49 | fo.seek(fo.tell()-1, os.SEEK_SET)
50 | fo.write('],')
51 |
52 | fo.write('"image_names": [')
53 | fo.writelines(['"%s",' % item for item in file_list])
54 | fo.seek(0, os.SEEK_END)
55 | fo.seek(fo.tell()-1, os.SEEK_SET)
56 | fo.write('],')
57 |
58 | fo.write('"image_labels": [')
59 | fo.writelines(['%d,' % item for item in label_list])
60 | fo.seek(0, os.SEEK_END)
61 | fo.seek(fo.tell()-1, os.SEEK_SET)
62 | fo.write(']}')
63 |
64 | fo.close()
65 | print("%s -OK" %dataset)
66 |
--------------------------------------------------------------------------------
/filelists/miniImagenet/write_miniImagenet_filelist.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from os import listdir
3 | from os.path import isfile, isdir, join
4 | import os
5 | import json
6 | import random
7 | import re
8 |
9 | cwd = os.getcwd()
10 | data_path = join(cwd,'ILSVRC2015/Data/CLS-LOC/train')
11 | #data_path = join('/home/yuqing/phd/code/miniimagenet/images')
12 | savedir = './'
13 | dataset_list = ['base', 'val', 'novel']
14 |
15 | #if not os.path.exists(savedir):
16 | # os.makedirs(savedir)
17 |
18 | cl = -1
19 | folderlist = []
20 |
21 | datasetmap = {'base':'train','val':'val','novel':'test'};
22 | filelists = {'base':{},'val':{},'novel':{} }
23 | filelists_flat = {'base':[],'val':[],'novel':[] }
24 | labellists_flat = {'base':[],'val':[],'novel':[] }
25 |
26 | for dataset in dataset_list:
27 | with open(datasetmap[dataset] + ".csv", "r") as lines:
28 | for i, line in enumerate(lines):
29 | if i == 0:
30 | continue
31 | fid, _ , label = re.split(',|\.', line)
32 | label = label.replace('\n','')
33 | if not label in filelists[dataset]:
34 | folderlist.append(label)
35 | filelists[dataset][label] = []
36 | fnames = listdir( join(data_path, label) )
37 | fname_number = [ int(re.split('_|\.', fname)[1]) for fname in fnames]
38 | sorted_fnames = list(zip( *sorted( zip(fnames, fname_number), key = lambda f_tuple: f_tuple[1] )))[0]
39 |
40 | fid = int(fid[-5:])-1
41 | fname = join( data_path,label, sorted_fnames[fid] )
42 | filelists[dataset][label].append(fname)
43 |
44 | for key, filelist in filelists[dataset].items():
45 | cl += 1
46 | random.shuffle(filelist)
47 | filelists_flat[dataset] += filelist
48 | labellists_flat[dataset] += np.repeat(cl, len(filelist)).tolist()
49 |
50 | for dataset in dataset_list:
51 | fo = open(savedir + dataset + ".json", "w")
52 | fo.write('{"label_names": [')
53 | fo.writelines(['"%s",' % item for item in folderlist])
54 | fo.seek(0, os.SEEK_END)
55 | fo.seek(fo.tell()-1, os.SEEK_SET)
56 | fo.write('],')
57 |
58 | fo.write('"image_names": [')
59 | fo.writelines(['"%s",' % item for item in filelists_flat[dataset]])
60 | fo.seek(0, os.SEEK_END)
61 | fo.seek(fo.tell()-1, os.SEEK_SET)
62 | fo.write('],')
63 |
64 | fo.write('"image_labels": [')
65 | fo.writelines(['%d,' % item for item in labellists_flat[dataset]])
66 | fo.seek(0, os.SEEK_END)
67 | fo.seek(fo.tell()-1, os.SEEK_SET)
68 | fo.write(']}')
69 |
70 | fo.close()
71 | print("%s -OK" %dataset)
72 |
--------------------------------------------------------------------------------
/io_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import glob
4 | import argparse
5 |
6 | import numpy as np
7 | import os
8 | import glob
9 | import argparse
10 |
11 |
12 | def parse_args(script):
13 | parser = argparse.ArgumentParser(description='few-shot script %s' % (script))
14 | parser.add_argument('--dataset', default='miniImagenet', help='CUB/miniImagenet')
15 | parser.add_argument('--model', default='WideResNet28_10', help='model: WideResNet28_10/ResNet{18}')
16 | parser.add_argument('--method', default='S2M2_R', help='rotation/S2M2_R')
17 | parser.add_argument('--train_aug', default='True',
18 | help='perform data augmentation or not during training ') # still required for save_features.py and test.py to find the model path correctly
19 |
20 | if script == 'train':
21 | parser.add_argument('--num_classes', default=200, type=int,
22 | help='total number of classes') # make it larger than the maximum label value in base class
23 | parser.add_argument('--save_freq', default=10, type=int, help='Save frequency')
24 | parser.add_argument('--start_epoch', default=0, type=int, help='Starting epoch')
25 | parser.add_argument('--stop_epoch', default=400, type=int,
26 | help='Stopping epoch') # for meta-learning methods, each epoch contains 100 episodes. The default epoch number is dataset dependent. See train.py
27 | parser.add_argument('--resume', action='store_true',
28 | help='continue from previous trained model with largest epoch')
29 | parser.add_argument('--lr', default=0.001, type=int, help='learning rate')
30 | parser.add_argument('--batch_size', default=16, type=int, help='batch size ')
31 | parser.add_argument('--test_batch_size', default=2, type=int, help='batch size ')
32 | parser.add_argument('--alpha', default=2.0, type=int, help='for S2M2 training ')
33 | elif script == 'test':
34 | parser.add_argument('--num_classes', default=200, type=int, help='total number of classes')
35 |
36 | return parser.parse_args()
37 |
38 |
39 | def get_assigned_file(checkpoint_dir, num):
40 | assign_file = os.path.join(checkpoint_dir, '{:d}.tar'.format(num))
41 | return assign_file
42 |
43 |
44 | def get_resume_file(checkpoint_dir):
45 | filelist = glob.glob(os.path.join(checkpoint_dir, '*.tar'))
46 | if len(filelist) == 0:
47 | return None
48 |
49 | filelist = [x for x in filelist if os.path.basename(x) != 'best.tar']
50 | epochs = np.array([int(os.path.splitext(os.path.basename(x))[0]) for x in filelist])
51 | max_epoch = np.max(epochs)
52 | resume_file = os.path.join(checkpoint_dir, '{:d}.tar'.format(max_epoch))
53 | return resume_file
54 |
55 |
56 | def get_best_file(checkpoint_dir):
57 | best_file = os.path.join(checkpoint_dir, 'best.tar')
58 | if os.path.isfile(best_file):
59 | return best_file
60 | else:
61 | return get_resume_file(checkpoint_dir)
--------------------------------------------------------------------------------
/data/dataset.py:
--------------------------------------------------------------------------------
1 | # This code is modified from https://github.com/facebookresearch/low-shot-shrink-hallucinate
2 |
3 | import torch
4 | from PIL import Image
5 | import json
6 | import numpy as np
7 | import torchvision.transforms as transforms
8 | import os
9 | identity = lambda x:x
10 | class SimpleDataset:
11 | def __init__(self, data_file, transform, target_transform=identity):
12 | with open(data_file, 'r') as f:
13 | self.meta = json.load(f)
14 | self.transform = transform
15 | self.target_transform = target_transform
16 |
17 |
18 | def __getitem__(self,i):
19 | image_path = os.path.join(self.meta['image_names'][i])
20 | img = Image.open(image_path).convert('RGB')
21 | img = self.transform(img)
22 | target = self.target_transform(self.meta['image_labels'][i])
23 | return img, target
24 |
25 | def __len__(self):
26 | return len(self.meta['image_names'])
27 |
28 |
29 | class SetDataset:
30 | def __init__(self, data_file, batch_size, transform):
31 | with open(data_file, 'r') as f:
32 | self.meta = json.load(f)
33 |
34 | self.cl_list = np.unique(self.meta['image_labels']).tolist()
35 |
36 | self.sub_meta = {}
37 | for cl in self.cl_list:
38 | self.sub_meta[cl] = []
39 |
40 | for x,y in zip(self.meta['image_names'],self.meta['image_labels']):
41 | self.sub_meta[y].append(x)
42 |
43 | self.sub_dataloader = []
44 | sub_data_loader_params = dict(batch_size = batch_size,
45 | shuffle = True,
46 | num_workers = 0, #use main thread only or may receive multiple batches
47 | pin_memory = False)
48 | for cl in self.cl_list:
49 | sub_dataset = SubDataset(self.sub_meta[cl], cl, transform = transform )
50 | self.sub_dataloader.append( torch.utils.data.DataLoader(sub_dataset, **sub_data_loader_params) )
51 |
52 | def __getitem__(self,i):
53 | return next(iter(self.sub_dataloader[i]))
54 |
55 | def __len__(self):
56 | return len(self.cl_list)
57 |
58 | class SubDataset:
59 | def __init__(self, sub_meta, cl, transform=transforms.ToTensor(), target_transform=identity):
60 | self.sub_meta = sub_meta
61 | self.cl = cl
62 | self.transform = transform
63 | self.target_transform = target_transform
64 |
65 | def __getitem__(self,i):
66 | #print( '%d -%d' %(self.cl,i))
67 | image_path = os.path.join( self.sub_meta[i])
68 | img = Image.open(image_path).convert('RGB')
69 | img = self.transform(img)
70 | target = self.target_transform(self.cl)
71 | return img, target
72 |
73 | def __len__(self):
74 | return len(self.sub_meta)
75 |
76 | class EpisodicBatchSampler(object):
77 | def __init__(self, n_classes, n_way, n_episodes):
78 | self.n_classes = n_classes
79 | self.n_way = n_way
80 | self.n_episodes = n_episodes
81 |
82 | def __len__(self):
83 | return self.n_episodes
84 |
85 | def __iter__(self):
86 | for i in range(self.n_episodes):
87 | yield torch.randperm(self.n_classes)[:self.n_way]
88 |
--------------------------------------------------------------------------------
/data/datamgr.py:
--------------------------------------------------------------------------------
1 | # This code is modified from https://github.com/facebookresearch/low-shot-shrink-hallucinate
2 |
3 | import torch
4 | from PIL import Image
5 | import numpy as np
6 | import torchvision.transforms as transforms
7 | import data.additional_transforms as add_transforms
8 | from data.dataset import SimpleDataset, SetDataset, EpisodicBatchSampler
9 | from abc import abstractmethod
10 |
11 | class TransformLoader:
12 | def __init__(self, image_size,
13 | normalize_param = dict(mean= [0.485, 0.456, 0.406] , std=[0.229, 0.224, 0.225]),
14 | jitter_param = dict(Brightness=0.4, Contrast=0.4, Color=0.4)):
15 | self.image_size = image_size
16 | self.normalize_param = normalize_param
17 | self.jitter_param = jitter_param
18 |
19 | def parse_transform(self, transform_type):
20 | if transform_type=='ImageJitter':
21 | method = add_transforms.ImageJitter( self.jitter_param )
22 | return method
23 | method = getattr(transforms, transform_type)
24 | if transform_type=='RandomSizedCrop':
25 | return method(self.image_size)
26 | elif transform_type=='CenterCrop':
27 | return method(self.image_size)
28 | elif transform_type=='Scale':
29 | return method([int(self.image_size*1.15), int(self.image_size*1.15)])
30 | elif transform_type=='Normalize':
31 | return method(**self.normalize_param )
32 | else:
33 | return method()
34 |
35 | def get_composed_transform(self, aug = False):
36 | if aug:
37 | transform_list = ['RandomSizedCrop', 'ImageJitter', 'RandomHorizontalFlip', 'ToTensor', 'Normalize']
38 | else:
39 | transform_list = ['Scale','CenterCrop', 'ToTensor', 'Normalize']
40 |
41 | transform_funcs = [ self.parse_transform(x) for x in transform_list]
42 | transform = transforms.Compose(transform_funcs)
43 | return transform
44 |
45 | class DataManager:
46 | @abstractmethod
47 | def get_data_loader(self, data_file, aug):
48 | pass
49 |
50 |
51 | class SimpleDataManager(DataManager):
52 | def __init__(self, image_size, batch_size):
53 | super(SimpleDataManager, self).__init__()
54 | self.batch_size = batch_size
55 | self.trans_loader = TransformLoader(image_size)
56 |
57 | def get_data_loader(self, data_file, aug): #parameters that would change on train/val set
58 | transform = self.trans_loader.get_composed_transform(aug)
59 | dataset = SimpleDataset(data_file, transform)
60 | data_loader_params = dict(batch_size = self.batch_size, shuffle = True, num_workers = 12, pin_memory = True)
61 | data_loader = torch.utils.data.DataLoader(dataset, **data_loader_params)
62 |
63 | return data_loader
64 |
65 | class SetDataManager(DataManager):
66 | def __init__(self, image_size, n_way, n_support, n_query, n_eposide =100):
67 | super(SetDataManager, self).__init__()
68 | self.image_size = image_size
69 | self.n_way = n_way
70 | self.batch_size = n_support + n_query
71 | self.n_eposide = n_eposide
72 |
73 | self.trans_loader = TransformLoader(image_size)
74 |
75 | def get_data_loader(self, data_file, aug): #parameters that would change on train/val set
76 | transform = self.trans_loader.get_composed_transform(aug)
77 | dataset = SetDataset( data_file , self.batch_size, transform )
78 | sampler = EpisodicBatchSampler(len(dataset), self.n_way, self.n_eposide )
79 | data_loader_params = dict(batch_sampler = sampler, num_workers = 12, pin_memory = True)
80 | data_loader = torch.utils.data.DataLoader(dataset, **data_loader_params)
81 | return data_loader
82 |
83 |
84 |
--------------------------------------------------------------------------------
/save_plk.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import argparse
4 | import csv
5 | import os
6 | import collections
7 | import pickle
8 | import random
9 |
10 | import numpy as np
11 | import torch
12 | from torch.autograd import Variable
13 | import torch.backends.cudnn as cudnn
14 | import torch.nn as nn
15 | import torch.optim as optim
16 | import torchvision.transforms as transforms
17 | import torchvision.datasets as datasets
18 | from io_utils import parse_args
19 | from data.datamgr import SimpleDataManager , SetDataManager
20 | import configs
21 |
22 | import wrn_model
23 |
24 | import torch.nn.functional as F
25 |
26 | from io_utils import parse_args, get_resume_file ,get_assigned_file
27 | from os import path
28 |
29 | use_gpu = torch.cuda.is_available()
30 |
31 | class WrappedModel(nn.Module):
32 | def __init__(self, module):
33 | super(WrappedModel, self).__init__()
34 | self.module = module
35 | def forward(self, x):
36 | return self.module(x)
37 |
38 | def save_pickle(file, data):
39 | with open(file, 'wb') as f:
40 | pickle.dump(data, f)
41 |
42 | def load_pickle(file):
43 | with open(file, 'rb') as f:
44 | return pickle.load(f)
45 |
46 | def extract_feature(val_loader, model, checkpoint_dir, tag='last',set='base'):
47 | save_dir = '{}/{}'.format(checkpoint_dir, tag)
48 | if os.path.isfile(save_dir + '/%s_features.plk'%set):
49 | data = load_pickle(save_dir + '/%s_features.plk'%set)
50 | return data
51 | else:
52 | if not os.path.isdir(save_dir):
53 | os.makedirs(save_dir)
54 |
55 | #model.eval()
56 | with torch.no_grad():
57 |
58 | output_dict = collections.defaultdict(list)
59 |
60 | for i, (inputs, labels) in enumerate(val_loader):
61 | # compute output
62 | inputs = inputs.cuda()
63 | labels = labels.cuda()
64 | outputs,_ = model(inputs)
65 | outputs = outputs.cpu().data.numpy()
66 |
67 | for out, label in zip(outputs, labels):
68 | output_dict[label.item()].append(out)
69 |
70 | all_info = output_dict
71 | save_pickle(save_dir + '/%s_features.plk'%set, all_info)
72 | return all_info
73 |
74 | if __name__ == '__main__':
75 | params = parse_args('test')
76 | params.model = 'WideResNet28_10'
77 | params.method = 'S2M2_R'
78 |
79 | loadfile_base = configs.data_dir[params.dataset] + 'base.json'
80 | loadfile_novel = configs.data_dir[params.dataset] + 'novel.json'
81 | if params.dataset == 'miniImagenet' or params.dataset == 'CUB':
82 | datamgr = SimpleDataManager(84, batch_size = 256)
83 | base_loader = datamgr.get_data_loader(loadfile_base, aug=False)
84 | novel_loader = datamgr.get_data_loader(loadfile_novel, aug = False)
85 |
86 | checkpoint_dir = '%s/checkpoints/%s/%s_%s' %(configs.save_dir, params.dataset, params.model, params.method)
87 | modelfile = get_resume_file(checkpoint_dir)
88 |
89 | if params.model == 'WideResNet28_10':
90 | model = wrn_model.wrn28_10(num_classes=params.num_classes)
91 |
92 |
93 | model = model.cuda()
94 | cudnn.benchmark = True
95 |
96 | checkpoint = torch.load(modelfile)
97 | state = checkpoint['state']
98 | state_keys = list(state.keys())
99 |
100 | callwrap = False
101 | if 'module' in state_keys[0]:
102 | callwrap = True
103 | if callwrap:
104 | model = WrappedModel(model)
105 | model_dict_load = model.state_dict()
106 | model_dict_load.update(state)
107 | model.load_state_dict(model_dict_load)
108 | model.eval()
109 | output_dict_base = extract_feature(base_loader, model, checkpoint_dir, tag='last', set='base')
110 | print("base set features saved!")
111 | output_dict_novel=extract_feature(novel_loader, model, checkpoint_dir, tag='last',set='novel')
112 | print("novel features saved!")
113 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # TransDARC (IROS 2022, ReadME in progress)
2 |
3 | ## News
4 | **TransDARC** [[**PDF**](https://arxiv.org/pdf/2203.00927.pdf)] is accepted to **IROS2022** for an **Oral** presentation.
5 |
6 | ## Extract and save features
7 | Please first train the video feature extraction backbone from video swin transformer in mmaction2 repo using drive and act dataset,
8 |
9 | In our paper, we use the same configuration of the video swin transformer for swin-base as mentioned in https://github.com/SwinTransformer/Video-Swin-Transformer.git while using the ImageNet pretraining.
10 |
11 | The corresponding configuration of traing for swin_base_patch244_window877_kinetics400_22k.py
12 |
13 |
14 |
15 | train_pipeline = [
16 |
17 |
18 | dict(type='SampleFrames', clip_len=32, frame_interval=2, num_clips=2),
19 |
20 |
21 | dict(type='Resize', scale=(224, 224), keep_ratio=False),
22 |
23 |
24 | dict(type='FormatShape', input_format='NCTHW'),
25 |
26 |
27 | dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
28 |
29 |
30 | dict(type='ToTensor', keys=['imgs', 'label'])
31 |
32 |
33 | ]
34 |
35 |
36 | val_pipeline = [
37 |
38 |
39 | dict(
40 |
41 |
42 | type='SampleFrames',
43 |
44 |
45 | clip_len=32,
46 |
47 |
48 | frame_interval=2,
49 |
50 |
51 | num_clips=2,
52 |
53 |
54 | test_mode=True),
55 |
56 |
57 | dict(type='FormatShape', input_format='NCTHW'),
58 |
59 |
60 | dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
61 |
62 |
63 | dict(type='ToTensor', keys=['imgs', 'label'])
64 |
65 |
66 | ]
67 |
68 |
69 | test_pipeline = [
70 |
71 |
72 | dict(
73 |
74 |
75 | type='SampleFrames',
76 |
77 |
78 | clip_len=32,
79 |
80 |
81 | frame_interval=2,
82 |
83 |
84 | num_clips=2,
85 |
86 |
87 | test_mode=True),
88 |
89 |
90 | dict(type='FormatShape', input_format='NCTHW'),
91 |
92 |
93 | dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
94 |
95 |
96 | dict(type='ToTensor', keys=['imgs', 'label'])
97 |
98 |
99 | ]
100 |
101 | PLease average the clips output at the top of video swin transformer
102 |
103 | ## Evaluate our distribution calibration
104 |
105 | To evaluate our distribution calibration method, run:
106 |
107 | ```eval
108 | python evaluate_DC.py
109 | ```
110 | ## Verification of our code
111 |
112 | The logits of Video Swin Base are available at https://drive.google.com/drive/folders/1MJY8toH3PSV--pA2EvL8qjPIjcaTdmvr?usp=sharing, which is for fine-grained level split 0 driver activity recognition. (notice that, the result reported in the paper is the mean average over three splits, i.e., split0, split1, and split2)
113 |
114 | Please note that the performance is evaluated unfortunately differently by using unbalanced mean average Top-1 accuracy. Under balanced accuracy our model can still achieve 71.0 accuracy for fine-grained activity recognition on the test set. In case you want to compare with us, you could either evaluate following the same way our paper use and compare with TransDARC, or contact my email to get more results for the other tasks considering balanced accuracy evaluation. Thanks!
115 |
116 | ## Please consider citing our paper once you are interested in it. [[**PDF**](https://arxiv.org/pdf/2203.00927.pdf)]
117 |
118 | ```
119 | @article{peng2022transdarc,
120 | title={TransDARC: Transformer-based Driver Activity Recognition with Latent Space Feature Calibration},
121 | author={Peng, Kunyu and Roitberg, Alina and Yang, Kailun and Zhang, Jiaming and Stiefelhagen, Rainer},
122 | journal={2022 IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS)},
123 | year={2022}
124 | }
125 | ```
126 |
127 |
--------------------------------------------------------------------------------
/wrn_model.py:
--------------------------------------------------------------------------------
1 | ### dropout has been removed in this code. original code had dropout#####
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.init as init
5 | import torch.nn.functional as F
6 | from torch.autograd import Variable
7 |
8 | import sys, os
9 | import numpy as np
10 | import random
11 | act = torch.nn.ReLU()
12 |
13 |
14 | import math
15 | from torch.nn.utils.weight_norm import WeightNorm
16 |
17 |
18 |
19 | class BasicBlock(nn.Module):
20 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0):
21 | super(BasicBlock, self).__init__()
22 | self.bn1 = nn.BatchNorm2d(in_planes)
23 | self.relu1 = nn.ReLU(inplace=True)
24 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
25 | padding=1, bias=False)
26 | self.bn2 = nn.BatchNorm2d(out_planes)
27 | self.relu2 = nn.ReLU(inplace=True)
28 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
29 | padding=1, bias=False)
30 | self.droprate = dropRate
31 | self.equalInOut = (in_planes == out_planes)
32 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
33 | padding=0, bias=False) or None
34 | def forward(self, x):
35 | if not self.equalInOut:
36 | x = self.relu1(self.bn1(x))
37 | else:
38 | out = self.relu1(self.bn1(x))
39 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
40 | if self.droprate > 0:
41 | out = F.dropout(out, p=self.droprate, training=self.training)
42 | out = self.conv2(out)
43 | return torch.add(x if self.equalInOut else self.convShortcut(x), out)
44 |
45 | class distLinear(nn.Module):
46 | def __init__(self, indim, outdim):
47 | super(distLinear, self).__init__()
48 | self.L = nn.Linear( indim, outdim, bias = False)
49 | self.class_wise_learnable_norm = True #See the issue#4&8 in the github
50 | if self.class_wise_learnable_norm:
51 | WeightNorm.apply(self.L, 'weight', dim=0) #split the weight update component to direction and norm
52 |
53 | if outdim <=200:
54 | self.scale_factor = 2; #a fixed scale factor to scale the output of cos value into a reasonably large input for softmax
55 | else:
56 | self.scale_factor = 10; #in omniglot, a larger scale factor is required to handle >1000 output classes.
57 |
58 | def forward(self, x):
59 | x_norm = torch.norm(x, p=2, dim =1).unsqueeze(1).expand_as(x)
60 | x_normalized = x.div(x_norm+ 0.00001)
61 | if not self.class_wise_learnable_norm:
62 | L_norm = torch.norm(self.L.weight.data, p=2, dim =1).unsqueeze(1).expand_as(self.L.weight.data)
63 | self.L.weight.data = self.L.weight.data.div(L_norm + 0.00001)
64 | cos_dist = self.L(x_normalized) #matrix product by forward function, but when using WeightNorm, this also multiply the cosine distance by a class-wise learnable norm, see the issue#4&8 in the github
65 | scores = self.scale_factor* (cos_dist)
66 |
67 | return scores
68 |
69 | class NetworkBlock(nn.Module):
70 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0):
71 | super(NetworkBlock, self).__init__()
72 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate)
73 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate):
74 | layers = []
75 | for i in range(int(nb_layers)):
76 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate))
77 | return nn.Sequential(*layers)
78 | def forward(self, x):
79 | return self.layer(x)
80 |
81 |
82 | def to_one_hot(inp,num_classes):
83 |
84 | y_onehot = torch.FloatTensor(inp.size(0), num_classes)
85 | if torch.cuda.is_available():
86 | y_onehot = y_onehot.cuda()
87 |
88 | y_onehot.zero_()
89 | x = inp.type(torch.LongTensor)
90 | if torch.cuda.is_available():
91 | x = x.cuda()
92 |
93 | x = torch.unsqueeze(x , 1)
94 | y_onehot.scatter_(1, x , 1)
95 |
96 | return Variable(y_onehot,requires_grad=False)
97 | # return y_onehot
98 |
99 |
100 | def mixup_data(x, y, lam):
101 |
102 | '''Compute the mixup data. Return mixed inputs, pairs of targets, and lambda'''
103 |
104 | batch_size = x.size()[0]
105 | index = torch.randperm(batch_size)
106 | if torch.cuda.is_available():
107 | index = index.cuda()
108 | mixed_x = lam * x + (1 - lam) * x[index,:]
109 | y_a, y_b = y, y[index]
110 |
111 | return mixed_x, y_a, y_b, lam
112 |
113 |
114 | class WideResNet(nn.Module):
115 | def __init__(self, depth=28, widen_factor=10, num_classes= 200 , loss_type = 'dist', per_img_std = False, stride = 1 ):
116 | dropRate = 0.5
117 | flatten = True
118 | super(WideResNet, self).__init__()
119 | nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor]
120 | assert((depth - 4) % 6 == 0)
121 | n = (depth - 4) / 6
122 | block = BasicBlock
123 | # 1st conv before any network block
124 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1,
125 | padding=1, bias=False)
126 | # 1st block
127 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, stride, dropRate)
128 | # 2nd block
129 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate)
130 | # 3rd block
131 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate)
132 | # global average pooling and linear
133 | self.bn1 = nn.BatchNorm2d(nChannels[3])
134 | self.relu = nn.ReLU(inplace=True)
135 | self.nChannels = nChannels[3]
136 |
137 | if loss_type == 'softmax':
138 | self.linear = nn.Linear(nChannels[3], int(num_classes))
139 | self.linear.bias.data.fill_(0)
140 | else:
141 | self.linear = distLinear(nChannels[3], int(num_classes))
142 |
143 | self.num_classes = num_classes
144 | if flatten:
145 | self.final_feat_dim = 640
146 | for m in self.modules():
147 | if isinstance(m, nn.Conv2d):
148 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
149 | m.weight.data.normal_(0, math.sqrt(2. / n))
150 | elif isinstance(m, nn.BatchNorm2d):
151 | m.weight.data.fill_(1)
152 | m.bias.data.zero_()
153 |
154 |
155 | def forward(self, x, target= None, mixup=False, mixup_hidden=True, mixup_alpha=None , lam = 0.4):
156 | if target is not None:
157 | if mixup_hidden:
158 | layer_mix = random.randint(0,3)
159 | elif mixup:
160 | layer_mix = 0
161 | else:
162 | layer_mix = None
163 |
164 | out = x
165 |
166 | target_a = target_b = target
167 |
168 | if layer_mix == 0:
169 | out, target_a , target_b , lam = mixup_data(out, target, lam=lam)
170 |
171 | out = self.conv1(out)
172 | out = self.block1(out)
173 |
174 |
175 | if layer_mix == 1:
176 | out, target_a , target_b , lam = mixup_data(out, target, lam=lam)
177 |
178 | out = self.block2(out)
179 |
180 | if layer_mix == 2:
181 | out, target_a , target_b , lam = mixup_data(out, target, lam=lam)
182 |
183 |
184 | out = self.block3(out)
185 | if layer_mix == 3:
186 | out, target_a , target_b , lam = mixup_data(out, target, lam=lam)
187 |
188 | out = self.relu(self.bn1(out))
189 | out = F.avg_pool2d(out, out.size()[2:])
190 | out = out.view(out.size(0), -1)
191 | out1 = self.linear(out)
192 |
193 | return out , out1 , target_a , target_b
194 | else:
195 | out = x
196 | out = self.conv1(out)
197 | out = self.block1(out)
198 | out = self.block2(out)
199 | out = self.block3(out)
200 | out = self.relu(self.bn1(out))
201 | out = F.avg_pool2d(out, out.size()[2:])
202 | out = out.view(out.size(0), -1)
203 | out1 = self.linear(out)
204 | return out, out1
205 |
206 |
207 |
208 | def wrn28_10(num_classes=200 , loss_type = 'dist'):
209 | model = WideResNet(depth=28, widen_factor=10, num_classes=num_classes, loss_type = loss_type , per_img_std = False, stride = 1 )
210 | return model
211 |
212 |
--------------------------------------------------------------------------------
/FSLTask.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | import numpy as np
4 | import torch
5 | import sys
6 | from pprint import pprint
7 | # from tqdm import tqdm
8 | nc = 34
9 | # ========================================================
10 | # Usefull paths
11 | _datasetFeaturesFiles = {"train": "/cvhci/data/activity/kpeng/logits_split0_chunk90_swin_base_last_logits768_last22train.pkl",
12 | #"train": "/cvhci/data/activity/kpeng/logits_split1_chunk90_swin_base_last_logits768_last_augmentedtrain.pkl",
13 | #"train": "/cvhci/data/activity/kpeng/logits_split2_chunk90_swin_base_last_logits768_lasttrain.pkl",
14 | #"train": "/cvhci/data/activity/kpeng/task_level_logits_split0_chunk90_swin_base_last_logits768_lasttrain.pkl",
15 | #"train": "/cvhci/data/activity/kpeng/object_level_logits_split2_chunk90_swin_base_last_logits768_lasttrain.pkl",
16 | #"train": "/cvhci/data/activity/kpeng/location_level_logits_split1_chunk90_swin_base_last_logits768_lasttrain.pkl",
17 | "train_aug": "/cvhci/data/activity/kpeng/logits_split0_chunk90_swin_base_last_logits768_last22augmentedtrain",
18 | #"train_aug": "/cvhci/data/activity/kpeng/logits_split1_chunk90_swin_base_last_logits768_lastaugmentedrealtrain.pkl",
19 | #"train_aug": "/cvhci/data/activity/kpeng/logits_split2_chunk90_swin_base_last_logits768_last_augmentedtrain.pkl",
20 | #"train_aug": "/cvhci/data/activity/kpeng/task_level_logits_split0_chunk90_swin_base_last_logits768_lastaugmentedtrain.pkl",
21 | #"train_aug": "/cvhci/data/activity/kpeng/object_level_logits_split2_chunk90_swin_base_last_logits768_lastaugmentedtrain.pkl",
22 | #"train_aug": "/cvhci/data/activity/kpeng/location_level_logits_split1_chunk90_swin_base_last_logits768_lastaugmentedtrain.pkl",
23 | "eval": "/cvhci/data/activity/kpeng/logits_split0_chunk90_swin_base_last_logits768_last22val.pkl",
24 | #"eval": "/cvhci/data/activity/kpeng/logits_split2_chunk90_swin_base_last_logits768_lastval.pkl",
25 | #"eval": "/cvhci/data/activity/kpeng/logits_split1_chunk90_swin_base_last_logits768_lastaugmentedval.pkl",
26 | #"eval": "/cvhci/data/activity/kpeng/task_level_logits_split0_chunk90_swin_base_last_logits768_lastval.pkl",
27 | #"test": "/cvhci/data/activity/kpeng/logits_split0_chunk90_swin_base_last_logits768_last_ids2test.pkl",
28 | #"eval": "/cvhci/data/activity/kpeng/object_level_logits_split2_chunk90_swin_base_last_logits768_lastval.pkl",
29 | #"eval": "/cvhci/data/activity/kpeng/location_level_logits_split1_chunk90_swin_base_last_logits768_lastval.pkl",
30 | "test": "/cvhci/data/activity/kpeng/logits_split0_chunk90_swin_base_last_logits768_last22test.pkl"
31 | #"test": "/cvhci/data/activity/kpeng/logits_split1_chunk90_swin_base_last_logits768_lastaugmentedtest.pkl"
32 | #"test": "/cvhci/data/activity/kpeng/task_level_logits_split0_chunk90_swin_base_last_logits768_lasttest.pkl"
33 | #"test": "/cvhci/data/activity/kpeng/object_level_logits_split2_chunk90_swin_base_last_logits768_lasttest.pkl",
34 | #"test": "/cvhci/data/activity/kpeng/location_level_logits_split1_chunk90_swin_base_last_logits768_lasttest.pkl",
35 | }
36 | _cacheDir = "./cache"
37 | _maxRuns = 10000
38 | _min_examples = -1
39 |
40 | # ========================================================
41 | # Module internal functions and variables
42 |
43 | _randStates = None
44 | _rsCfg = None
45 |
46 |
47 |
48 | def load_label_feature(item):
49 | #f = open("/cvhci/data/activity/Drive&Act/kunyu/annotation_list.pkl", 'rb')
50 | f = open("/cvhci/data/activity/Drive&Act/kunyu/annotation_list.pkl", 'rb')
51 | annotation = []
52 | class_index = pickle.load(f)
53 | print(class_index)
54 | f.close()
55 | infos = item.keys()
56 | class_label = []
57 | for info in infos:
58 | info = ''.join([item[0] for item in list(info)])
59 | #print(info)
60 | activity = info.split(',')[-2]
61 | if activity not in class_label:
62 | class_label.append(activity)
63 | #print(activity)
64 | label = class_index.index(activity)
65 | annotation.append(label)
66 | features = item.values()
67 | #print(features)
68 | print(class_label)
69 | feature = [term for term in features]
70 | #print(feature)
71 | return feature, annotation
72 |
73 | def _load_pickle(file_train,file_train_aug,file_eval,file_test):
74 | dataset = dict()
75 | with open(file_train, 'rb') as f:
76 | data = pickle.load(f)
77 | feature, annotation = load_label_feature(data)
78 | #labels = [np.full(shape=len(data[key]), fill_value=key)
79 | # for key in data]
80 | #data = [features for key in data for features in data[key]]
81 | dataset['data_train'] = torch.FloatTensor(np.stack(feature, axis=0))
82 | #print(dataset['data_train'].size())
83 | dataset['labels_train'] = torch.LongTensor(np.stack(annotation, axis=0))
84 | with open(file_train_aug, 'rb') as f:
85 | data = pickle.load(f)
86 | feature, annotation = load_label_feature(data)
87 | # labels = [np.full(shape=len(data[key]), fill_value=key)
88 | # for key in data]
89 | # data = [features for key in data for features in data[key]]
90 | dataset['data_train_aug'] = torch.FloatTensor(np.stack(feature, axis=0))
91 | # print(dataset['data_train'].size())
92 | dataset['labels_train_aug'] = torch.LongTensor(np.stack(annotation, axis=0))
93 | with open(file_eval, 'rb') as f:
94 | data = pickle.load(f)
95 | feature, annotation = load_label_feature(data)
96 | #labels = [np.full(shape=len(data[key]), fill_value=key)
97 | # for key in data]
98 | #data = [features for key in data for features in data[key]]
99 |
100 | dataset['data_eval'] = torch.FloatTensor(np.stack(feature, axis=0))
101 | dataset['labels_eval'] = torch.LongTensor(np.stack(annotation, axis=0))
102 | #print(dataset['labels_eval'])
103 | #sys.exit()
104 | with open(file_test, 'rb') as f:
105 | data = pickle.load(f)
106 | feature, annotation = load_label_feature(data)
107 | #labels = [np.full(shape=len(data[key]), fill_value=key)
108 | # for key in data]
109 | #data = [features for key in data for features in data[key]]
110 | dataset['data_test'] = torch.FloatTensor(np.stack(feature, axis=0))
111 | dataset['labels_test'] = torch.LongTensor(np.stack(annotation,axis=0))
112 | return dataset
113 |
114 | def calculate_samples_per_class(annotation):
115 | class_index_max = torch.max(annotation)+1
116 | number_per_class = torch.zeros(class_index_max)
117 | for i in range(class_index_max):
118 | number_per_class[i] = (annotation == i).sum()
119 | return number_per_class
120 | def arrange_dataset(data, annotation):
121 | class_index_max = torch.max(annotation)+1
122 | arranged_dataset = []
123 | for i in range(class_index_max):
124 | mask = annotation == i
125 | arranged_dataset.append(data[mask,:,:])
126 | return arranged_dataset
127 | def rare_class_selection(data, sam_number):
128 | #print(data.size())
129 | rare_class_threshold = 100
130 | annotation = torch.arange(nc)
131 | print(sam_number)
132 | rare_mask = (sam_number < rare_class_threshold).bool()
133 | rich_mask = sam_number >= rare_class_threshold
134 | rare_mask_list = [int(item) for item in annotation[rare_mask].tolist()]
135 | #print(rare_mask_list)
136 | #print(data)
137 | rare_data = [data[i].squeeze() for i in rare_mask_list]
138 | rare_classes = annotation[rare_mask]
139 | #rich_data = data[rich_mask, :,:]
140 | rich_classes = annotation[rich_mask]
141 | rich_mask_list = [int(item) for item in annotation[rich_mask].tolist()]
142 | #print(rich_mask_list)
143 | # print(data)
144 | rich_data = [data[i].squeeze() for i in rich_mask_list]
145 | return rare_data, rare_classes, rich_data, rich_classes
146 | def load_dataset_driveact():
147 | train_path = _datasetFeaturesFiles['train']
148 | train_aug_path = _datasetFeaturesFiles['train_aug']
149 |
150 | val_path= _datasetFeaturesFiles['eval']
151 | test_path = _datasetFeaturesFiles['test']
152 | #load_train_dataset
153 |
154 | dataset = _load_pickle(train_path,train_aug_path, val_path, test_path)
155 | feature_train = dataset['data_train']
156 | annotation_train = dataset['labels_train']
157 | feature_val, annotation_val = dataset['data_eval'], dataset['labels_eval']
158 | feature_test, annotation_test = dataset['data_test'], dataset['labels_test']
159 | feature_train_aug, annotation_train_aug = dataset['data_train_aug'], dataset['labels_train_aug']
160 | #print(feature_train-feature_train_aug)
161 | #sys.exit()
162 | sample_per_class_train = calculate_samples_per_class(annotation_train_aug)
163 | sample_per_class_train_aug = calculate_samples_per_class(annotation_train)
164 | arrange_dataset_train = arrange_dataset(feature_train, annotation_train)
165 | arrange_dataset_train_aug = arrange_dataset(feature_train_aug, annotation_train_aug)
166 | rare_data, rare_classes, rich_data, rich_classes = rare_class_selection(arrange_dataset_train, sample_per_class_train)
167 | rare_data_aug, rare_classes_aug, rich_data_aug, rich_classes_aug = rare_class_selection(arrange_dataset_train_aug, sample_per_class_train_aug)
168 |
169 | return rare_data, rare_classes, rich_data, rich_classes, feature_val,annotation_val,feature_test, annotation_test, sample_per_class_train,rare_data_aug, rare_classes_aug, rich_data_aug, rich_classes_aug,sample_per_class_train_aug
170 |
171 |
172 | # =========================================================
173 | # Callable variables and functions from outside the module
174 |
175 | data = None
176 | labels = None
177 | dsName = None
178 |
179 |
180 | def loadDataSet(dsname):
181 | if dsname not in _datasetFeaturesFiles:
182 | raise NameError('Unknwown dataset: {}'.format(dsname))
183 |
184 | global dsName, data, labels, _randStates, _rsCfg, _min_examples
185 | dsName = dsname
186 | _randStates = None
187 | _rsCfg = None
188 |
189 | # Loading data from files on computer
190 | # home = expanduser("~")
191 | dataset = _load_pickle(_datasetFeaturesFiles[dsname])
192 |
193 | # Computing the number of items per class in the dataset
194 | _min_examples = dataset["labels"].shape[0]
195 | for i in range(dataset["labels"].shape[0]):
196 | if torch.where(dataset["labels"] == dataset["labels"][i])[0].shape[0] > 0:
197 | _min_examples = min(_min_examples, torch.where(
198 | dataset["labels"] == dataset["labels"][i])[0].shape[0])
199 | print("Guaranteed number of items per class: {:d}\n".format(_min_examples))
200 |
201 | # Generating data tensors
202 | data = torch.zeros((0, _min_examples, dataset["data"].shape[1]))
203 | labels = dataset["labels"].clone()
204 | while labels.shape[0] > 0:
205 | indices = torch.where(dataset["labels"] == labels[0])[0]
206 | data = torch.cat([data, dataset["data"][indices, :]
207 | [:_min_examples].view(1, _min_examples, -1)], dim=0)
208 | indices = torch.where(labels != labels[0])[0]
209 | labels = labels[indices]
210 | print("Total of {:d} classes, {:d} elements each, with dimension {:d}\n".format(
211 | data.shape[0], data.shape[1], data.shape[2]))
212 |
213 |
214 | def GenerateRun(iRun, cfg, regenRState=False, generate=True):
215 | global _randStates, data, _min_examples
216 | if not regenRState:
217 | np.random.set_state(_randStates[iRun])
218 |
219 | classes = np.random.permutation(np.arange(data.shape[0]))[:cfg["ways"]]
220 | shuffle_indices = np.arange(_min_examples)
221 | dataset = None
222 | if generate:
223 | dataset = torch.zeros(
224 | (cfg['ways'], cfg['shot']+cfg['queries'], data.shape[2]))
225 | for i in range(cfg['ways']):
226 | shuffle_indices = np.random.permutation(shuffle_indices)
227 | if generate:
228 | dataset[i] = data[classes[i], shuffle_indices,
229 | :][:cfg['shot']+cfg['queries']]
230 |
231 | return dataset
232 |
233 |
234 | def ClassesInRun(iRun, cfg):
235 | global _randStates, data
236 | np.random.set_state(_randStates[iRun])
237 | classes = np.random.permutation(np.arange(data.shape[0]))[:cfg["ways"]]
238 | return classes
239 |
240 |
241 | def setRandomStates(cfg):
242 | global _randStates, _maxRuns, _rsCfg
243 | if _rsCfg == cfg:
244 | return
245 | rsFile = os.path.join(_cacheDir, "RandStates_{}_s{}_q{}_w{}".format(
246 | dsName, cfg['shot'], cfg['queries'], cfg['ways']))
247 | if not os.path.exists(rsFile):
248 | print("{} does not exist, regenerating it...".format(rsFile))
249 | np.random.seed(0)
250 | _randStates = []
251 | for iRun in range(_maxRuns):
252 | _randStates.append(np.random.get_state())
253 | GenerateRun(iRun, cfg, regenRState=True, generate=False)
254 | torch.save(_randStates, rsFile)
255 | else:
256 | print("reloading random states from file....")
257 | _randStates = torch.load(rsFile)
258 | _rsCfg = cfg
259 |
260 |
261 | def GenerateRunSet(start=None, end=None, cfg=None):
262 | global dataset, _maxRuns
263 | if start is None:
264 | start = 0
265 | if end is None:
266 | end = _maxRuns
267 | if cfg is None:
268 | cfg = {"shot": 1, "ways": 5, "queries": 15}
269 |
270 | setRandomStates(cfg)
271 | print("generating task from {} to {}".format(start, end))
272 |
273 | dataset = torch.zeros(
274 | (end-start, cfg['ways'], cfg['shot']+cfg['queries'], data.shape[2]))
275 | for iRun in range(end-start):
276 | dataset[iRun] = GenerateRun(start+iRun, cfg)
277 |
278 | return dataset
279 |
280 |
281 | # define a main code to test this module
282 | if __name__ == "__main__":
283 |
284 | print("Testing Task loader for Few Shot Learning")
285 | loadDataSet('miniimagenet')
286 |
287 | cfg = {"shot": 1, "ways": 5, "queries": 15}
288 | setRandomStates(cfg)
289 |
290 | run10 = GenerateRun(10, cfg)
291 | print("First call:", run10[:2, :2, :2])
292 |
293 | run10 = GenerateRun(10, cfg)
294 | print("Second call:", run10[:2, :2, :2])
295 |
296 | ds = GenerateRunSet(start=2, end=12, cfg=cfg)
297 | print("Third call:", ds[8, :2, :2, :2])
298 | print(ds.size())
299 |
--------------------------------------------------------------------------------
/evaluate_DC.py:
--------------------------------------------------------------------------------
1 | import torch.optim.lr_scheduler
2 | import pickle
3 | import numpy as np
4 | import torch.nn as nn
5 | import torch
6 | from sklearn.linear_model import LogisticRegression
7 | from tqdm import tqdm
8 | from FSLTask import load_dataset_driveact
9 | use_gpu = torch.cuda.is_available()
10 | import sklearn
11 | import sys
12 | from torch.utils.data import Dataset
13 | from torchvision.transforms import ToTensor
14 | from torchvision import datasets
15 | from torch.utils.data import DataLoader
16 | import random
17 | SEED = 42
18 |
19 | random.seed(SEED)
20 | torch.manual_seed(SEED)
21 | torch.cuda.manual_seed(SEED)
22 | torch.cuda.manual_seed_all(SEED)
23 | np.random.seed(SEED)
24 | torch.backends.cudnn.deterministic = True
25 | torch.backends.cudnn.benchmark = False
26 | from sklearn.metrics import confusion_matrix
27 |
28 | from einops import rearrange, repeat
29 | from einops.layers.torch import Rearrange
30 | global nc
31 | nc = 34
32 | # helpers
33 |
34 | def pair(t):
35 | return t if isinstance(t, tuple) else (t, t)
36 |
37 | # classes
38 |
39 | class PreNorm(nn.Module):
40 | def __init__(self, dim, fn):
41 | super().__init__()
42 | self.norm = nn.LayerNorm(dim)
43 | self.fn = fn
44 | def forward(self, x, **kwargs):
45 | return self.fn(self.norm(x), **kwargs)
46 |
47 | class FeedForward(nn.Module):
48 | def __init__(self, dim, hidden_dim, dropout = 0.):
49 | super().__init__()
50 | self.net = nn.Sequential(
51 | nn.Linear(dim, hidden_dim),
52 | nn.GELU(),
53 | nn.Dropout(dropout),
54 | nn.Linear(hidden_dim, dim),
55 | nn.Dropout(dropout)
56 | )
57 | def forward(self, x):
58 | return self.net(x)
59 |
60 | class Attention(nn.Module):
61 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
62 | super().__init__()
63 | inner_dim = dim_head * heads
64 | project_out = not (heads == 1 and dim_head == dim)
65 |
66 | self.heads = heads
67 | self.scale = dim_head ** -0.5
68 |
69 | self.attend = nn.Softmax(dim = -1)
70 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
71 |
72 | self.to_out = nn.Sequential(
73 | nn.Linear(inner_dim, dim),
74 | nn.Dropout(dropout)
75 | ) if project_out else nn.Identity()
76 |
77 | def forward(self, x):
78 | qkv = self.to_qkv(x).chunk(3, dim = -1)
79 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
80 |
81 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
82 |
83 | attn = self.attend(dots)
84 |
85 | out = torch.matmul(attn, v)
86 | out = rearrange(out, 'b h n d -> b n (h d)')
87 | return self.to_out(out)
88 |
89 |
90 | def random_feature_interpolation(selected_mean, query, k, num):
91 | num_q = query.shape[0]
92 | #print(num_q)
93 | num_base = selected_mean.shape[0]
94 | origin = np.random.choice(num_q, num)
95 | target = np.random.choice(num_base, num)
96 | alpha = np.stack([np.random.rand(num)]*1024, axis=-1)*0.07
97 | #print(query[origin].shape,selected_mean[target].shape)
98 | generated_feature = query[origin,:] + alpha*selected_mean[target,:]
99 | #print(generated_feature.shape)
100 | #sys.exit()
101 | return generated_feature
102 |
103 | def distribution_calibration(query, base_means, base_cov, k,alpha, num):
104 | query = query.numpy()
105 | dist = []
106 | k=1
107 | alpha=0.21
108 | #print(query.shape)
109 | for i in range(len(base_means)):
110 | dist.append(np.linalg.norm(query-base_means[i]))
111 | index = np.argpartition(dist, k)[:k]
112 | #mean_basics = np.array(base_means)[index]
113 | #print(mean_basics.shape)
114 | #sys.exit()
115 | selected_mean = np.array(base_means)[index]
116 | mean = np.concatenate([np.array(base_means)[index], np.squeeze(query[np.newaxis, :])])
117 | #mean = np.squeeze(query[np.newaxis, :])
118 | calibrated_mean = np.mean(mean, axis=0)
119 | calibrated_cov = np.mean(np.array(base_cov)[index], axis=0)+alpha
120 | samples = random_feature_interpolation(selected_mean,query,k, num)
121 | #print(calibrated_mean)
122 | #print(calibrated_cov)
123 | #feature interpolation based feature augmentation
124 |
125 | return calibrated_mean, calibrated_cov, samples
126 |
127 | class CustomImageDataset(Dataset):
128 | def __init__(self, feature, annotation, transform=None, target_transform=None):
129 |
130 | self.feature = feature #.view(-1,34)
131 | #print(self.feature.shape)
132 | #sys-exit()
133 | self.annotations = annotation
134 | def __len__(self):
135 | return len(self.annotations)
136 | def __getitem__(self, idx):
137 | return self.feature[idx], self.annotations[idx]
138 |
139 | class Net(nn.Module):
140 | def __init__(self):
141 | super(Net, self).__init__()
142 | self.fc1 = nn.Linear(1024, 256)
143 | self.attention = nn.Linear(256,128)
144 | self.attention2 = nn.Linear(128, 256)
145 | #self.pool = nn.MaxPool1d(128)
146 | self.relu = nn.ReLU()
147 | #self.bn = torch.nn.BatchNorm1d(128)
148 | self.dropout = torch.nn.Dropout(p=0.5, inplace=False)
149 | self.softmax = nn.Sigmoid()
150 | #self.conv2 = nn.Conv2d(6, 16, 5)
151 | # self.fc1 = nn.Linear(16 * 5 * 5, 120)
152 | # self.fc2 = nn.Linear(120, 84)
153 | self.fc2 = nn.Linear(256,nc)
154 | self.fc3=nn.Linear(1024,nc)
155 |
156 | def forward(self, x,y):
157 | y = y.float()
158 | x=self.fc1(x.float())
159 | att = self.softmax(self.attention2(self.relu(self.attention(x))))
160 | return self.fc2(self.relu(self.dropout(x+att*x))),self.fc3(y),
161 |
162 | def calculate_samples_per_class(annotation):
163 | class_index_max = nc
164 | #print(annotation)
165 | number_per_class = torch.zeros(int(class_index_max))
166 | for i in range(class_index_max):
167 | number_per_class[i] = (annotation == i).sum()
168 | return number_per_class
169 | def calculate_weights(annotation, weights):
170 | class_index_max = nc
171 | #number_per_class = torch.zeros(class_index_max)
172 | sampler_weight = torch.zeros_like(annotation)
173 | for i in range(class_index_max):
174 | mask= annotation == i
175 | sampler_weight[mask] = weights[i]
176 | return sampler_weight
177 | def generate_train(rare_data, rare_classes, rich_data, rich_classes,sample_num_per_class):
178 | base_means = []
179 | base_cov = []
180 | for key in range(len(rich_data)):
181 | feature = np.array(rich_data[key])
182 | print(feature.shape)
183 | mean = np.mean(feature, axis=0)
184 | cov = np.cov(feature.T)
185 | base_means.append(mean)
186 | base_cov.append(cov)
187 |
188 | # ---- classification for each task
189 | acc_list = []
190 | print('Start classification for %d tasks...'%(n_runs))
191 |
192 | #support_data = ndatas[i][:n_lsamples].numpy()
193 | #support_label = labels[i][:n_lsamples].numpy()
194 | #query_data = ndatas[i][n_lsamples:].numpy()
195 | #query_label = labels[i][n_lsamples:].numpy()
196 | # ---- Tukey's transform
197 | beta = 0.5
198 | #support_data = np.power(support_data[:, ] ,beta)
199 | #query_data = np.power(query_data[:, ] ,beta)
200 | # ---- cross distribution calibration for rare classes
201 | sampled_data = []
202 | sampled_label = []
203 | count = 0
204 | np.set_printoptions(threshold=sys.maxsize)
205 | #for i in range(len(rare_classes)):
206 | # print(rare_data[i].shape)
207 | for i in range(len(rare_classes)):
208 | print(sample_num_per_class[i])
209 |
210 | #if sample_num_per_class[rare_classes[i]] == 0:
211 | # continue
212 | num_sampled = 1000 #(int(torch.max(sample_num_per_class) - sample_num_per_class[rare_classes[i]]))
213 | count += num_sampled
214 | #print(rare_data[i].shape)
215 | mean, cov, samples = distribution_calibration(rare_data[i], base_means, base_cov, 2, 0.21, 1000)
216 | #print(num_sampled)
217 | #print(samples.shape)
218 | #print(cov)
219 | sampled_data.append(samples)
220 | #sampled_data.append(np.random.multivariate_normal(list(mean), list(cov), num_sampled, 'warn'))
221 | #print(np.mean(sampled_data[i], axis=0))
222 | #print(np.max(sampled_data[i], axis=0))
223 | #val_data = feature_val[annotation_val==rare_classes[i]].numpy()
224 | #print(val_data)
225 | #print(np.mean(val_data, axis=0)-np.mean(sampled_data[i], axis=0))
226 | sampled_label.extend([rare_classes[i]]*num_sampled)
227 | #sampled_label.extend([rare_classes[i]] * int(sample_num_per_class[rare_classes[i]]))
228 | #sys.exit()
229 | #sys.exit()
230 | sampled_data = np.concatenate(sampled_data).reshape(count, 1024)
231 | rare_data = np.concatenate(rare_data, axis=0)
232 | rare_label = []#torch.zeros(rare_data.shape[0])
233 | for i in range(len(rare_classes)):
234 | rare_label.extend([rare_classes[i]] * int(sample_num_per_class[rare_classes[i]]))
235 | rare_label = np.array(rare_label)
236 | X_aug_1 = sampled_data #np.concatenate([rare_data, sampled_data])
237 | Y_aug_1 = sampled_label #np.concatenate([rare_label,sampled_label])
238 | #print(X_aug_1.shape)
239 | #print(Y_aug_1.shape)
240 |
241 | # ---- self distribution calibration for rich classes
242 | #num_sampled = int(750 / n_shot)
243 | sampled_data = []
244 | sampled_label = []
245 | count = 0
246 |
247 | for i in range(len(rich_classes)):
248 | num_sampled = 1000 #int(torch.max(sample_num_per_class) - sample_num_per_class[rich_classes[i]])
249 | #if sample_num_per_class>500:
250 | # continue
251 | count += num_sampled
252 | #print(rich_classes[i])
253 | #mean, conv, samples = base_means[i], base_cov[i]
254 | #print(samples.shape)
255 | mean, cov, samples = distribution_calibration(rich_data[i], base_means, base_cov, 2, 0.21, 1000)
256 | #sampled_data.append(np.random.multivariate_normal(mean=mean, cov=cov, size=num_sampled))
257 | sampled_data.append(samples)
258 | sampled_label.extend([rich_classes[i]] * num_sampled)
259 | #sampled_label.extend([rich_classes[i]] * int(sample_num_per_class[rich_classes[i]]))
260 | sampled_data = np.concatenate(sampled_data).reshape(count, 1024)
261 | rich_label = []#torch.zeros(rare_data.shape[0])
262 | for i in range(len(rich_classes)):
263 | rich_label.extend([rich_classes[i]] * int(sample_num_per_class[rich_classes[i]]))
264 | rare_label = np.array(rich_label)
265 | rich_data = np.concatenate(rich_data, axis=0)
266 | X_aug_2 = sampled_data #rich_data #sampled_data #np.concatenate([rich_data, sampled_data])
267 | Y_aug_2 = sampled_label #rich_label#sampled_label #np.concatenate([rich_label,sampled_label])
268 | X_aug = np.concatenate([X_aug_1, X_aug_2])
269 | Y_aug = np.concatenate([Y_aug_1, Y_aug_2])
270 | #X_aug += np.random.normal(0, .1, X_aug.shape)
271 | return X_aug, Y_aug
272 |
273 | if __name__ == '__main__':
274 | # ---- data loading
275 | n_runs = 10000
276 | import FSLTask
277 | import torch.optim as optim
278 |
279 |
280 |
281 | rare_data, rare_classes, rich_data, rich_classes, feature_val, annotation_val, feature_test, annotation_test, sample_num_per_class,rare_data_aug, rare_classes_aug, rich_data_aug, rich_classes_aug,sample_num_per_class_aug = load_dataset_driveact()
282 | #acc = sklearn.metrics.top_k_accuracy_score(annotation_test, np.squeeze(feature_test), k=1)
283 | #print(acc)
284 | #noise_mean = np.zeros(34, 1024)
285 |
286 | #print(len(rare_data)+len(rich_data))
287 | #print(len(annotation_test))
288 | #sys.exit()
289 | #print(len(annotation_test))
290 | #length= rare_classes.shape[0]+rich_data.shape[]
291 | #cfg = {'shot': n_shot, 'ways': n_ways, 'queries': n_queries}
292 | #FSLTask.loadDataSet(dataset)
293 | #FSLTask.setRandomStates(cfg)
294 | #rich_datas, rare_data = FSLTask.GenerateRunSet(end=n_runs, cfg=cfg)
295 | #ndatas = ndatas.permute(0, 2, 1, 3).reshape(n_runs, n_samples, -1)
296 | #labels = torch.arange(n_ways).view(1, 1, n_ways).expand(n_runs, n_shot + n_queries, 5).clone().view(n_runs, n_samples)
297 | # ---- Base class statistics
298 |
299 | #base_features_path = "./checkpoints/%s/WideResNet28_10_S2M2_R/last/base_features.plk"%dataset
300 |
301 | X_aug, Y_aug = generate_train(rare_data,rare_classes,rich_data,rich_classes,sample_num_per_class)
302 | X_aug2, Y_aug2 = generate_train(rare_data_aug,rare_classes_aug,rich_data_aug,rich_classes_aug,sample_num_per_class_aug)
303 | #print(torch.Tensor(Y_aug))
304 | X_aug = np.concatenate([X_aug, X_aug2])
305 | Y_aug = np.concatenate([Y_aug, Y_aug2])
306 | sample_number = calculate_samples_per_class(torch.Tensor(Y_aug))
307 | #print(sample_number)
308 | weights = 1/sample_number
309 | #print(weights)
310 | sampler_weight = calculate_weights(torch.Tensor(Y_aug), weights)
311 |
312 | # ---- train classifier
313 | #print(X_aug.shape)
314 | #if mode == 'train':
315 | # samples_weight = torch.from_numpy(np.array([weight[t] for t in dataset.gt_labels]))
316 | #print(sampler_weight)
317 | sampler = torch.utils.data.sampler.WeightedRandomSampler(sampler_weight,3000)
318 |
319 |
320 | dataset_train = CustomImageDataset(X_aug, Y_aug)
321 | #train_GAN(DataLoader(dataset_train, batch_size=256, sampler=sampler))
322 | dataset_val = CustomImageDataset(np.squeeze(feature_val), annotation_val)
323 | dataset_test = CustomImageDataset(np.squeeze(feature_test), annotation_test)
324 |
325 |
326 |
327 | train_dataloader = DataLoader(dataset_train, batch_size=256, sampler=sampler)
328 |
329 |
330 | infer_dataloader = DataLoader(dataset_train, batch_size=256, shuffle=False)
331 | test_dataloader = DataLoader(dataset_test, batch_size=64, shuffle=False)
332 | val_dataloader = DataLoader(dataset_val, batch_size=64, shuffle=False)
333 |
334 | #model = LogisticRegression(max_iter=10000,verbose=10).fit(X=X_aug, y=Y_aug)
335 | model = Net()
336 | #resume = '/cvhci/temp/kpeng/driveact/models_swin_base/best_top1_acc_epoch_24.pth'
337 | #checkpoint = torch.load(resume)
338 | #print(checkpoint['state_dict']['cls_head.fc_cls.weight'])
339 | #print(checkpoint['state_dict']['cls_head.fc_cls.bias'])
340 | #model.fc1.weight.data = checkpoint['state_dict']['cls_head.fc_cls.weight']
341 | #model.fc1.bias.data = checkpoint['state_dict']['cls_head.fc_cls.bias']
342 | #sys.exit()
343 | model=model.cuda()
344 | criterion = nn.CrossEntropyLoss(reduce='None',reduction='mean')
345 | criterion2 = nn.CrossEntropyLoss(reduce=False,reduction='none')
346 | #optimizer = optim.SGD(model.parameters(), lr=0.0001, momentum=0.9)
347 | #scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
348 |
349 | optimizer = torch.optim.AdamW(model.parameters(), lr=0.000993, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01, amsgrad=False)
350 | #optimizer = torch.optim.SGD(model.parameters(), 0.001,
351 | # momentum=0.9,
352 | # weight_decay=0.01)
353 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=5, eta_min=0, last_epoch=- 1, verbose=False)
354 |
355 | #print(np.min(Y_aug))
356 | #sys.exit()
357 | #criterion2 = torch.nn.CosineEmbeddingLoss(margin=0.0, size_average=None, reduce=False,)
358 | for epoch in range(1500):
359 | hard_samples = []
360 | model.train()
361 | for step, (data,label) in enumerate(train_dataloader):
362 | optimizer.zero_grad()
363 | predicts,y = model(data.cuda(), data.cuda())
364 |
365 | loss = criterion(predicts, label.cuda()) #+ criterion2(predicts,y, torch.ones(y.size()[0]).cuda())
366 | loss.backward()
367 | optimizer.step()
368 | scheduler.step()
369 | if (epoch > 50) and (epoch%30 == 0):
370 | model.eval()
371 | for index, (data, label) in enumerate(infer_dataloader):
372 | with torch.no_grad():
373 | #label = torch.Tensor(label).cuda().double()
374 | predicts,y = model(data.cuda(), data.cuda())
375 | #print(predicts.size())
376 | difficulty = criterion2(predicts, label.cuda())
377 | #print(difficulty)
378 | hard_samples.append(difficulty.data)
379 | #print(hard_samples)
380 | difficulty = torch.cat(hard_samples, dim=0)
381 | threshold = 1.2* torch.mean(difficulty)
382 | mask = (difficulty>threshold).cpu().numpy()
383 | hard_set = X_aug[mask]
384 | hard_label = Y_aug[mask]
385 | hard_dataset = CustomImageDataset(hard_set, hard_label)
386 |
387 | hard_train_dataloader = DataLoader(hard_dataset, batch_size=256, shuffle=True)
388 | for sub_epoch in range(1):
389 | model.train()
390 | for step, (data,label) in enumerate(hard_train_dataloader):
391 | optimizer.zero_grad()
392 | predicts,y = model(data.cuda(), data.cuda())
393 | loss = 3*criterion(predicts, label.cuda()) #+ criterion2(predicts,y, torch.ones(y.size()[0]).cuda())
394 | loss.backward()
395 | optimizer.step()
396 | #scheduler.step()
397 | print(epoch, 'hard_loss', loss)
398 | print(epoch, 'loss', loss)
399 |
400 | val_predict = []
401 | model.eval()
402 | for step, (data,label) in enumerate(val_dataloader):
403 | with torch.no_grad():
404 | #data = torch.nn.functional.normalize(data, dim=-1)
405 |
406 | predicts,y = model(data.cuda(), data.cuda())
407 | val_predict.append(predicts.cpu())
408 | val_predict = torch.cat(val_predict, dim=0).cpu().numpy()
409 | test_predict = []
410 | #val_predict = np.argmax(val_predict, axis=-1)
411 | #print(predicts)
412 | acc = sklearn.metrics.top_k_accuracy_score(annotation_val, val_predict, k=1)
413 | f = open('/cvhci/data/activity/kpeng/ts_val_midlevel_predict_split0.pkl', 'wb')
414 | pickle.dump(val_predict,f)
415 | f.close()
416 | f = open('/cvhci/data/activity/kpeng/ts_val_midlevel_label_split0.pkl', 'wb')
417 | pickle.dump(annotation_val,f)
418 | f.close()
419 | print('two-stage calibration eval ACC : %f'%acc)
420 | #predicts = model.predict(np.squeeze(feature_test))
421 | cm = confusion_matrix(annotation_val, np.argmax(val_predict, axis=-1))
422 | f = open("/cvhci/data/activity/Drive&Act/kunyu/annotation_list.pkl", 'rb')
423 | annotation = []
424 | class_index = pickle.load(f)
425 | # We will store the results in a dictionary for easy access later
426 | per_class_accuracies = {}
427 | # Calculate the accuracy for each one of our classes
428 | for idx, cls in enumerate(range(nc)):
429 | # True negatives are all the samples that are not our current GT class (not the current row)
430 | # and were not predicted as the current class (not the current column)
431 | true_negatives = np.sum(np.delete(np.delete(cm, idx, axis=0), idx, axis=1))
432 | # True positives are all the samples of our current GT class that were predicted as such
433 | true_positives = cm[idx, idx]
434 | # The accuracy for the current class is ratio between correct predictions to all predictions
435 | per_class_accuracies[cls] = (true_positives + true_negatives) / np.sum(cm)
436 | #print(class_index[idx], 'val_accuracy', per_class_accuracies[cls])
437 | model.eval()
438 |
439 |
440 | for step, (data,label) in enumerate(test_dataloader):
441 | with torch.no_grad():
442 | data = torch.nn.functional.normalize(data, dim=-1)
443 | predicts,y = model(data.cuda(), data.cuda())
444 | test_predict.append(predicts.cpu())
445 | test_predict = torch.cat(test_predict, dim=0).cpu().numpy()
446 | #print(np.squeeze(feature_test).shape)
447 | #sys.exit()
448 | #predicts = model.predict(np.squeeze(feature_val))
449 | #predicts = np.argmax(predicts, axis=-1)
450 | #test_predict = np.argmax(test_predict, axis=-1)
451 | acc = sklearn.metrics.top_k_accuracy_score(annotation_test, test_predict, k=1)
452 | f = open('/cvhci/data/activity/kpeng/ts_test_midlevel_predict_split0.pkl', 'wb')
453 | pickle.dump(test_predict,f)
454 | f.close()
455 | f = open('/cvhci/data/activity/kpeng/ts_test_midlevel_label_split0.pkl', 'wb')
456 | pickle.dump(annotation_test,f)
457 | f.close()
458 | print('two-stage calibration test ACC : %f' % acc)
459 | #for i in range(34):
460 | # mask = annotation_test == i
461 | # acc = sklearn.metrics.top_k_accuracy_score(torch.argmax(annotation_test[mask], dim=-1), test_predict[mask], k=1)
462 | # print('class', class_index[i], 'accuracy:', acc)
463 | #filename = 'finalized_model.sav'
464 | #pickle.dump(model, open(filename, 'wb'))
465 | # Get the confusion matrix
466 | cm = confusion_matrix(annotation_test, np.argmax(test_predict, axis=-1))
467 | # We will store the results in a dictionary for easy access later
468 | per_class_accuracies = {}
469 | # Calculate the accuracy for each one of our classes
470 | for idx, cls in enumerate(range(nc)):
471 | # True negatives are all the samples that are not our current GT class (not the current row)
472 | # and were not predicted as the current class (not the current column)
473 | true_negatives = np.sum(np.delete(np.delete(cm, idx, axis=0), idx, axis=1))
474 | # True positives are all the samples of our current GT class that were predicted as such
475 | true_positives = cm[idx, idx]
476 | # The accuracy for the current class is ratio between correct predictions to all predictions
477 | per_class_accuracies[cls] = (true_positives + true_negatives) / np.sum(cm)
478 | #print(class_index[idx], 'test_accuracy', per_class_accuracies[cls])
479 | cm = confusion_matrix(annotation_test, np.argmax(feature_test, axis=-1))
480 | # We will store the results in a dictionary for easy access later
481 | per_class_accuracies = {}
482 | # Calculate the accuracy for each one of our classes
483 | for idx, cls in enumerate(range(nc)):
484 | # True negatives are all the samples that are not our current GT class (not the current row)
485 | # and were not predicted as the current class (not the current column)
486 | true_negatives = np.sum(np.delete(np.delete(cm, idx, axis=0), idx, axis=1))
487 | # True positives are all the samples of our current GT class that were predicted as such
488 | true_positives = cm[idx, idx]
489 | # The accuracy for the current class is ratio between correct predictions to all predictions
490 | per_class_accuracies[cls] = (true_positives + true_negatives) / np.sum(cm)
491 | #print(class_index[idx], 'test_accuracy', per_class_accuracies[cls])
492 |
493 |
494 |
495 |
--------------------------------------------------------------------------------