├── prototypical_module ├── __init__.py ├── convnet.py ├── samplers.py ├── mini_imagenet.py ├── utils.py ├── extract_feature.py ├── logreg.py └── resnet.py ├── domain_adaptive_module ├── __init__.py ├── eval.sh ├── lr_schedule.py ├── loss.py ├── data_list.py ├── lr.py ├── pre_process.py ├── eval.py └── network.py ├── dataset ├── tiered-imagenet │ ├── readme.md │ ├── label_dict.txt │ ├── train.txt │ ├── test_new_domain_fsl.txt │ └── test.txt ├── DomainNet │ ├── val_classes.txt │ ├── test_classes.txt │ └── train_classes.txt ├── mini-imagenet │ ├── label_dict.txt │ ├── test_new_domain.txt │ └── val_source_domain.txt └── get_datasetlist.py ├── pretrain ├── train.sh ├── dataloader.py ├── main_resnet.py └── resnet.py ├── train.sh ├── train_cross.sh ├── Readme.md ├── .gitignore ├── data_loader.py ├── test.py ├── train_lambda.py └── train.py /prototypical_module/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /domain_adaptive_module/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dataset/tiered-imagenet/readme.md: -------------------------------------------------------------------------------- 1 | Due to the large file size, we here provide some samples. 2 | -------------------------------------------------------------------------------- /pretrain/train.sh: -------------------------------------------------------------------------------- 1 | python -u main_resnet.py --epochs 50 --batch_size 1024 2>&1 | tee log.txt & 2 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | python -u train.py --gpu_id 0 --net ResNet50 --dset tiered-imagenet --s_dset_path dataset/tiered-imagenet/ --fsl_test_path dataset/tiered-imagenet/ --shot 5 --train-way 5 --pretrained 'tiered_checkpoint.pth.tar' 2 | -------------------------------------------------------------------------------- /train_cross.sh: -------------------------------------------------------------------------------- 1 | python -u train_cross.py --gpu_id 0 --net ResNet50 --dset mini-imagenet --s_dset_path dataset/mini-imagenet/ --fsl_test_path dataset/mini-imagenet/ --shot 5 --train-way 16 --pretrained 'mini_checkpoint.pth.tar' --output_dir mini_16 2 | -------------------------------------------------------------------------------- /domain_adaptive_module/eval.sh: -------------------------------------------------------------------------------- 1 | now=`date +%Y-%m-%d,%H:%m:%s` 2 | srun -p Test --gres=gpu:8 -n1 --job-name=cdan python -u eval.py --net ResNet50 --dset imagenet --t_dset_path list/test_transfer_list.txt --output_dir test 2>&1|tee logs/eval-${now}.log 3 | -------------------------------------------------------------------------------- /domain_adaptive_module/lr_schedule.py: -------------------------------------------------------------------------------- 1 | def inv_lr_scheduler(optimizer, iter_num, gamma, power, lr=0.001, weight_decay=0.0005): 2 | """Decay learning rate by a factor of 0.1 every lr_decay_epoch epochs.""" 3 | lr = lr * (1 + gamma * iter_num) ** (-power) 4 | i=0 5 | for param_group in optimizer.param_groups: 6 | param_group['lr'] = lr * param_group['lr_mult'] 7 | param_group['weight_decay'] = weight_decay * param_group['decay_mult'] 8 | i+=1 9 | 10 | return optimizer 11 | 12 | 13 | schedule_dict = {"inv":inv_lr_scheduler} 14 | -------------------------------------------------------------------------------- /dataset/DomainNet/val_classes.txt: -------------------------------------------------------------------------------- 1 | airplane/ 2 | axe/ 3 | backpack/ 4 | baseball_bat/ 5 | bat/ 6 | bee/ 7 | binoculars/ 8 | blueberry/ 9 | brain/ 10 | bucket/ 11 | calendar/ 12 | church/ 13 | cooler/ 14 | crayon/ 15 | crocodile/ 16 | cup/ 17 | dishwasher/ 18 | door/ 19 | feather/ 20 | finger/ 21 | garden/ 22 | golf_club/ 23 | helicopter/ 24 | house_plant/ 25 | jacket/ 26 | lollipop/ 27 | mountain/ 28 | mouse/ 29 | nose/ 30 | oven/ 31 | paint_can/ 32 | parrot/ 33 | piano/ 34 | pond/ 35 | power_outlet/ 36 | raccoon/ 37 | rifle/ 38 | river/ 39 | sandwich/ 40 | shovel/ 41 | snail/ 42 | snowflake/ 43 | spreadsheet/ 44 | square/ 45 | steak/ 46 | stereo/ 47 | stove/ 48 | strawberry/ 49 | swan/ 50 | table/ 51 | trombone/ 52 | wheel/ 53 | windmill/ 54 | wine_glass/ 55 | zigzag/ 56 | -------------------------------------------------------------------------------- /prototypical_module/convnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def conv_block(in_channels, out_channels): 5 | return nn.Sequential( 6 | nn.Conv2d(in_channels, out_channels, 3, padding=1), 7 | nn.BatchNorm2d(out_channels), 8 | nn.ReLU(), 9 | nn.MaxPool2d(2) 10 | ) 11 | 12 | 13 | class Convnet(nn.Module): 14 | 15 | def __init__(self, x_dim=3, hid_dim=64, z_dim=64): 16 | super().__init__() 17 | self.encoder = nn.Sequential( 18 | conv_block(x_dim, hid_dim), 19 | conv_block(hid_dim, hid_dim), 20 | conv_block(hid_dim, hid_dim), 21 | conv_block(hid_dim, z_dim), 22 | ) 23 | self.out_channels = 1600 24 | 25 | def forward(self, x): 26 | x = self.encoder(x) 27 | return x.view(x.size(0), -1) 28 | 29 | -------------------------------------------------------------------------------- /dataset/DomainNet/test_classes.txt: -------------------------------------------------------------------------------- 1 | barn/ 2 | basketball/ 3 | beach/ 4 | bicycle/ 5 | birthday_cake/ 6 | bracelet/ 7 | bread/ 8 | bus/ 9 | camouflage/ 10 | castle/ 11 | compass/ 12 | computer/ 13 | dragon/ 14 | drill/ 15 | eye/ 16 | eyeglasses/ 17 | flip_flops/ 18 | foot/ 19 | frog/ 20 | frying_pan/ 21 | harp/ 22 | headphones/ 23 | hockey_stick/ 24 | hot_dog/ 25 | ice_cream/ 26 | key/ 27 | knife/ 28 | ladder/ 29 | leaf/ 30 | leg/ 31 | light_bulb/ 32 | lightning/ 33 | lipstick/ 34 | microwave/ 35 | motorbike/ 36 | moustache/ 37 | necklace/ 38 | octagon/ 39 | passport/ 40 | pliers/ 41 | pool/ 42 | potato/ 43 | rabbit/ 44 | rain/ 45 | rollerskates/ 46 | screwdriver/ 47 | see_saw/ 48 | shark/ 49 | skull/ 50 | snake/ 51 | snorkel/ 52 | snowman/ 53 | soccer_ball/ 54 | sock/ 55 | stitches/ 56 | stop_sign/ 57 | streetlight/ 58 | sweater/ 59 | swing_set/ 60 | tennis_racquet/ 61 | The_Eiffel_Tower/ 62 | The_Mona_Lisa/ 63 | toaster/ 64 | toilet/ 65 | tooth/ 66 | toothpaste/ 67 | tree/ 68 | truck/ 69 | wine_bottle/ 70 | zebra/ 71 | -------------------------------------------------------------------------------- /prototypical_module/samplers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class CategoriesSampler(): 6 | 7 | def __init__(self, label, n_batch, n_cls, n_per): 8 | self.n_batch = n_batch 9 | self.n_cls = n_cls 10 | self.n_per = n_per 11 | 12 | label = np.array(label) 13 | print max(label) 14 | self.m_ind = [] 15 | for i in range(max(label) + 1): 16 | ind = np.argwhere(label == i).reshape(-1) 17 | ind = torch.from_numpy(ind) 18 | self.m_ind.append(ind) 19 | 20 | def __len__(self): 21 | return self.n_batch 22 | 23 | def __iter__(self): 24 | for i_batch in range(self.n_batch): 25 | batch = [] 26 | classes = torch.randperm(len(self.m_ind))[:self.n_cls] 27 | for c in classes: 28 | l = self.m_ind[c] 29 | pos = torch.randperm(len(l))[:self.n_per] 30 | batch.append(l[pos]) 31 | batch = torch.stack(batch).t().reshape(-1) 32 | yield batch 33 | 34 | -------------------------------------------------------------------------------- /pretrain/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import os 3 | from PIL import Image 4 | 5 | 6 | class MyDataset(data.Dataset): 7 | def __init__(self, file, dir_path, new_width, new_height, transform=None): 8 | imgs = [] 9 | fw = open(file, 'r') 10 | lines = fw.readlines() 11 | for line in lines: 12 | words = line.split() 13 | imgs.append((words[0], int(words[1]))) 14 | self.imgs = imgs 15 | self.dir_path = dir_path 16 | self.height = new_height 17 | self.width = new_width 18 | self.transform = transform 19 | 20 | def __getitem__(self, index): 21 | path, label = self.imgs[index] 22 | idx = path.split('/')[1].split('.')[0] 23 | path = os.path.join(self.dir_path, path) 24 | img = Image.open(path).convert('RGB') 25 | # img = img.resize((self.width, self.height), Image.ANTIALIAS) 26 | if self.transform is not None: 27 | img = self.transform(img) 28 | return img, label, idx 29 | 30 | def __len__(self): 31 | return len(self.imgs) 32 | -------------------------------------------------------------------------------- /dataset/mini-imagenet/label_dict.txt: -------------------------------------------------------------------------------- 1 | n03775546 0 2 | n02219486 1 3 | n02443484 2 4 | n02116738 3 5 | n03272010 4 6 | n02110063 5 7 | n02871525 6 8 | n02099601 7 9 | n07613480 8 10 | n04522168 9 11 | n03127925 10 12 | n01981276 11 13 | n04149813 12 14 | n02129165 13 15 | n03544143 14 16 | n04418357 15 17 | n03146219 16 18 | n04146614 17 19 | n02110341 18 20 | n01930112 19 21 | n04251144 20 22 | n04596742 21 23 | n03017168 22 24 | n04296562 23 25 | n03838899 24 26 | n04515003 25 27 | n04389033 26 28 | n02105505 27 29 | n01532829 28 30 | n04612504 29 31 | n03998194 30 32 | n02165456 31 33 | n01843383 32 34 | n02606052 33 35 | n02111277 34 36 | n07747607 35 37 | n03476684 36 38 | n01910747 37 39 | n04258138 38 40 | n03924679 39 41 | n01558993 40 42 | n03347037 41 43 | n04067472 42 44 | n09246464 43 45 | n02747177 44 46 | n02795169 45 47 | n04275548 46 48 | n04443257 47 49 | n04435653 48 50 | n02113712 49 51 | n02074367 50 52 | n03854065 51 53 | n04604644 52 54 | n01770081 53 55 | n02120079 54 56 | n07697537 55 57 | n03400231 56 58 | n02108089 57 59 | n03220513 58 60 | n02966193 59 61 | n01704323 60 62 | n03337140 61 63 | n02101006 62 64 | n13133613 63 65 | -------------------------------------------------------------------------------- /prototypical_module/mini_imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | 4 | from torch.utils.data import Dataset 5 | from torchvision import transforms 6 | from PIL import ImageFile 7 | ImageFile.LOAD_TRUNCATED_IMAGES = True 8 | 9 | class MiniImageNet(Dataset): 10 | 11 | def __init__(self, root='../dataset/mini-imagenet/train'): 12 | self.root = root 13 | self.data = [] 14 | self.label = [] 15 | self.label_dict = self._get_label() 16 | self._load_dataset() 17 | 18 | self.transform = transforms.Compose([ 19 | transforms.Resize(84), 20 | transforms.CenterCrop(84), 21 | transforms.ToTensor(), 22 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 23 | std=[0.229, 0.224, 0.225]) 24 | ]) 25 | 26 | def _get_label(self): 27 | labels = {} 28 | idx = 0 29 | train_list = os.listdir(self.root) 30 | for name in train_list: 31 | labels[name] = idx 32 | idx += 1 33 | return labels 34 | 35 | def _load_dataset(self): 36 | path = self.root 37 | subdirs = os.listdir(path) 38 | for subdir in subdirs: 39 | labels = self.label_dict[subdir] 40 | imgs = os.listdir(os.path.join(path, subdir)) 41 | for img in imgs: 42 | img_path = os.path.join(path, subdir, img) 43 | self.data.append(img_path) 44 | self.label.append(labels) 45 | 46 | def __getitem__(self, i): 47 | path, label = self.data[i], self.label[i] 48 | image = self.transform(Image.open(path).convert('RGB')) 49 | return image, label 50 | 51 | def __len__(self): 52 | return len(self.data) 53 | -------------------------------------------------------------------------------- /prototypical_module/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import time 4 | import pprint 5 | 6 | import torch 7 | 8 | 9 | def set_gpu(x): 10 | os.environ['CUDA_VISIBLE_DEVICES'] = x 11 | print('using gpu:', x) 12 | 13 | 14 | def ensure_path(path): 15 | if os.path.exists(path): 16 | if input('{} exists, remove? ([y]/n)'.format(path)) != 'n': 17 | shutil.rmtree(path) 18 | os.mkdir(path) 19 | else: 20 | os.mkdir(path) 21 | 22 | 23 | class Averager(): 24 | 25 | def __init__(self): 26 | self.n = 0 27 | self.v = 0 28 | 29 | def add(self, x): 30 | self.v = (self.v * self.n + x) / (self.n + 1) 31 | self.n += 1 32 | 33 | def item(self): 34 | return self.v 35 | 36 | 37 | def count_acc(logits, label): 38 | pred = torch.argmax(logits, dim=1) 39 | return (pred == label).type(torch.cuda.FloatTensor).mean().item() 40 | 41 | 42 | def dot_metric(a, b): 43 | return torch.mm(a, b.t()) 44 | 45 | 46 | def euclidean_metric(a, b): 47 | n = a.shape[0] 48 | m = b.shape[0] 49 | a = a.unsqueeze(1).expand(n, m, -1) 50 | b = b.unsqueeze(0).expand(n, m, -1) 51 | logits = -((a - b)**2).sum(dim=2) 52 | return logits 53 | 54 | 55 | class Timer(): 56 | 57 | def __init__(self): 58 | self.o = time.time() 59 | 60 | def measure(self, p=1): 61 | x = (time.time() - self.o) / p 62 | x = int(x) 63 | if x >= 3600: 64 | return '{:.1f}h'.format(x / 3600) 65 | if x >= 60: 66 | return '{}m'.format(round(x / 60)) 67 | return '{}s'.format(x) 68 | 69 | _utils_pp = pprint.PrettyPrinter() 70 | def pprint(x): 71 | _utils_pp.pprint(x) 72 | 73 | 74 | def l2_loss(pred, label): 75 | return ((pred - label)**2).sum() / len(pred) / 2 76 | 77 | -------------------------------------------------------------------------------- /dataset/get_datasetlist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | root = 'dataset/' 4 | dataset_list = ['mini-imagenet', 'tiered-imagenet'] 5 | split_list = ['train', 'val', 'test', 'test_new_domain'] 6 | for dataset in dataset_list: 7 | label_dict = {} 8 | label_dict_reverse = {} 9 | idx = 0 10 | for split in split_list: 11 | if split != 'test_new_domain': 12 | sw = open(os.path.join('dataset', dataset, split+'.txt'), 'w+') 13 | path = os.path.join(root, dataset, split) 14 | subdirs = os.listdir(path) 15 | for subdir in subdirs: 16 | label_dict[subdir] = idx 17 | label_dict_reverse[idx] = subdir 18 | imgs = os.listdir(os.path.join(path, subdir)) 19 | for img in imgs: 20 | img_path = os.path.join(split, subdir, img) 21 | sw.writelines(img_path+" "+str(idx)+"\n") 22 | idx += 1 23 | sw.close() 24 | else: 25 | sw = open(os.path.join('dataset', dataset, split+'.txt'), 'w+') 26 | sw2 = open(os.path.join('dataset', dataset, split+'_fsl.txt'), 'w+') 27 | path = os.path.join(root, dataset, split) 28 | subdirs = os.listdir(path) 29 | for subdir in subdirs: 30 | idx = label_dict[subdir] 31 | imgs = os.listdir(os.path.join(path, subdir)) 32 | random.shuffle(imgs) 33 | for i, img in enumerate(imgs): 34 | img_path = os.path.join(split, subdir, img) 35 | if i < 5: 36 | sw.writelines(img_path+" "+str(idx)+"\n") 37 | else: 38 | sw2.writelines(img_path+" "+str(idx)+"\n") 39 | sw.close() 40 | sw2.close() 41 | sw = open(os.path.join('dataset', dataset, 'label_dict.txt'), 'w+') 42 | for i in range(len(label_dict_reverse.keys())): 43 | sw.writelines(label_dict_reverse[i]+" "+str(i)+"\n") 44 | sw.close() 45 | -------------------------------------------------------------------------------- /prototypical_module/extract_feature.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch.utils.data import DataLoader 7 | import torch.utils.model_zoo as model_zoo 8 | from imagenet import ImageNet 9 | from resnet import * 10 | import numpy as np 11 | from utils import pprint, set_gpu, ensure_path, Averager, Timer, count_acc, euclidean_metric 12 | 13 | def main(): 14 | set_gpu('0') 15 | save_path = 'features/test_new_domain_miniimagenet/' 16 | test_set = ImageNet(root='../cross-domain-fsl/dataset/mini-imagenet/test_new_domain') 17 | val_loader = DataLoader(dataset=test_set, batch_size=1, shuffle=False, num_workers=8, 18 | pin_memory=True) 19 | model = resnet50() 20 | model = torch.nn.DataParallel(model).cuda() 21 | model.load_state_dict(torch.load('save/proto-5/max-acc.pth')) 22 | 23 | # model_dict = model.state_dict() 24 | # pretrained_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet50-19c8e357.pth') 25 | # # 1. filter out unnecessary keys 26 | # pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 27 | # # 2. overwrite entries in the existing state dict 28 | # print pretrained_dict 29 | # model_dict.update(pretrained_dict) 30 | # # 3. load the new state dict 31 | # model.load_state_dict(model_dict) 32 | model = model.cuda() 33 | model.eval() 34 | # model = torch.nn.DataParallel(model).cuda() 35 | features = [[] for i in range(359)] 36 | for (image, label) in val_loader: 37 | image = image.cuda() 38 | label = label.numpy() 39 | feature = model(image) 40 | feature = feature.data.cpu().numpy() 41 | # print feature.shape[0] 42 | for j in range(feature.shape[0]): 43 | features[int(label[j])].append(feature[j]) 44 | for i in range(359): 45 | save_file = os.path.join(save_path, str(i)+'.txt') 46 | feature_np = np.asarray(features[i]) 47 | np.savetxt(save_file, feature_np) 48 | 49 | if __name__ == '__main__': 50 | main() 51 | -------------------------------------------------------------------------------- /domain_adaptive_module/loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | import math 6 | import torch.nn.functional as F 7 | import pdb 8 | 9 | def Entropy(input_): 10 | bs = input_.size(0) 11 | epsilon = 1e-5 12 | entropy = -input_ * torch.log(input_ + epsilon) 13 | entropy = torch.sum(entropy, dim=1) 14 | return entropy 15 | 16 | def grl_hook(coeff): 17 | def fun1(grad): 18 | return -coeff*grad.clone() 19 | return fun1 20 | 21 | def CDAN(input_list, ad_net, entropy=None, coeff=None, random_layer=None): 22 | softmax_output = input_list[1].detach() 23 | feature = input_list[0] 24 | if random_layer is None: 25 | op_out = torch.bmm(softmax_output.unsqueeze(2), feature.unsqueeze(1)) 26 | ad_out = ad_net(op_out.view(-1, softmax_output.size(1) * feature.size(1))) 27 | else: 28 | random_out = random_layer.forward([feature, softmax_output]) 29 | ad_out = ad_net(random_out.view(-1, random_out.size(1))) 30 | batch_size = softmax_output.size(0) // 2 31 | dc_target = torch.from_numpy(np.array([[1]] * batch_size + [[0]] * batch_size)).float().cuda() 32 | if entropy is not None: 33 | entropy.register_hook(grl_hook(coeff)) 34 | entropy = 1.0+torch.exp(-entropy) 35 | source_mask = torch.ones_like(entropy) 36 | source_mask[feature.size(0)//2:] = 0 37 | source_weight = entropy*source_mask 38 | target_mask = torch.ones_like(entropy) 39 | target_mask[0:feature.size(0)//2] = 0 40 | target_weight = entropy*target_mask 41 | weight = source_weight / torch.sum(source_weight).detach().item() + \ 42 | target_weight / torch.sum(target_weight).detach().item() 43 | return torch.sum(weight.view(-1, 1) * nn.BCELoss(reduction='none')(ad_out, dc_target)) / torch.sum(weight).detach().item() 44 | else: 45 | return nn.BCELoss()(ad_out, dc_target) 46 | 47 | def DANN(features, ad_net): 48 | ad_out = ad_net(features) 49 | batch_size = ad_out.size(0) // 2 50 | dc_target = torch.from_numpy(np.array([[1]] * batch_size + [[0]] * batch_size)).float().cuda() 51 | return nn.BCELoss()(ad_out, dc_target) 52 | -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | ## Introduction 2 | 3 | The framework is implemented and tested with Ubuntu 16.04, CUDA 8.0/9.0, Python 3, Pytorch 0.4/1.0/1.1, NVIDIA TITANX GPU. 4 | 5 | ## Requirements 6 | 7 | - **Cuda & Cudnn & Python & Pytorch** 8 | 9 | This project is tested with CUDA 8.0/9.0, Python 3, Pytorch 0.4/1.0, NVIDIA TITANX GPUs. 10 | 11 | Please install proper CUDA and CUDNN version, and then install Anaconda3 and Pytorch. Almost all the packages we use are covered by Anaconda. 12 | 13 | - **My settings** 14 | 15 | ```shell 16 | source ~/anaconda3/bin/activate (python 3.6.5) 17 | (base) pip list 18 | torch 0.4.1 19 | torchvision 0.2.2.post3 20 | numpy 1.18.1 21 | numpydoc 0.8.0 22 | numba 0.42.0 23 | opencv-python 4.0.0.21 24 | ``` 25 | 26 | 27 | ## Data preparation 28 | 29 | Download and unzip the datasets: **MiniImageNet, TieredImageNet, DomainNet**. 30 | 31 | Here we provide the datasets of target domain in Google Drive, [miniImageNet](https://drive.google.com/file/d/1Yxzw2kJarXCV2tldKzXt6rlGqcLuv24W), [tieredImageNet](https://drive.google.com/file/d/1Unqwgiuoy7br8vKiEZo8Jhib-eNDxc5p). 32 | 33 | Format: 34 | (E.g. mini-imagenet) 35 | ```shell 36 | MINI_DIR/ 37 | -- train/ 38 | -- n01532829/ 39 | -- n01558993/ 40 | ... 41 | -- train_new_domain/ 42 | -- val/ 43 | -- val_new_domain/ 44 | -- test/ 45 | -- test_new_domain/ 46 | ``` 47 | 48 | 49 | ## Training 50 | 51 | First set the dataset path `MINI_DIR/, TIERED_DIR/, DOMAIN_DIR/` for the three datasets. 52 | 53 | For each dataset, we use its training set to train a pre-trained model (e.g. tiered-imagenet). 54 | 55 | ``` 56 | cd pretrain/ 57 | python -u main_resnet.py --epochs 50 --batch_size 1024 --dir_path TIERED_DIR 2>&1 | tee log.txt & 58 | ``` 59 | 60 | We then use the corresponding pre-trained model to train on each dataset. (e.g. mini-imagenet) 61 | 62 | ``` 63 | python -u train_cross.py --gpu_id 0 --net ResNet50 --dset mini-imagenet --s_dset_path MINI_DIR --fsl_test_path MINI_DIR --shot 5 --train-way 16 --pretrained 'mini_checkpoint.pth.tar' --output_dir mini_way_16 64 | ``` 65 | 66 | 67 | ## Testing 68 | 69 | ``` 70 | python -u test.py --load MODEL_PATH --root MINI_DIR 71 | ``` 72 | 73 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | from PIL import Image 5 | 6 | from torch.utils.data import Dataset 7 | from torchvision import transforms 8 | from PIL import ImageFile 9 | ImageFile.LOAD_TRUNCATED_IMAGES = True 10 | 11 | class CategoriesSampler(): 12 | 13 | def __init__(self, label, n_batch, n_cls, n_per): 14 | self.n_batch = n_batch 15 | self.n_cls = n_cls 16 | self.n_per = n_per 17 | 18 | label = np.array(label) 19 | print(max(label)) 20 | self.m_ind = [] 21 | for i in range(max(label) + 1): 22 | ind = np.argwhere(label == i).reshape(-1) 23 | ind = torch.from_numpy(ind) 24 | self.m_ind.append(ind) 25 | 26 | def __len__(self): 27 | return self.n_batch 28 | 29 | def __iter__(self): 30 | for i_batch in range(self.n_batch): 31 | batch = [] 32 | classes = torch.randperm(len(self.m_ind))[:self.n_cls] 33 | for c in classes: 34 | l = self.m_ind[c] 35 | pos = torch.randperm(len(l))[:self.n_per] 36 | batch.append(l[pos]) 37 | batch = torch.stack(batch).t().reshape(-1) 38 | yield batch 39 | 40 | class MiniImageNet(Dataset): 41 | 42 | def __init__(self, root='dataset/mini-imagenet/train', dataset='mini-imagenet', mode='train'): 43 | self.root = root 44 | self.data = [] 45 | self.label = [] 46 | self.dataset = dataset 47 | self.mode = mode 48 | self._load_dataset() 49 | self.transform = transforms.Compose([ 50 | transforms.RandomHorizontalFlip(), 51 | transforms.ToTensor(), 52 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 53 | std=[0.229, 0.224, 0.225]) 54 | ]) 55 | 56 | def _load_dataset(self): 57 | path = self.root 58 | fw = open(os.path.join('dataset', self.dataset, self.mode+'.txt')) 59 | lines = fw.readlines() 60 | for line in lines: 61 | img_path = os.path.join(path, line.split()[0]) 62 | labels = int(line.split()[1]) 63 | self.data.append(img_path) 64 | self.label.append(labels) 65 | fw.close() 66 | 67 | def __getitem__(self, i): 68 | path, label = self.data[i], self.label[i] 69 | image = self.transform(Image.open(path).convert('RGB')) 70 | return image, label 71 | 72 | def __len__(self): 73 | return len(self.data) 74 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.utils.data import DataLoader 6 | import numpy as np 7 | from data_loader import * 8 | import domain_adaptive_module.network as network 9 | from prototypical_module.utils import pprint, set_gpu, count_acc, Averager, euclidean_metric 10 | 11 | 12 | if __name__ == '__main__': 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--gpu', default='0,1,2,3') 15 | parser.add_argument('--load', default='snapshot/mini_16/iter_09500_model.pth.tar') 16 | parser.add_argument('--batch', type=int, default=2000) 17 | parser.add_argument('--way', type=int, default=5) 18 | parser.add_argument('--shot', type=int, default=5) 19 | parser.add_argument('--query', type=int, default=30) 20 | parser.add_argument('--root', default='dataset/mini-imagenet/') 21 | args = parser.parse_args() 22 | pprint(vars(args)) 23 | 24 | set_gpu(args.gpu) 25 | 26 | # dataset = MiniImageNet('test') 27 | dataset = MiniImageNet(root=args.root, dataset='mini-imagenet', mode='test_new_domain_fsl') #transfer 28 | #dataset = MiniImageNet(root=args.root, dataset='mini-imagenet', mode='test') #origin 29 | sampler = CategoriesSampler(dataset.label, 30 | args.batch, args.way, args.shot + args.query) 31 | loader = DataLoader(dataset, batch_sampler=sampler, 32 | num_workers=8, pin_memory=True) 33 | 34 | # model = Convnet().cuda() 35 | model = torch.load(args.load) 36 | # model= list(model.children())[0] 37 | # model = model.module 38 | # for key, value in base_network.state_dict().items(): 39 | # print(key) 40 | # model = list(model.children())[9].cuda() 41 | # base_network= torch.nn.Sequential(*list(base_network.children())[:-1]).cuda() 42 | model = nn.DataParallel(model) 43 | print(model) 44 | model.eval() 45 | 46 | ave_acc = Averager() 47 | test_accuracies = [] 48 | for i, batch in enumerate(loader, 1): 49 | data, _ = [_.cuda() for _ in batch] 50 | k = args.way * args.shot 51 | data_shot, data_query = data[:k], data[k:] 52 | 53 | x, _ = model(data_shot) 54 | x = x.reshape(args.shot, args.way, -1).mean(dim=0) 55 | p = x 56 | 57 | proto_query, _ = model(data_query) 58 | logits = euclidean_metric(proto_query, p) 59 | 60 | label = torch.arange(args.way).repeat(args.query) 61 | label = label.type(torch.cuda.LongTensor) 62 | 63 | acc = count_acc(logits, label) 64 | test_accuracies.append(acc) 65 | 66 | avg = np.mean(np.array(test_accuracies)) 67 | std = np.std(np.array(test_accuracies)) 68 | ci95 = 1.96 * std / np.sqrt(i + 1) 69 | print('batch {}: Accuracy: {:.4f} +- {:.4f} % ({:.4f} %)'.format(i, avg, ci95, acc)) 70 | x = None; p = None; logits = None 71 | 72 | -------------------------------------------------------------------------------- /dataset/DomainNet/train_classes.txt: -------------------------------------------------------------------------------- 1 | aircraft_carrier/ 2 | alarm_clock/ 3 | ambulance/ 4 | angel/ 5 | animal_migration/ 6 | ant/ 7 | anvil/ 8 | apple/ 9 | arm/ 10 | asparagus/ 11 | banana/ 12 | bandage/ 13 | baseball/ 14 | basket/ 15 | bathtub/ 16 | bear/ 17 | beard/ 18 | bed/ 19 | belt/ 20 | bench/ 21 | bird/ 22 | blackberry/ 23 | book/ 24 | boomerang/ 25 | bottlecap/ 26 | bowtie/ 27 | bridge/ 28 | broccoli/ 29 | broom/ 30 | bulldozer/ 31 | bush/ 32 | butterfly/ 33 | cactus/ 34 | cake/ 35 | calculator/ 36 | camel/ 37 | camera/ 38 | campfire/ 39 | candle/ 40 | cannon/ 41 | canoe/ 42 | car/ 43 | carrot/ 44 | cat/ 45 | ceiling_fan/ 46 | cello/ 47 | cell_phone/ 48 | chair/ 49 | chandelier/ 50 | circle/ 51 | clarinet/ 52 | clock/ 53 | cloud/ 54 | coffee_cup/ 55 | cookie/ 56 | couch/ 57 | cow/ 58 | crab/ 59 | crown/ 60 | cruise_ship/ 61 | diamond/ 62 | diving_board/ 63 | dog/ 64 | dolphin/ 65 | donut/ 66 | dresser/ 67 | drums/ 68 | duck/ 69 | dumbbell/ 70 | ear/ 71 | elbow/ 72 | elephant/ 73 | envelope/ 74 | eraser/ 75 | face/ 76 | fan/ 77 | fence/ 78 | fire_hydrant/ 79 | fireplace/ 80 | firetruck/ 81 | fish/ 82 | flamingo/ 83 | flashlight/ 84 | floor_lamp/ 85 | flower/ 86 | flying_saucer/ 87 | fork/ 88 | garden_hose/ 89 | giraffe/ 90 | goatee/ 91 | grapes/ 92 | grass/ 93 | guitar/ 94 | hamburger/ 95 | hammer/ 96 | hand/ 97 | hat/ 98 | hedgehog/ 99 | helmet/ 100 | hexagon/ 101 | hockey_puck/ 102 | horse/ 103 | hospital/ 104 | hot_air_balloon/ 105 | hot_tub/ 106 | hourglass/ 107 | house/ 108 | hurricane/ 109 | jail/ 110 | kangaroo/ 111 | keyboard/ 112 | knee/ 113 | lantern/ 114 | laptop/ 115 | lighter/ 116 | lighthouse/ 117 | line/ 118 | lion/ 119 | lobster/ 120 | mailbox/ 121 | map/ 122 | marker/ 123 | matches/ 124 | megaphone/ 125 | mermaid/ 126 | microphone/ 127 | monkey/ 128 | moon/ 129 | mosquito/ 130 | mouth/ 131 | mug/ 132 | mushroom/ 133 | nail/ 134 | ocean/ 135 | octopus/ 136 | onion/ 137 | owl/ 138 | paintbrush/ 139 | palm_tree/ 140 | panda/ 141 | pants/ 142 | paper_clip/ 143 | parachute/ 144 | peanut/ 145 | pear/ 146 | peas/ 147 | pencil/ 148 | penguin/ 149 | pickup_truck/ 150 | picture_frame/ 151 | pig/ 152 | pillow/ 153 | pineapple/ 154 | pizza/ 155 | police_car/ 156 | popsicle/ 157 | postcard/ 158 | purse/ 159 | radio/ 160 | rainbow/ 161 | rake/ 162 | remote_control/ 163 | rhinoceros/ 164 | roller_coaster/ 165 | sailboat/ 166 | saw/ 167 | saxophone/ 168 | school_bus/ 169 | scissors/ 170 | scorpion/ 171 | sea_turtle/ 172 | sheep/ 173 | shoe/ 174 | shorts/ 175 | sink/ 176 | skateboard/ 177 | skyscraper/ 178 | sleeping_bag/ 179 | smiley_face/ 180 | speedboat/ 181 | spider/ 182 | spoon/ 183 | squiggle/ 184 | squirrel/ 185 | stairs/ 186 | star/ 187 | stethoscope/ 188 | string_bean/ 189 | submarine/ 190 | suitcase/ 191 | sun/ 192 | sword/ 193 | syringe/ 194 | teapot/ 195 | teddy-bear/ 196 | telephone/ 197 | television/ 198 | tent/ 199 | The_Great_Wall_of_China/ 200 | tiger/ 201 | toe/ 202 | toothbrush/ 203 | tornado/ 204 | tractor/ 205 | traffic_light/ 206 | train/ 207 | triangle/ 208 | trumpet/ 209 | t-shirt/ 210 | umbrella/ 211 | underwear/ 212 | van/ 213 | vase/ 214 | violin/ 215 | washing_machine/ 216 | watermelon/ 217 | waterslide/ 218 | whale/ 219 | wristwatch/ 220 | yoga/ 221 | -------------------------------------------------------------------------------- /domain_adaptive_module/data_list.py: -------------------------------------------------------------------------------- 1 | #from __future__ import print_function, division 2 | 3 | import torch 4 | import numpy as np 5 | import random 6 | from PIL import Image 7 | from torch.utils.data import Dataset 8 | import os 9 | import os.path 10 | 11 | from PIL import ImageFile 12 | ImageFile.LOAD_TRUNCATED_IMAGES = True 13 | 14 | 15 | def make_dataset(image_list, labels): 16 | if labels: 17 | len_ = len(image_list) 18 | images = [(image_list[i].strip(), labels[i, :]) for i in range(len_)] 19 | else: 20 | if len(image_list[0].split()) > 2: 21 | images = [(val.split()[0], np.array([int(la) for la in val.split()[1:]])) for val in image_list] 22 | else: 23 | images = [(val.split()[0], int(val.split()[1])) for val in image_list] 24 | return images 25 | 26 | 27 | def rgb_loader(path): 28 | with open('dataset/imagenet/' + path, 'rb') as f: 29 | with Image.open(f) as img: 30 | return img.convert('RGB') 31 | 32 | def l_loader(path): 33 | with open(path, 'rb') as f: 34 | with Image.open(f) as img: 35 | return img.convert('L') 36 | 37 | class ImageList(Dataset): 38 | def __init__(self, image_list, labels=None, transform=None, target_transform=None, mode='RGB'): 39 | imgs = make_dataset(image_list, labels) 40 | if len(imgs) == 0: 41 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 42 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 43 | 44 | self.imgs = imgs 45 | self.transform = transform 46 | self.target_transform = target_transform 47 | if mode == 'RGB': 48 | self.loader = rgb_loader 49 | elif mode == 'L': 50 | self.loader = l_loader 51 | 52 | def __getitem__(self, index): 53 | path, target = self.imgs[index] 54 | img = self.loader(path) 55 | if self.transform is not None: 56 | img = self.transform(img) 57 | if self.target_transform is not None: 58 | target = self.target_transform(target) 59 | 60 | return img, target 61 | 62 | def __len__(self): 63 | return len(self.imgs) 64 | 65 | class ImageValueList(Dataset): 66 | def __init__(self, image_list, labels=None, transform=None, target_transform=None, 67 | loader=rgb_loader): 68 | imgs = make_dataset(image_list, labels) 69 | if len(imgs) == 0: 70 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 71 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 72 | 73 | self.imgs = imgs 74 | self.values = [1.0] * len(imgs) 75 | self.transform = transform 76 | self.target_transform = target_transform 77 | self.loader = loader 78 | 79 | def set_values(self, values): 80 | self.values = values 81 | 82 | def __getitem__(self, index): 83 | path, target = self.imgs[index] 84 | img = self.loader(path) 85 | if self.transform is not None: 86 | img = self.transform(img) 87 | if self.target_transform is not None: 88 | target = self.target_transform(target) 89 | 90 | return img, target 91 | 92 | def __len__(self): 93 | return len(self.imgs) 94 | 95 | -------------------------------------------------------------------------------- /dataset/mini-imagenet/test_new_domain.txt: -------------------------------------------------------------------------------- 1 | test_new_domain/n03775546/8490.png 0 2 | test_new_domain/n03775546/8612.png 0 3 | test_new_domain/n03775546/8419.png 0 4 | test_new_domain/n03775546/8536.png 0 5 | test_new_domain/n03775546/8750.png 0 6 | test_new_domain/n02219486/4615.png 1 7 | test_new_domain/n02219486/4400.png 1 8 | test_new_domain/n02219486/4631.png 1 9 | test_new_domain/n02219486/4511.png 1 10 | test_new_domain/n02219486/4420.png 1 11 | test_new_domain/n02443484/4898.png 2 12 | test_new_domain/n02443484/5106.png 2 13 | test_new_domain/n02443484/5216.png 2 14 | test_new_domain/n02443484/4871.png 2 15 | test_new_domain/n02443484/5033.png 2 16 | test_new_domain/n02116738/3084.png 3 17 | test_new_domain/n02116738/3051.png 3 18 | test_new_domain/n02116738/3117.png 3 19 | test_new_domain/n02116738/3450.png 3 20 | test_new_domain/n02116738/3442.png 3 21 | test_new_domain/n03272010/7212.png 4 22 | test_new_domain/n03272010/7767.png 4 23 | test_new_domain/n03272010/7540.png 4 24 | test_new_domain/n03272010/7221.png 4 25 | test_new_domain/n03272010/7742.png 4 26 | test_new_domain/n02110063/2329.png 5 27 | test_new_domain/n02110063/1881.png 5 28 | test_new_domain/n02110063/2314.png 5 29 | test_new_domain/n02110063/2040.png 5 30 | test_new_domain/n02110063/1801.png 5 31 | test_new_domain/n02871525/5773.png 6 32 | test_new_domain/n02871525/5906.png 6 33 | test_new_domain/n02871525/5425.png 6 34 | test_new_domain/n02871525/5752.png 6 35 | test_new_domain/n02871525/5575.png 6 36 | test_new_domain/n02099601/1765.png 7 37 | test_new_domain/n02099601/1514.png 7 38 | test_new_domain/n02099601/1350.png 7 39 | test_new_domain/n02099601/1375.png 7 40 | test_new_domain/n02099601/1317.png 7 41 | test_new_domain/n07613480/11433.png 8 42 | test_new_domain/n07613480/11611.png 8 43 | test_new_domain/n07613480/11510.png 8 44 | test_new_domain/n07613480/11957.png 8 45 | test_new_domain/n07613480/11630.png 8 46 | test_new_domain/n04522168/11269.png 9 47 | test_new_domain/n04522168/11243.png 9 48 | test_new_domain/n04522168/10905.png 9 49 | test_new_domain/n04522168/11002.png 9 50 | test_new_domain/n04522168/11170.png 9 51 | test_new_domain/n03127925/6572.png 10 52 | test_new_domain/n03127925/6234.png 10 53 | test_new_domain/n03127925/6017.png 10 54 | test_new_domain/n03127925/6064.png 10 55 | test_new_domain/n03127925/6016.png 10 56 | test_new_domain/n01981276/1062.png 11 57 | test_new_domain/n01981276/607.png 11 58 | test_new_domain/n01981276/1154.png 11 59 | test_new_domain/n01981276/1164.png 11 60 | test_new_domain/n01981276/1113.png 11 61 | test_new_domain/n04149813/10097.png 12 62 | test_new_domain/n04149813/9794.png 12 63 | test_new_domain/n04149813/9926.png 12 64 | test_new_domain/n04149813/9719.png 12 65 | test_new_domain/n04149813/9877.png 12 66 | test_new_domain/n02129165/4038.png 13 67 | test_new_domain/n02129165/4108.png 13 68 | test_new_domain/n02129165/3824.png 13 69 | test_new_domain/n02129165/4163.png 13 70 | test_new_domain/n02129165/3660.png 13 71 | test_new_domain/n03544143/7907.png 14 72 | test_new_domain/n03544143/7856.png 14 73 | test_new_domain/n03544143/8189.png 14 74 | test_new_domain/n03544143/8386.png 14 75 | test_new_domain/n03544143/7863.png 14 76 | test_new_domain/n04418357/10601.png 15 77 | test_new_domain/n04418357/10554.png 15 78 | test_new_domain/n04418357/10312.png 15 79 | test_new_domain/n04418357/10451.png 15 80 | test_new_domain/n04418357/10675.png 15 81 | test_new_domain/n03146219/6857.png 16 82 | test_new_domain/n03146219/6781.png 16 83 | test_new_domain/n03146219/7010.png 16 84 | test_new_domain/n03146219/6922.png 16 85 | test_new_domain/n03146219/6901.png 16 86 | test_new_domain/n04146614/9033.png 17 87 | test_new_domain/n04146614/9348.png 17 88 | test_new_domain/n04146614/9106.png 17 89 | test_new_domain/n04146614/9158.png 17 90 | test_new_domain/n04146614/9518.png 17 91 | test_new_domain/n02110341/2681.png 18 92 | test_new_domain/n02110341/2694.png 18 93 | test_new_domain/n02110341/2799.png 18 94 | test_new_domain/n02110341/2733.png 18 95 | test_new_domain/n02110341/2622.png 18 96 | test_new_domain/n01930112/60.png 19 97 | test_new_domain/n01930112/539.png 19 98 | test_new_domain/n01930112/475.png 19 99 | test_new_domain/n01930112/409.png 19 100 | test_new_domain/n01930112/553.png 19 101 | -------------------------------------------------------------------------------- /domain_adaptive_module/lr.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import matplotlib.pyplot as plt 4 | from sklearn import model_selection 5 | from sklearn.linear_model import LogisticRegression 6 | from sklearn import metrics 7 | 8 | #log_model = LogisticRegression() 9 | #log_model.fit(x, y) 10 | 11 | import argparse 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | import torch.optim as optim 16 | from torch.autograd import Variable 17 | import sys 18 | import time 19 | 20 | 21 | def train(model, X, labels, opts): 22 | optimizer = optim.SGD(model.parameters(), 23 | lr = opts.lr, 24 | momentum = opts.mom, 25 | weight_decay=opts.wd) 26 | N = X.shape[0] 27 | er = 0 28 | for it in range(0, opts.maxiter): 29 | dt = opts.lr 30 | model.train() 31 | optimizer.zero_grad() 32 | 33 | idx = np.random.randint(0,N,opts.batchsize) 34 | x = Variable(torch.Tensor(X[idx])) 35 | y = Variable(torch.from_numpy(labels[idx]).long()) 36 | yhat = model(x) 37 | print(x.size(),y.size(),yhat.size()) 38 | loss = F.nll_loss(yhat, y) 39 | er = er + loss.data.item() 40 | loss.backward() 41 | optimizer.step() 42 | 43 | if it % opts.verbose == 1: 44 | print(er/opts.verbose) 45 | er = 0 46 | return er/opts.verbose 47 | 48 | 49 | 50 | def train_balanced(model, X, labels, opts, freq): 51 | optimizer = optim.SGD(model.parameters(), 52 | lr = opts.lr, 53 | momentum = opts.mom, 54 | weight_decay=opts.wd) 55 | unq, inv, cnt = np.unique(labels, 56 | return_inverse=True, 57 | return_counts=True) 58 | lid = np.split(np.argsort(inv), np.cumsum(cnt[:-1])) 59 | N = X.shape[0] 60 | er = 0 61 | nlabels = len(lid) 62 | llid = np.zeros(nlabels).astype('int') 63 | for i in range(nlabels): 64 | llid[i] = len(lid[i]) 65 | t0 = time.time() 66 | for it in range(opts.maxiter): 67 | dt = opts.lr 68 | model.train() 69 | optimizer.zero_grad() 70 | idx = np.random.randint(0,nlabels,opts.batchsize) 71 | for t in range(opts.batchsize): 72 | i = idx[t] 73 | idx[t] = lid[i][np.random.randint(0,llid[i])] 74 | x = Variable(torch.Tensor(X[idx])) 75 | y = Variable(torch.from_numpy(labels[idx]).long()) 76 | yhat = model(x) 77 | # print(x.size(),y.size(),yhat.size()) 78 | loss = F.nll_loss(yhat, y) 79 | er = er + loss.data.item() 80 | loss.backward() 81 | optimizer.step() 82 | 83 | if it % opts.verbose == 1: 84 | print(er/opts.verbose) 85 | er = er/opts.verbose 86 | if (it+1) % freq == 0: 87 | print('[%.3fs] iteration %d' % (time.time() - t0, it + 1), end=' ') 88 | test_acc = validate(Xte, Yte, model) 89 | train_acc = validate(Xtr, Ytr, model) 90 | print ('all acc is: ', (test_acc * len(Yte) - train_acc * len(Ytr))/ (len(Yte)-len(Ytr))) 91 | return er 92 | 93 | 94 | def validate(tX, tlabels, model): 95 | """ compute top-1 and top-5 error of the model """ 96 | model.eval() 97 | N = tX.shape[0] 98 | batch = tX 99 | batchlabels = tlabels 100 | x = Variable(torch.Tensor(batch), volatile=True) 101 | yhat = model(x) 102 | values, indices = torch.max(yhat, 1) 103 | indices = indices.data.numpy() 104 | print(indices.shape, tlabels.shape) 105 | print('################### acc is : ', sum(indices==tlabels)/len(tlabels)) 106 | return sum(indices==tlabels)/len(tlabels) 107 | 108 | class Net(nn.Module): 109 | def __init__(self, d, nclasses): 110 | super(Net, self).__init__() 111 | self.l1 = nn.Linear(d,nclasses) 112 | def forward(self, x): 113 | return F.log_softmax(self.l1(x)) 114 | 115 | 116 | 117 | 118 | ##########################################################3 119 | # 120 | 121 | 122 | 123 | parser = argparse.ArgumentParser() 124 | parser.add_argument('--nlabeled', type=int, default = 5) 125 | parser.add_argument('--maxiter', type=int, default = 20000) 126 | parser.add_argument('--batchsize', type=int, default=128) 127 | parser.add_argument('--verbose', type=int, default=100) 128 | parser.add_argument('--pcadim', type=int, default=2048) 129 | parser.add_argument('--seed', type=int, default=123) 130 | parser.add_argument('--lr', type=float, default=.005) 131 | parser.add_argument('--wd', type=float, default=0.0) 132 | parser.add_argument('--mom', type=float, default=0.0) 133 | parser.add_argument('--mode', default='val') 134 | parser.add_argument('--storemodel', type=str, default='lr_model_15.pth') 135 | parser.add_argument('--storeL', type=str, default='') 136 | 137 | 138 | opts = parser.parse_args() 139 | mode = opts.mode 140 | 141 | 142 | 143 | df=pd.read_csv('feature15.txt', header=None, sep=' ') 144 | 145 | 146 | df = np.array(df) 147 | Xtr = df[:,:-1] 148 | Ytr = np.array(df[:,-1], dtype=np.int) 149 | 150 | 151 | print('load train') 152 | 153 | 154 | df=pd.read_csv('feature15_all.txt', header=None, sep=' ') 155 | 156 | 157 | df = np.array(df) 158 | Xte = df[:,:-1] 159 | Yte = np.array(df[:,-1], dtype=np.int) 160 | 161 | 162 | 163 | print('load test') 164 | 165 | 166 | 167 | nclasses = 359 168 | Xtr_orig = Xtr 169 | Xte_orig = Xte 170 | 171 | print("dataset sizes: Xtr %s (%d labeled), Xte %s, %d classes (eval on %s)" % ( 172 | Xtr.shape, (Ytr >= 0).sum(), 173 | Xte.shape, nclasses, Yte)) 174 | 175 | net= Net(opts.pcadim, nclasses) 176 | 177 | 178 | print('============== start logreg') 179 | 180 | 181 | if opts.mode == 'val': 182 | eval_freq = 500 183 | else: 184 | eval_freq = opts.maxiter 185 | 186 | train_balanced(net, Xtr, Ytr, opts, eval_freq) 187 | 188 | # if opts.storeL: 189 | # L = validate(Xte, Yte, net) 190 | # print('writing', opts.storeL) 191 | # np.save(opts.storeL, L) 192 | 193 | if opts.storemodel: 194 | print('writing', opts.storemodel) 195 | torch.save(net.state_dict(), opts.storemodel) 196 | 197 | 198 | -------------------------------------------------------------------------------- /dataset/tiered-imagenet/label_dict.txt: -------------------------------------------------------------------------------- 1 | n02105251 0 2 | n07873807 1 3 | n02536864 2 4 | n09468604 3 5 | n09256479 4 6 | n02108551 5 7 | n07714571 6 8 | n07615774 7 9 | n07734744 8 10 | n02236044 9 11 | n03063599 10 12 | n02823428 11 13 | n07720875 12 14 | n07753592 13 15 | n02219486 14 16 | n07836838 15 17 | n07880968 16 18 | n02174001 17 19 | n02815834 18 20 | n02190166 19 21 | n09193705 20 22 | n02788148 21 23 | n09428293 22 24 | n03314780 23 25 | n04557648 24 26 | n07745940 25 27 | n09288635 26 28 | n02104365 27 29 | n03633091 28 30 | n01496331 29 31 | n03160309 30 32 | n02640242 31 33 | n07892512 32 34 | n03063689 33 35 | n07768694 34 36 | n07565083 35 37 | n02105412 36 38 | n04562935 37 39 | n07584110 38 40 | n02108915 39 41 | n02909870 40 42 | n07730033 41 43 | n01498041 42 44 | n07754684 43 45 | n02105855 44 46 | n02277742 45 47 | n04398044 46 48 | n07831146 47 49 | n02105505 48 50 | n02110063 49 51 | n02939185 50 52 | n07760859 51 53 | n02105162 52 54 | n02110627 53 55 | n02256656 54 56 | n02169497 55 57 | n04553703 56 58 | n02165456 57 59 | n07930864 58 60 | n07583066 59 61 | n07716358 60 62 | n02606052 61 63 | n02167151 62 64 | n02655020 63 65 | n02108000 64 66 | n02808440 65 67 | n07747607 66 68 | n03937543 67 69 | n07579787 68 70 | n04326547 69 71 | n04560804 70 72 | n02268853 71 73 | n07860988 72 74 | n03000134 73 75 | n07802026 74 76 | n02106382 75 77 | n03930313 76 78 | n02105641 77 79 | n02526121 78 80 | n07718747 79 81 | n03983396 80 82 | n07932039 81 83 | n07715103 82 84 | n02264363 83 85 | n02206856 84 86 | n01440764 85 87 | n02281406 86 88 | n09246464 87 89 | n07614500 88 90 | n02104029 89 91 | n02107142 90 92 | n09332890 91 93 | n07613480 92 94 | n04501370 93 95 | n02795169 94 96 | n02280649 95 97 | n04493381 96 98 | n07742313 97 99 | n09399592 98 100 | n09472597 99 101 | n02514041 100 102 | n04522168 101 103 | n02106166 102 104 | n02106030 103 105 | n01491361 104 106 | n04239074 105 107 | n02165105 106 108 | n02106550 107 109 | n02109525 108 110 | n02279972 109 111 | n02107574 110 112 | n02607072 111 113 | n02268443 112 114 | n02109961 113 115 | n04604644 114 116 | n02281787 115 117 | n04579145 116 118 | n02276258 117 119 | n03950228 118 120 | n03459775 119 121 | n07714990 120 122 | n07753113 121 123 | n07875152 122 124 | n02643566 123 125 | n02177972 124 126 | n02106662 125 127 | n07717410 126 128 | n09421951 127 129 | n02168699 128 130 | n02110185 129 131 | n01484850 130 132 | n07697537 131 133 | n04591713 132 134 | n07716906 133 135 | n02107908 134 136 | n02105056 135 137 | n02229544 136 138 | n02108089 137 139 | n01443537 138 140 | n02894605 139 141 | n07749582 140 142 | n02108422 141 143 | n02259212 142 144 | n07590611 143 145 | n07717556 144 146 | n07697313 145 147 | n02107683 146 148 | n02107312 147 149 | n02233338 148 150 | n02109047 149 151 | n02641379 150 152 | n07718472 151 153 | n01494475 152 154 | n07920052 153 155 | n02226429 154 156 | n03786901 155 157 | n02231487 156 158 | n04049303 157 159 | n02172182 158 160 | n07753275 159 161 | n03877845 160 162 | n02097130 161 163 | n01682714 162 164 | n03529860 163 165 | n04147183 164 166 | n04590129 165 167 | n01693334 166 168 | n01534433 167 169 | n01582220 168 170 | n02389026 169 171 | n03461385 170 172 | n02992529 171 173 | n02397096 172 174 | n01558993 173 175 | n02486261 174 176 | n01675722 175 177 | n01748264 176 178 | n03075370 177 179 | n04536866 178 180 | n03347037 179 181 | n03967562 180 182 | n03208938 181 183 | n02492035 182 184 | n03617480 183 185 | n02493793 184 186 | n03249569 185 187 | n02093256 186 188 | n02124075 187 189 | n01530575 188 190 | n03372029 189 191 | n02422106 190 192 | n03657121 191 193 | n02088094 192 194 | n03394916 193 195 | n01692333 194 196 | n02090622 195 197 | n02011460 196 198 | n01729322 197 199 | n02988304 198 200 | n04162706 199 201 | n02097047 200 202 | n03447447 201 203 | n04487394 202 204 | n04443257 203 205 | n02692877 204 206 | n02088632 205 207 | n02018795 206 208 | n02979186 207 209 | n02487347 208 210 | n03495258 209 211 | n04591157 210 212 | n02793495 211 213 | n03187595 212 214 | n04435653 213 215 | n04154565 214 216 | n02088364 215 217 | n04548280 216 218 | n01753488 217 219 | n02028035 218 220 | n03840681 219 221 | n02483708 220 222 | n02492660 221 223 | n02802426 222 224 | n03452741 223 225 | n03325584 224 226 | n01677366 225 227 | n02058221 226 228 | n04355338 227 229 | n04266014 228 230 | n02690373 229 231 | n02093859 230 232 | n03344393 231 233 | n04479046 232 234 | n04192698 233 235 | n02412080 234 236 | n03662601 235 237 | n03884397 236 238 | n02490219 237 239 | n01855032 238 240 | n01755581 239 241 | n04023962 240 242 | n03673027 241 243 | n03627232 242 244 | n02895154 243 245 | n02129604 244 246 | n04462240 245 247 | n03854065 246 248 | n02488702 247 249 | n02992211 248 250 | n01592084 249 251 | n02437312 250 252 | n01694178 251 253 | n02672831 252 254 | n02096051 253 255 | n03089624 254 256 | n02776631 255 257 | n01531178 256 258 | n02089973 257 259 | n03866082 258 260 | n02129165 259 261 | n02676566 260 262 | n02091032 261 263 | n02804610 262 264 | n04347754 263 265 | n02497673 264 266 | n02403003 265 267 | n02096294 266 268 | n02883205 267 269 | n04376876 268 270 | n03544143 269 271 | n03970156 270 272 | n01688243 271 273 | n04005630 272 274 | n03467068 273 275 | n04273569 274 276 | n04141975 275 277 | n02879718 276 278 | n03445777 277 279 | n02128385 278 280 | n04141076 279 281 | n03000247 280 282 | n04392985 281 283 | n03032252 282 284 | n02423022 283 285 | n01728572 284 286 | n01735189 285 287 | n04483307 286 288 | n02088238 287 289 | n02951358 288 290 | n03146219 289 291 | n02837789 290 292 | n01537544 291 293 | n02097298 292 294 | n02018207 293 295 | n02128757 294 296 | n03498962 295 297 | n03595614 296 298 | n01734418 297 299 | n02410509 298 300 | n02092002 299 301 | n02056570 300 302 | n04346328 301 303 | n03220513 302 304 | n04009552 303 305 | n04356056 304 306 | n02091635 305 307 | n01751748 306 308 | n03630383 307 309 | n03777754 308 310 | n03788195 309 311 | n02486410 310 312 | n03109150 311 313 | n03534580 312 314 | n02096585 313 315 | n03902125 314 316 | n03857828 315 317 | n03095699 316 318 | n04136333 317 319 | n03770439 318 320 | n02123045 319 321 | n02794156 320 322 | n02096437 321 323 | n03197337 322 324 | n02484975 323 325 | n04606251 324 326 | n02910353 325 327 | n04317175 326 328 | n03841143 327 329 | n04153751 328 330 | n02037110 329 331 | n04371430 330 332 | n02033041 331 333 | n04409515 332 334 | n04507155 333 335 | n02009229 334 336 | n02437616 335 337 | n02093991 336 338 | n02130308 337 339 | n03110669 338 340 | n01744401 339 341 | n02090379 340 342 | n02127052 341 343 | n03424325 342 344 | n02892767 343 345 | n02006656 344 346 | n04118776 345 347 | n01855672 346 348 | n02098413 347 349 | n02481823 348 350 | n02092339 349 351 | n03649909 350 352 | -------------------------------------------------------------------------------- /prototypical_module/logreg.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | from torch.autograd import Variable 7 | import os, sys 8 | import time 9 | import numpy as np 10 | import random 11 | 12 | def get_feature(root='./features/test/', k=5): 13 | Xtr = [] 14 | Xte = [] 15 | Ytr = [] 16 | Yte = [] 17 | train = [] 18 | for i in range(359): 19 | fw = open(os.path.join(root, str(i)+'.txt'), 'r') 20 | lines = fw.readlines() 21 | for j in range(len(lines)): 22 | line = lines[j].strip() 23 | features = map(float, line.split()) 24 | if j < k: 25 | train.append((features, i)) 26 | else: 27 | Xte.append(features) 28 | Yte.append(i) 29 | random.shuffle(train) 30 | for i in range(len(train)): 31 | Xtr.append(train[i][0]) 32 | Ytr.append(train[i][1]) 33 | print Ytr 34 | Xtr = np.array(Xtr, np.float) 35 | Xte = np.array(Xte, np.float) 36 | Ytr = np.array(Ytr, np.int) 37 | Yte = np.array(Yte, np.int) 38 | return Xtr, Xte, Ytr, Yte 39 | 40 | 41 | def train(model, X, labels, opts): 42 | optimizer = optim.SGD(model.parameters(), 43 | lr = opts.lr, 44 | momentum = opts.mom, 45 | weight_decay=opts.wd) 46 | N = X.shape[0] 47 | er = 0 48 | for it in range(0, opts.maxiter): 49 | dt = opts.lr 50 | model.train() 51 | optimizer.zero_grad() 52 | 53 | idx = np.random.randint(0,N,opts.batchsize) 54 | x = Variable(torch.Tensor(X[idx])) 55 | y = Variable(torch.from_numpy(labels[idx]).long()) 56 | yhat = model(x) 57 | print(x.size(),y.size(),yhat.size()) 58 | loss = F.nll_loss(yhat, y) 59 | er = er + loss.data.item() 60 | loss.backward() 61 | optimizer.step() 62 | 63 | if it % opts.verbose == 1: 64 | print(er/opts.verbose) 65 | er = 0 66 | return er/opts.verbose 67 | 68 | 69 | 70 | def train_balanced(model, X, labels, opts, freq): 71 | optimizer = optim.SGD(model.parameters(), 72 | lr = opts.lr, 73 | momentum = opts.mom, 74 | weight_decay=opts.wd) 75 | unq, inv, cnt = np.unique(labels, 76 | return_inverse=True, 77 | return_counts=True) 78 | lid = np.split(np.argsort(inv), np.cumsum(cnt[:-1])) 79 | N = X.shape[0] 80 | er = 0 81 | nlabels = len(lid) 82 | llid = np.zeros(nlabels).astype('int') 83 | for i in range(nlabels): 84 | llid[i] = len(lid[i]) 85 | t0 = time.time() 86 | model = model.cuda() 87 | for it in range(opts.maxiter): 88 | dt = opts.lr 89 | model.train() 90 | optimizer.zero_grad() 91 | idx = np.random.randint(0,nlabels,opts.batchsize) 92 | for t in range(opts.batchsize): 93 | i = idx[t] 94 | idx[t] = lid[i][np.random.randint(0,llid[i])] 95 | x = Variable(torch.Tensor(X[idx])) 96 | y = Variable(torch.from_numpy(labels[idx]).long()) 97 | x = x.cuda() 98 | y = y.cuda() 99 | yhat = model(x) 100 | # print(x.size(),y.size(),yhat.size()) 101 | loss = F.nll_loss(yhat, y) 102 | er = er + loss.data.item() 103 | loss.backward() 104 | optimizer.step() 105 | 106 | if it % opts.verbose == 1: 107 | print(er/opts.verbose) 108 | er = er/opts.verbose 109 | if (it+1) % freq == 0: 110 | print('[%.3fs] iteration %d' % (time.time() - t0, it + 1)) 111 | train_acc = validate(Xtr, Ytr, model) 112 | test_acc = validate(Xte, Yte, model) 113 | return er 114 | 115 | 116 | def validate(tX, tlabels, model): 117 | """ compute top-1 and top-5 error of the model """ 118 | model.eval() 119 | N = tX.shape[0] 120 | batch = tX 121 | batchlabels = tlabels 122 | x = Variable(torch.Tensor(batch), volatile=True) 123 | x = x.cuda() 124 | yhat = model(x) 125 | values, indices = torch.max(yhat, 1) 126 | indices = indices.data.cpu().numpy() 127 | print(indices.shape, tlabels.shape) 128 | print('################### acc is : ', float(sum(indices==tlabels))/float(len(tlabels))) 129 | return float(sum(indices==tlabels))/float(len(tlabels)) 130 | 131 | class Net(nn.Module): 132 | def __init__(self, d, nclasses): 133 | super(Net, self).__init__() 134 | self.l1 = nn.Linear(d,nclasses) 135 | def forward(self, x): 136 | return F.log_softmax(self.l1(x)) 137 | 138 | 139 | 140 | 141 | ##########################################################3 142 | # 143 | 144 | 145 | 146 | parser = argparse.ArgumentParser() 147 | parser.add_argument('--nlabeled', type=int, default = 2) 148 | parser.add_argument('--maxiter', type=int, default = 1500) 149 | parser.add_argument('--batchsize', type=int, default=128) 150 | parser.add_argument('--verbose', type=int, default=500) 151 | parser.add_argument('--pcadim', type=int, default=2048) 152 | parser.add_argument('--seed', type=int, default=123) 153 | parser.add_argument('--lr', type=float, default=.01) 154 | parser.add_argument('--wd', type=float, default=0.0) 155 | parser.add_argument('--mom', type=float, default=0.0) 156 | parser.add_argument('--mode', default='val') 157 | parser.add_argument('--dataset', default='test') 158 | parser.add_argument('--storemodel', type=str, default='') 159 | parser.add_argument('--storeL', type=str, default='') 160 | 161 | 162 | opts = parser.parse_args() 163 | mode = opts.mode 164 | 165 | root = os.path.join('./features', opts.dataset) 166 | Xtr, Xte, Ytr, Yte = get_feature(root=root, k=opts.nlabeled) 167 | 168 | 169 | nclasses = 359 170 | Xtr_orig = Xtr 171 | Xte_orig = Xte 172 | 173 | print("dataset sizes: Xtr %s (%d labeled), Xte %s, %d classes (eval on %s)" % ( 174 | Xtr.shape, (Ytr >= 0).sum(), 175 | Xte.shape, nclasses, Yte)) 176 | 177 | net= Net(opts.pcadim, nclasses) 178 | 179 | 180 | print('============== start logreg') 181 | 182 | 183 | if opts.mode == 'val': 184 | eval_freq = 500 185 | else: 186 | eval_freq = opts.maxiter 187 | 188 | train_balanced(net, Xtr, Ytr, opts, eval_freq) 189 | 190 | # if opts.storeL: 191 | # L = validate(Xte, Yte, net) 192 | # print('writing', opts.storeL) 193 | # np.save(opts.storeL, L) 194 | 195 | if opts.storemodel: 196 | print('writing', opts.storemodel) 197 | torch.save(net.state_dict(), opts.storemodel) -------------------------------------------------------------------------------- /prototypical_module/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.utils.model_zoo as model_zoo 4 | import torch.nn.functional as F 5 | 6 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 7 | 'resnet152'] 8 | 9 | 10 | model_urls = { 11 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 12 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 13 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 14 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 15 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 16 | } 17 | 18 | 19 | def conv3x3(in_planes, out_planes, stride=1): 20 | """3x3 convolution with padding""" 21 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 22 | padding=1, bias=False) 23 | 24 | 25 | def conv1x1(in_planes, out_planes, stride=1): 26 | """1x1 convolution""" 27 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 28 | 29 | 30 | class BasicBlock(nn.Module): 31 | expansion = 1 32 | 33 | def __init__(self, inplanes, planes, stride=1, downsample=None): 34 | super(BasicBlock, self).__init__() 35 | self.conv1 = conv3x3(inplanes, planes, stride) 36 | self.bn1 = nn.BatchNorm2d(planes) 37 | self.relu = nn.ReLU(inplace=True) 38 | self.conv2 = conv3x3(planes, planes) 39 | self.bn2 = nn.BatchNorm2d(planes) 40 | self.downsample = downsample 41 | self.stride = stride 42 | 43 | def forward(self, x): 44 | identity = x 45 | 46 | out = self.conv1(x) 47 | out = self.bn1(out) 48 | out = self.relu(out) 49 | 50 | out = self.conv2(out) 51 | out = self.bn2(out) 52 | 53 | if self.downsample is not None: 54 | identity = self.downsample(x) 55 | 56 | out += identity 57 | out = self.relu(out) 58 | 59 | return out 60 | 61 | 62 | class Bottleneck(nn.Module): 63 | expansion = 4 64 | 65 | def __init__(self, inplanes, planes, stride=1, downsample=None): 66 | super(Bottleneck, self).__init__() 67 | self.conv1 = conv1x1(inplanes, planes) 68 | self.bn1 = nn.BatchNorm2d(planes) 69 | self.conv2 = conv3x3(planes, planes, stride) 70 | self.bn2 = nn.BatchNorm2d(planes) 71 | self.conv3 = conv1x1(planes, planes * self.expansion) 72 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 73 | self.relu = nn.ReLU(inplace=True) 74 | self.downsample = downsample 75 | self.stride = stride 76 | 77 | def forward(self, x): 78 | identity = x 79 | 80 | out = self.conv1(x) 81 | out = self.bn1(out) 82 | out = self.relu(out) 83 | 84 | out = self.conv2(out) 85 | out = self.bn2(out) 86 | out = self.relu(out) 87 | 88 | out = self.conv3(out) 89 | out = self.bn3(out) 90 | 91 | if self.downsample is not None: 92 | identity = self.downsample(x) 93 | 94 | out += identity 95 | out = self.relu(out) 96 | 97 | return out 98 | 99 | 100 | class ResNet(nn.Module): 101 | 102 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False): 103 | super(ResNet, self).__init__() 104 | self.inplanes = 64 105 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 106 | bias=False) 107 | self.bn1 = nn.BatchNorm2d(64) 108 | self.relu = nn.ReLU(inplace=True) 109 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 110 | self.layer1 = self._make_layer(block, 64, layers[0]) 111 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 112 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 113 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 114 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 115 | self.fc1 = nn.Linear(512 * block.expansion, 1024) 116 | self.fc2 = nn.Linear(1024, 1) 117 | 118 | for m in self.modules(): 119 | if isinstance(m, nn.Conv2d): 120 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 121 | elif isinstance(m, nn.BatchNorm2d): 122 | nn.init.constant_(m.weight, 1) 123 | nn.init.constant_(m.bias, 0) 124 | 125 | # Zero-initialize the last BN in each residual branch, 126 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 127 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 128 | if zero_init_residual: 129 | for m in self.modules(): 130 | if isinstance(m, Bottleneck): 131 | nn.init.constant_(m.bn3.weight, 0) 132 | elif isinstance(m, BasicBlock): 133 | nn.init.constant_(m.bn2.weight, 0) 134 | 135 | def _make_layer(self, block, planes, blocks, stride=1): 136 | downsample = None 137 | if stride != 1 or self.inplanes != planes * block.expansion: 138 | downsample = nn.Sequential( 139 | conv1x1(self.inplanes, planes * block.expansion, stride), 140 | nn.BatchNorm2d(planes * block.expansion), 141 | ) 142 | 143 | layers = [] 144 | layers.append(block(self.inplanes, planes, stride, downsample)) 145 | self.inplanes = planes * block.expansion 146 | for _ in range(1, blocks): 147 | layers.append(block(self.inplanes, planes)) 148 | 149 | return nn.Sequential(*layers) 150 | 151 | def forward(self, x): 152 | x = self.conv1(x) 153 | x = self.bn1(x) 154 | x = self.relu(x) 155 | x = self.maxpool(x) 156 | 157 | x = self.layer1(x) 158 | x = self.layer2(x) 159 | x = self.layer3(x) 160 | x = self.layer4(x) 161 | 162 | x = self.avgpool(x) 163 | x = x.view(x.size(0), -1) 164 | return x 165 | 166 | 167 | def resnet18(pretrained=False, **kwargs): 168 | """Constructs a ResNet-18 model. 169 | Args: 170 | pretrained (bool): If True, returns a model pre-trained on ImageNet 171 | """ 172 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 173 | if pretrained: 174 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 175 | return model 176 | 177 | 178 | def resnet34(pretrained=False, **kwargs): 179 | """Constructs a ResNet-34 model. 180 | Args: 181 | pretrained (bool): If True, returns a model pre-trained on ImageNet 182 | """ 183 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 184 | if pretrained: 185 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 186 | return model 187 | 188 | 189 | def resnet50(pretrained=False, **kwargs): 190 | """Constructs a ResNet-50 model. 191 | Args: 192 | pretrained (bool): If True, returns a model pre-trained on ImageNet 193 | """ 194 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 195 | if pretrained: 196 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 197 | return model 198 | 199 | 200 | def resnet101(pretrained=False, **kwargs): 201 | """Constructs a ResNet-101 model. 202 | Args: 203 | pretrained (bool): If True, returns a model pre-trained on ImageNet 204 | """ 205 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 206 | if pretrained: 207 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 208 | return model 209 | 210 | 211 | def resnet152(pretrained=False, **kwargs): 212 | """Constructs a ResNet-152 model. 213 | Args: 214 | pretrained (bool): If True, returns a model pre-trained on ImageNet 215 | """ 216 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 217 | if pretrained: 218 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 219 | return model -------------------------------------------------------------------------------- /domain_adaptive_module/pre_process.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torchvision import transforms 3 | import os 4 | from PIL import Image, ImageOps 5 | import numbers 6 | import torch 7 | 8 | class ResizeImage(): 9 | def __init__(self, size): 10 | if isinstance(size, int): 11 | self.size = (int(size), int(size)) 12 | else: 13 | self.size = size 14 | def __call__(self, img): 15 | th, tw = self.size 16 | return img.resize((th, tw)) 17 | 18 | class RandomSizedCrop(object): 19 | """Crop the given PIL.Image to random size and aspect ratio. 20 | A crop of random size of (0.08 to 1.0) of the original size and a random 21 | aspect ratio of 3/4 to 4/3 of the original aspect ratio is made. This crop 22 | is finally resized to given size. 23 | This is popularly used to train the Inception networks. 24 | Args: 25 | size: size of the smaller edge 26 | interpolation: Default: PIL.Image.BILINEAR 27 | """ 28 | 29 | def __init__(self, size, interpolation=Image.BILINEAR): 30 | self.size = size 31 | self.interpolation = interpolation 32 | 33 | def __call__(self, img): 34 | h_off = random.randint(0, img.shape[1]-self.size) 35 | w_off = random.randint(0, img.shape[2]-self.size) 36 | img = img[:, h_off:h_off+self.size, w_off:w_off+self.size] 37 | return img 38 | 39 | 40 | class Normalize(object): 41 | """Normalize an tensor image with mean and standard deviation. 42 | Given mean: (R, G, B), 43 | will normalize each channel of the torch.*Tensor, i.e. 44 | channel = channel - mean 45 | Args: 46 | mean (sequence): Sequence of means for R, G, B channels respecitvely. 47 | """ 48 | 49 | def __init__(self, mean=None, meanfile=None): 50 | if mean: 51 | self.mean = mean 52 | else: 53 | arr = np.load(meanfile) 54 | self.mean = torch.from_numpy(arr.astype('float32')/255.0)[[2,1,0],:,:] 55 | 56 | def __call__(self, tensor): 57 | """ 58 | Args: 59 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 60 | Returns: 61 | Tensor: Normalized image. 62 | """ 63 | # TODO: make efficient 64 | for t, m in zip(tensor, self.mean): 65 | t.sub_(m) 66 | return tensor 67 | 68 | 69 | 70 | class PlaceCrop(object): 71 | """Crops the given PIL.Image at the particular index. 72 | Args: 73 | size (sequence or int): Desired output size of the crop. If size is an 74 | int instead of sequence like (w, h), a square crop (size, size) is 75 | made. 76 | """ 77 | 78 | def __init__(self, size, start_x, start_y): 79 | if isinstance(size, int): 80 | self.size = (int(size), int(size)) 81 | else: 82 | self.size = size 83 | self.start_x = start_x 84 | self.start_y = start_y 85 | 86 | def __call__(self, img): 87 | """ 88 | Args: 89 | img (PIL.Image): Image to be cropped. 90 | Returns: 91 | PIL.Image: Cropped image. 92 | """ 93 | th, tw = self.size 94 | return img.crop((self.start_x, self.start_y, self.start_x + tw, self.start_y + th)) 95 | 96 | 97 | class ForceFlip(object): 98 | """Horizontally flip the given PIL.Image randomly with a probability of 0.5.""" 99 | 100 | def __call__(self, img): 101 | """ 102 | Args: 103 | img (PIL.Image): Image to be flipped. 104 | Returns: 105 | PIL.Image: Randomly flipped image. 106 | """ 107 | return img.transpose(Image.FLIP_LEFT_RIGHT) 108 | 109 | class CenterCrop(object): 110 | """Crops the given PIL.Image at the center. 111 | Args: 112 | size (sequence or int): Desired output size of the crop. If size is an 113 | int instead of sequence like (h, w), a square crop (size, size) is 114 | made. 115 | """ 116 | 117 | def __init__(self, size): 118 | if isinstance(size, numbers.Number): 119 | self.size = (int(size), int(size)) 120 | else: 121 | self.size = size 122 | 123 | def __call__(self, img): 124 | """ 125 | Args: 126 | img (PIL.Image): Image to be cropped. 127 | Returns: 128 | PIL.Image: Cropped image. 129 | """ 130 | w, h = (img.shape[1], img.shape[2]) 131 | th, tw = self.size 132 | w_off = int((w - tw) / 2.) 133 | h_off = int((h - th) / 2.) 134 | img = img[:, h_off:h_off+th, w_off:w_off+tw] 135 | return img 136 | 137 | 138 | def image_train(resize_size=256, crop_size=224, alexnet=False): 139 | if not alexnet: 140 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 141 | std=[0.229, 0.224, 0.225]) 142 | else: 143 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy') 144 | return transforms.Compose([ 145 | ResizeImage(resize_size), 146 | transforms.RandomResizedCrop(crop_size), 147 | transforms.RandomHorizontalFlip(), 148 | transforms.ToTensor(), 149 | normalize 150 | ]) 151 | 152 | def image_test(resize_size=256, crop_size=224, alexnet=False): 153 | if not alexnet: 154 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 155 | std=[0.229, 0.224, 0.225]) 156 | else: 157 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy') 158 | start_first = 0 159 | start_center = (resize_size - crop_size - 1) / 2 160 | start_last = resize_size - crop_size - 1 161 | 162 | return transforms.Compose([ 163 | ResizeImage(resize_size), 164 | PlaceCrop(crop_size, start_center, start_center), 165 | transforms.ToTensor(), 166 | normalize 167 | ]) 168 | 169 | def image_test_10crop(resize_size=256, crop_size=224, alexnet=False): 170 | if not alexnet: 171 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 172 | std=[0.229, 0.224, 0.225]) 173 | else: 174 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy') 175 | start_first = 0 176 | start_center = (resize_size - crop_size - 1) / 2 177 | start_last = resize_size - crop_size - 1 178 | data_transforms = [ 179 | transforms.Compose([ 180 | ResizeImage(resize_size),ForceFlip(), 181 | PlaceCrop(crop_size, start_first, start_first), 182 | transforms.ToTensor(), 183 | normalize 184 | ]), 185 | transforms.Compose([ 186 | ResizeImage(resize_size),ForceFlip(), 187 | PlaceCrop(crop_size, start_last, start_last), 188 | transforms.ToTensor(), 189 | normalize 190 | ]), 191 | transforms.Compose([ 192 | ResizeImage(resize_size),ForceFlip(), 193 | PlaceCrop(crop_size, start_last, start_first), 194 | transforms.ToTensor(), 195 | normalize 196 | ]), 197 | transforms.Compose([ 198 | ResizeImage(resize_size),ForceFlip(), 199 | PlaceCrop(crop_size, start_first, start_last), 200 | transforms.ToTensor(), 201 | normalize 202 | ]), 203 | transforms.Compose([ 204 | ResizeImage(resize_size),ForceFlip(), 205 | PlaceCrop(crop_size, start_center, start_center), 206 | transforms.ToTensor(), 207 | normalize 208 | ]), 209 | transforms.Compose([ 210 | ResizeImage(resize_size), 211 | PlaceCrop(crop_size, start_first, start_first), 212 | transforms.ToTensor(), 213 | normalize 214 | ]), 215 | transforms.Compose([ 216 | ResizeImage(resize_size), 217 | PlaceCrop(crop_size, start_last, start_last), 218 | transforms.ToTensor(), 219 | normalize 220 | ]), 221 | transforms.Compose([ 222 | ResizeImage(resize_size), 223 | PlaceCrop(crop_size, start_last, start_first), 224 | transforms.ToTensor(), 225 | normalize 226 | ]), 227 | transforms.Compose([ 228 | ResizeImage(resize_size), 229 | PlaceCrop(crop_size, start_first, start_last), 230 | transforms.ToTensor(), 231 | normalize 232 | ]), 233 | transforms.Compose([ 234 | ResizeImage(resize_size), 235 | PlaceCrop(crop_size, start_center, start_center), 236 | transforms.ToTensor(), 237 | normalize 238 | ]) 239 | ] 240 | return data_transforms 241 | -------------------------------------------------------------------------------- /pretrain/main_resnet.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import time 3 | import argparse 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.parallel 8 | import torch.backends.cudnn as cudnn 9 | import torch.optim 10 | from torchvision import transforms 11 | from torch.utils.data import DataLoader 12 | 13 | from dataloader import * 14 | import resnet 15 | from resnet import ResNetFc 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--dir_path', default='DATASET_DIR') 19 | parser.add_argument('--arch', default='resnet50', 20 | choices=['resnet34', 'resnet50', 'resnet101', 'resnet152']) 21 | parser.add_argument('--workers', default=32, type=int, metavar='N', 22 | help='number of data loading workers (default: 8)') 23 | parser.add_argument('--epochs', default=100, type=int, metavar='N', 24 | help='number of total epochs to run') 25 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 26 | help='manual epoch number (useful on restarts)') 27 | parser.add_argument('--batch_size', default=4096, type=int, 28 | metavar='N', help='mini-batch size (default: 64)') 29 | parser.add_argument('--iter-size', default=4, type=int, 30 | metavar='I', help='iter size as in Caffe to reduce memory usage (default: 8)') 31 | parser.add_argument('--lr', '--learning-rate', default=0.01, type=float, 32 | metavar='LR', help='initial learning rate') 33 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 34 | help='momentum') 35 | parser.add_argument('--new_length', default=400, type=int) 36 | parser.add_argument('--new_width', default=400, type=int) 37 | 38 | parser.add_argument('--weight-decay', default=1e-4, type=float, 39 | metavar='W', help='weight decay (default: 1e-4)') 40 | parser.add_argument('--print-freq', default=10, type=int, 41 | metavar='N', help='print frequency (default: 20)') 42 | parser.add_argument('--save-freq', default=5, type=int, 43 | metavar='N', help='save frequency (default: 200)') 44 | parser.add_argument('--resume', default='output', type=str, metavar='PATH', 45 | help='path to latest checkpoint (default: none)') 46 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 47 | help='evaluate model on validation set') 48 | 49 | best_prec = 0 50 | 51 | 52 | def main(): 53 | global args, best_prec 54 | args = parser.parse_args() 55 | print ("Build model ...") 56 | model = build_model() 57 | if not os.path.exists(args.resume): 58 | os.makedirs(args.resume) 59 | print("Saving everything to directory %s." % (args.resume)) 60 | 61 | # define loss function (criterion) and optimizer 62 | criterion = nn.CrossEntropyLoss().cuda() 63 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 64 | momentum=args.momentum, 65 | weight_decay=args.weight_decay) 66 | cudnn.benchmark = True 67 | 68 | # data transform 69 | mean = [0.485, 0.456, 0.406] 70 | std = [0.229, 0.224, 0.225] 71 | train_transform = transforms.Compose([ 72 | transforms.RandomRotation(20), 73 | transforms.RandomResizedCrop(84, scale=(0.8, 1.2), ratio=(0.75, 1.3333333333333333), interpolation=2), 74 | #transforms.RandomCrop(84), 75 | transforms.RandomHorizontalFlip(), 76 | # transforms.RandomVerticalFlip(), 77 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.4), 78 | transforms.ToTensor(), 79 | transforms.Normalize(mean=mean, std=std)]) 80 | val_transform = transforms.Compose([ 81 | transforms.ToTensor(), 82 | transforms.Normalize(mean=mean, std=std)]) 83 | 84 | train_data = MyDataset(os.path.join(args.dir_path, 'trainval_list.txt'), args.dir_path, args.new_width, args.new_length,train_transform) 85 | # val_data = MyDataset(os.path.join(args.dir_path, 'val.txt'), args.dir_path, args.new_width, args.new_length, val_transform) 86 | 87 | train_loader = DataLoader(dataset=train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, 88 | pin_memory=True) 89 | # val_loader = DataLoader(dataset=val_data, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, 90 | #pin_memory=True) 91 | 92 | for epoch in range(args.start_epoch, args.epochs): 93 | print ('epoch: ' + str(epoch + 1)) 94 | adjust_learning_rate(optimizer, epoch) 95 | 96 | # train for one epoch 97 | train(train_loader, model, criterion, optimizer, epoch) 98 | 99 | # evaluate on validation set 100 | #prec = validate(val_loader, model, criterion) 101 | 102 | # remember best prec and save checkpoint 103 | #is_best = prec > best_prec 104 | #best_prec = max(prec, best_prec) 105 | 106 | if (epoch + 1) % args.save_freq == 0: 107 | checkpoint_name = "%03d_%s" % (epoch + 1, "checkpoint.pth.tar") 108 | save_checkpoint({ 109 | 'epoch': epoch + 1, 110 | 'arch': args.arch, 111 | 'state_dict': model.state_dict(), 112 | 'optimizer' : optimizer.state_dict(), 113 | }, checkpoint_name, args.resume) 114 | 115 | 116 | def build_model(): 117 | model = ResNetFc(class_num=448) 118 | model = torch.nn.DataParallel(model).cuda() 119 | # model = model.cuda() 120 | return model 121 | 122 | 123 | def train(train_loader, model, criterion, optimizer, epoch): 124 | batch_time = AverageMeter() 125 | data_time = AverageMeter() 126 | losses = AverageMeter() 127 | top1 = AverageMeter() 128 | 129 | # switch to train mode 130 | model.train() 131 | 132 | end = time.time() 133 | for i, (input, target, _) in enumerate(train_loader): 134 | # measure data loading time 135 | data_time.update(time.time() - end) 136 | 137 | input = input.float().cuda(async=True) 138 | #print(input.size(),target.size()) 139 | target = target.cuda(async=True) 140 | input_var = torch.autograd.Variable(input) 141 | target_var = torch.autograd.Variable(target) 142 | 143 | output = model(input_var) 144 | loss = criterion(output, target_var) 145 | 146 | # measure accuracy and record loss 147 | prec = accuracy(output.data, target) 148 | losses.update(loss.data.item(), input.size(0)) 149 | top1.update(prec.item(), input.size(0)) 150 | 151 | # compute gradient and do SGD step 152 | optimizer.zero_grad() 153 | loss.backward() 154 | optimizer.step() 155 | 156 | # measure elapsed time 157 | batch_time.update(time.time() - end) 158 | end = time.time() 159 | 160 | if i % args.print_freq == 0: 161 | print('Epoch: [{0}][{1}/{2}]\t' 162 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 163 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 164 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 165 | 'Prec {top1.val:.3f} ({top1.avg:.3f})\t'.format( 166 | epoch, i, len(train_loader), batch_time=batch_time, 167 | data_time=data_time, loss=losses, top1=top1)) 168 | 169 | 170 | def validate(val_loader, model, criterion): 171 | batch_time = AverageMeter() 172 | losses = AverageMeter() 173 | top1 = AverageMeter() 174 | 175 | # switch to evaluate mode 176 | model.eval() 177 | 178 | end = time.time() 179 | for i, (input, target, _) in enumerate(val_loader): 180 | input = input.float().cuda(async=True) 181 | target = target.cuda(async=True) 182 | input_var = torch.autograd.Variable(input, volatile=True) 183 | target_var = torch.autograd.Variable(target, volatile=True) 184 | 185 | # compute output 186 | output = model(input_var) 187 | loss = criterion(output, target_var) 188 | 189 | # measure accuracy and record loss 190 | prec = accuracy(output.data, target) 191 | losses.update(loss.data.item(), input.size(0)) 192 | top1.update(prec.item(), input.size(0)) 193 | 194 | # measure elapsed time 195 | batch_time.update(time.time() - end) 196 | end = time.time() 197 | 198 | if i % args.print_freq == 0: 199 | print('Test: [{0}/{1}]\t' 200 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 201 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 202 | 'Prec {top1.val:.3f} ({top1.avg:.3f})'.format( 203 | i, len(val_loader), batch_time=batch_time, loss=losses, 204 | top1=top1)) 205 | 206 | print(' * Prec {top1.avg:.3f} ' 207 | .format(top1=top1)) 208 | 209 | return top1.avg 210 | 211 | 212 | def save_checkpoint(state, filename, resume_path): 213 | cur_path = os.path.join(resume_path, filename) 214 | torch.save(state, cur_path) 215 | 216 | 217 | class AverageMeter(object): 218 | """Computes and stores the average and current value""" 219 | def __init__(self): 220 | self.reset() 221 | 222 | def reset(self): 223 | self.val = 0 224 | self.avg = 0 225 | self.sum = 0 226 | self.count = 0 227 | 228 | def update(self, val, n=1): 229 | self.val = val 230 | self.sum += val * n 231 | self.count += n 232 | self.avg = self.sum / self.count 233 | 234 | 235 | def adjust_learning_rate(optimizer, epoch): 236 | """Sets the learning rate to the initial LR decayed by 10 every 15 epochs""" 237 | lr = args.lr * (0.1 ** (epoch // 20)) 238 | for param_group in optimizer.param_groups: 239 | param_group['lr'] = lr 240 | # param_group['lr'] = param_group['lr']/2 241 | 242 | 243 | def accuracy(output, target): 244 | """Computes the precision@k for the specified values of k""" 245 | maxk = 1 246 | batch_size = target.size(0) 247 | 248 | _, pred = output.topk(maxk, 1, True, True) 249 | pred = pred.t() 250 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 251 | 252 | correct_k = correct[0].view(-1).float().sum(0) 253 | res = correct_k.mul_(100.0 / batch_size) 254 | return res 255 | 256 | if __name__ == '__main__': 257 | main() 258 | -------------------------------------------------------------------------------- /pretrain/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import math 4 | import torch.utils.model_zoo as model_zoo 5 | from torchvision import models 6 | 7 | 8 | def init_weights(m): 9 | classname = m.__class__.__name__ 10 | if classname.find('Conv2d') != -1 or classname.find('ConvTranspose2d') != -1: 11 | nn.init.kaiming_uniform_(m.weight) 12 | nn.init.zeros_(m.bias) 13 | elif classname.find('BatchNorm') != -1: 14 | nn.init.normal_(m.weight, 1.0, 0.02) 15 | nn.init.zeros_(m.bias) 16 | elif classname.find('Linear') != -1: 17 | nn.init.xavier_normal_(m.weight) 18 | nn.init.zeros_(m.bias) 19 | 20 | 21 | class ResNetFc(nn.Module): 22 | def __init__(self, class_num=1000): 23 | super(ResNetFc, self).__init__() 24 | model_resnet = models.resnet50(pretrained=True) 25 | self.conv1 = model_resnet.conv1 26 | self.bn1 = model_resnet.bn1 27 | self.relu = model_resnet.relu 28 | self.maxpool = model_resnet.maxpool 29 | self.layer1 = model_resnet.layer1 30 | self.layer2 = model_resnet.layer2 31 | self.layer3 = model_resnet.layer3 32 | self.layer4 = model_resnet.layer4 33 | self.avgpool = nn.AdaptiveMaxPool2d((1,1)) 34 | 35 | self.fc = nn.Linear(model_resnet.fc.in_features, class_num) 36 | self.fc.apply(init_weights) 37 | 38 | def forward(self, x): 39 | x = self.conv1(x) 40 | x = self.bn1(x) 41 | x = self.relu(x) 42 | x = self.maxpool(x) 43 | 44 | x = self.layer1(x) 45 | x = self.layer2(x) 46 | x = self.layer3(x) 47 | x = self.layer4(x) 48 | 49 | x = self.avgpool(x) 50 | x = x.view(x.size(0), -1) 51 | x = self.fc(x) 52 | return x 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | model_urls = { 62 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 63 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 64 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 65 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 66 | } 67 | 68 | 69 | def conv3x3(in_planes, out_planes, stride=1): 70 | "3x3 convolution with padding" 71 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 72 | padding=1, bias=False) 73 | 74 | 75 | class BasicBlock(nn.Module): 76 | expansion = 1 77 | 78 | def __init__(self, inplanes, planes, stride=1, downsample=None): 79 | super(BasicBlock, self).__init__() 80 | self.conv1 = conv3x3(inplanes, planes, stride) 81 | self.bn1 = nn.BatchNorm2d(planes) 82 | self.relu = nn.ReLU(inplace=True) 83 | self.conv2 = conv3x3(planes, planes) 84 | self.bn2 = nn.BatchNorm2d(planes) 85 | self.downsample = downsample 86 | self.stride = stride 87 | 88 | def forward(self, x): 89 | residual = x 90 | 91 | out = self.conv1(x) 92 | out = self.bn1(out) 93 | out = self.relu(out) 94 | 95 | out = self.conv2(out) 96 | out = self.bn2(out) 97 | 98 | if self.downsample is not None: 99 | residual = self.downsample(x) 100 | 101 | out += residual 102 | out = self.relu(out) 103 | 104 | return out 105 | 106 | 107 | class Bottleneck(nn.Module): 108 | expansion = 4 109 | 110 | def __init__(self, inplanes, planes, stride=1, downsample=None): 111 | super(Bottleneck, self).__init__() 112 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 113 | self.bn1 = nn.BatchNorm2d(planes) 114 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 115 | padding=1, bias=False) 116 | self.bn2 = nn.BatchNorm2d(planes) 117 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 118 | self.bn3 = nn.BatchNorm2d(planes * 4) 119 | self.relu = nn.ReLU(inplace=True) 120 | self.downsample = downsample 121 | self.stride = stride 122 | 123 | def forward(self, x): 124 | residual = x 125 | 126 | out = self.conv1(x) 127 | out = self.bn1(out) 128 | out = self.relu(out) 129 | 130 | out = self.conv2(out) 131 | out = self.bn2(out) 132 | out = self.relu(out) 133 | 134 | out = self.conv3(out) 135 | out = self.bn3(out) 136 | 137 | if self.downsample is not None: 138 | residual = self.downsample(x) 139 | 140 | out += residual 141 | out = self.relu(out) 142 | 143 | return out 144 | 145 | 146 | class ResNet(nn.Module): 147 | 148 | def __init__(self, block, layers, num_classes=128): 149 | self.inplanes = 64 150 | super(ResNet, self).__init__() 151 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 152 | bias=False) 153 | self.bn1 = nn.BatchNorm2d(64) 154 | self.relu = nn.ReLU(inplace=True) 155 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 156 | self.layer1 = self._make_layer(block, 64, layers[0]) 157 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 158 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 159 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 160 | self.avgpool = nn.AvgPool2d(7) 161 | 162 | # self.dp = nn.Dropout(p=0.8) 163 | self.fc_action = nn.Linear(512 * block.expansion, num_classes) 164 | 165 | for m in self.modules(): 166 | if isinstance(m, nn.Conv2d): 167 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 168 | m.weight.data.normal_(0, math.sqrt(2. / n)) 169 | elif isinstance(m, nn.BatchNorm2d): 170 | m.weight.data.fill_(1) 171 | m.bias.data.zero_() 172 | 173 | def _make_layer(self, block, planes, blocks, stride=1): 174 | downsample = None 175 | if stride != 1 or self.inplanes != planes * block.expansion: 176 | downsample = nn.Sequential( 177 | nn.Conv2d(self.inplanes, planes * block.expansion, 178 | kernel_size=1, stride=stride, bias=False), 179 | nn.BatchNorm2d(planes * block.expansion), 180 | ) 181 | 182 | layers = [] 183 | layers.append(block(self.inplanes, planes, stride, downsample)) 184 | self.inplanes = planes * block.expansion 185 | for i in range(1, blocks): 186 | layers.append(block(self.inplanes, planes)) 187 | 188 | return nn.Sequential(*layers) 189 | 190 | def forward(self, x): 191 | x = self.conv1(x) 192 | x = self.bn1(x) 193 | x = self.relu(x) 194 | x = self.maxpool(x) 195 | 196 | x = self.layer1(x) 197 | x = self.layer2(x) 198 | x = self.layer3(x) 199 | x = self.layer4(x) 200 | 201 | x = self.avgpool(x) 202 | x = x.view(x.size(0), -1) 203 | # x = self.dp(x) 204 | x = self.fc_action(x) 205 | 206 | return x 207 | 208 | 209 | def rgb_resnet18(pretrained=False, **kwargs): 210 | """Constructs a ResNet-18 model. 211 | Args: 212 | pretrained (bool): If True, returns a model pre-trained on ImageNet 213 | """ 214 | if pretrained: 215 | pretrained_dict = model_zoo.load_url(model_urls['resnet18']) 216 | model_dict = model.state_dict() 217 | # 1. filter out unnecessary keys 218 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 219 | # 2. overwrite entries in the existing state dict 220 | model_dict.update(pretrained_dict) 221 | # 3. load the new state dict 222 | model.load_state_dict(model_dict) 223 | return model 224 | 225 | 226 | def rgb_resnet34(pretrained=False, **kwargs): 227 | """Constructs a ResNet-34 model. 228 | Args: 229 | pretrained (bool): If True, returns a model pre-trained on ImageNet 230 | """ 231 | if pretrained: 232 | pretrained_dict = model_zoo.load_url(model_urls['resnet34']) 233 | model_dict = model.state_dict() 234 | # 1. filter out unnecessary keys 235 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 236 | # 2. overwrite entries in the existing state dict 237 | model_dict.update(pretrained_dict) 238 | # 3. load the new state dict 239 | model.load_state_dict(model_dict) 240 | return model 241 | 242 | 243 | def rgb_resnet50(pretrained=False, **kwargs): 244 | """Constructs a ResNet-50 model. 245 | Args: 246 | pretrained (bool): If True, returns a model pre-trained on ImageNet 247 | """ 248 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 249 | if pretrained: 250 | pretrained_dict = model_zoo.load_url(model_urls['resnet50']) 251 | model_dict = model.state_dict() 252 | # 1. filter out unnecessary keys 253 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 254 | # 2. overwrite entries in the existing state dict 255 | model_dict.update(pretrained_dict) 256 | # 3. load the new state dict 257 | model.load_state_dict(model_dict) 258 | 259 | return model 260 | 261 | 262 | def rgb_resnet101(pretrained=False, **kwargs): 263 | """Constructs a ResNet-101 model. 264 | Args: 265 | pretrained (bool): If True, returns a model pre-trained on ImageNet 266 | """ 267 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 268 | if pretrained: 269 | #pretrained_dict = model_zoo.load_url(model_urls['resnet101']) 270 | #model_dict = model.state_dict() 271 | # # 1. filter out unnecessary keys 272 | #pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 273 | # # 2. overwrite entries in the existing state dict 274 | #model_dict.update(pretrained_dict) 275 | # # 3. load the new state dict 276 | #model.load_state_dict(model_dict) 277 | params = torch.load(model_urls['resnet101']) 278 | model = torch.nn.DataParallel(model).cuda() 279 | model.load_state_dict(params['state_dict']) 280 | return model 281 | 282 | 283 | def rgb_resnet152(pretrained=False, **kwargs): 284 | """Constructs a ResNet-152 model. 285 | Args: 286 | pretrained (bool): If True, returns a model pre-trained on ImageNet 287 | """ 288 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 289 | if pretrained: 290 | pretrained_dict = model_zoo.load_url(model_urls['resnet152']) 291 | model_dict = model.state_dict() 292 | # 1. filter out unnecessary keys 293 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 294 | # 2. overwrite entries in the existing state dict 295 | model_dict.update(pretrained_dict) 296 | # 3. load the new state dict 297 | model.load_state_dict(model_dict) 298 | return model 299 | -------------------------------------------------------------------------------- /dataset/tiered-imagenet/train.txt: -------------------------------------------------------------------------------- 1 | train/n04552348/n0455234800326574.png 0 2 | train/n04552348/n0455234800325700.png 0 3 | train/n04552348/n0455234800325660.png 0 4 | train/n04552348/n0455234800326229.png 0 5 | train/n04552348/n0455234800326709.png 0 6 | train/n04552348/n0455234800326433.png 0 7 | train/n04552348/n0455234800325494.png 0 8 | train/n04552348/n0455234800326131.png 0 9 | train/n04552348/n0455234800326241.png 0 10 | train/n04552348/n0455234800325636.png 0 11 | train/n04552348/n0455234800326306.png 0 12 | train/n04552348/n0455234800326594.png 0 13 | train/n04552348/n0455234800325887.png 0 14 | train/n04552348/n0455234800325699.png 0 15 | train/n04552348/n0455234800326283.png 0 16 | train/n04552348/n0455234800325654.png 0 17 | train/n04552348/n0455234800326442.png 0 18 | train/n04552348/n0455234800325847.png 0 19 | train/n04552348/n0455234800326330.png 0 20 | train/n04552348/n0455234800326611.png 0 21 | train/n04552348/n0455234800325667.png 0 22 | train/n04552348/n0455234800325548.png 0 23 | train/n04552348/n0455234800326136.png 0 24 | train/n04552348/n0455234800325586.png 0 25 | train/n04552348/n0455234800326005.png 0 26 | train/n04552348/n0455234800325483.png 0 27 | train/n04552348/n0455234800325993.png 0 28 | train/n04552348/n0455234800325560.png 0 29 | train/n04552348/n0455234800326339.png 0 30 | train/n04552348/n0455234800325786.png 0 31 | train/n04552348/n0455234800326191.png 0 32 | train/n04552348/n0455234800326683.png 0 33 | train/n04552348/n0455234800326635.png 0 34 | train/n04552348/n0455234800326456.png 0 35 | train/n04552348/n0455234800326730.png 0 36 | train/n04552348/n0455234800325891.png 0 37 | train/n04552348/n0455234800326119.png 0 38 | train/n04552348/n0455234800325874.png 0 39 | train/n04552348/n0455234800326419.png 0 40 | train/n04552348/n0455234800326728.png 0 41 | train/n04552348/n0455234800326591.png 0 42 | train/n04552348/n0455234800326334.png 0 43 | train/n04552348/n0455234800326159.png 0 44 | train/n04552348/n0455234800326467.png 0 45 | train/n04552348/n0455234800325834.png 0 46 | train/n04552348/n0455234800326331.png 0 47 | train/n04552348/n0455234800326698.png 0 48 | train/n04552348/n0455234800325596.png 0 49 | train/n04552348/n0455234800326130.png 0 50 | train/n04552348/n0455234800326250.png 0 51 | train/n04552348/n0455234800325517.png 0 52 | train/n04552348/n0455234800325764.png 0 53 | train/n04552348/n0455234800326001.png 0 54 | train/n04552348/n0455234800326407.png 0 55 | train/n04552348/n0455234800325479.png 0 56 | train/n04552348/n0455234800326011.png 0 57 | train/n04552348/n0455234800325981.png 0 58 | train/n04552348/n0455234800325657.png 0 59 | train/n04552348/n0455234800325505.png 0 60 | train/n04552348/n0455234800325508.png 0 61 | train/n04552348/n0455234800326561.png 0 62 | train/n04552348/n0455234800325676.png 0 63 | train/n04552348/n0455234800326459.png 0 64 | train/n04552348/n0455234800325482.png 0 65 | train/n04552348/n0455234800325775.png 0 66 | train/n04552348/n0455234800326464.png 0 67 | train/n04552348/n0455234800326116.png 0 68 | train/n04552348/n0455234800326586.png 0 69 | train/n04552348/n0455234800325726.png 0 70 | train/n04552348/n0455234800326296.png 0 71 | train/n04552348/n0455234800325492.png 0 72 | train/n04552348/n0455234800325973.png 0 73 | train/n04552348/n0455234800326363.png 0 74 | train/n04552348/n0455234800326599.png 0 75 | train/n04552348/n0455234800325455.png 0 76 | train/n04552348/n0455234800326534.png 0 77 | train/n04552348/n0455234800326196.png 0 78 | train/n04552348/n0455234800326176.png 0 79 | train/n04552348/n0455234800326097.png 0 80 | train/n04552348/n0455234800326636.png 0 81 | train/n04552348/n0455234800326602.png 0 82 | train/n04552348/n0455234800325658.png 0 83 | train/n04552348/n0455234800325898.png 0 84 | train/n04552348/n0455234800326642.png 0 85 | train/n04552348/n0455234800326038.png 0 86 | train/n04552348/n0455234800325611.png 0 87 | train/n04552348/n0455234800326226.png 0 88 | train/n04552348/n0455234800326273.png 0 89 | train/n04552348/n0455234800326156.png 0 90 | train/n04552348/n0455234800325748.png 0 91 | train/n04552348/n0455234800325965.png 0 92 | train/n04552348/n0455234800326364.png 0 93 | train/n04552348/n0455234800325990.png 0 94 | train/n04552348/n0455234800325921.png 0 95 | train/n04552348/n0455234800325639.png 0 96 | train/n04552348/n0455234800326240.png 0 97 | train/n04552348/n0455234800325582.png 0 98 | train/n04552348/n0455234800326080.png 0 99 | train/n04552348/n0455234800326061.png 0 100 | train/n04552348/n0455234800325604.png 0 101 | train/n04552348/n0455234800326379.png 0 102 | train/n04552348/n0455234800325633.png 0 103 | train/n04552348/n0455234800325626.png 0 104 | train/n04552348/n0455234800325552.png 0 105 | train/n04552348/n0455234800325968.png 0 106 | train/n04552348/n0455234800325911.png 0 107 | train/n04552348/n0455234800325536.png 0 108 | train/n04552348/n0455234800326090.png 0 109 | train/n04552348/n0455234800325730.png 0 110 | train/n04552348/n0455234800326673.png 0 111 | train/n04552348/n0455234800326659.png 0 112 | train/n04552348/n0455234800326460.png 0 113 | train/n04552348/n0455234800326280.png 0 114 | train/n04552348/n0455234800325878.png 0 115 | train/n04552348/n0455234800326532.png 0 116 | train/n04552348/n0455234800325474.png 0 117 | train/n04552348/n0455234800325959.png 0 118 | train/n04552348/n0455234800326140.png 0 119 | train/n04552348/n0455234800326416.png 0 120 | train/n04552348/n0455234800326418.png 0 121 | train/n04552348/n0455234800325784.png 0 122 | train/n04552348/n0455234800326353.png 0 123 | train/n04552348/n0455234800325719.png 0 124 | train/n04552348/n0455234800325936.png 0 125 | train/n04552348/n0455234800326715.png 0 126 | train/n04552348/n0455234800326182.png 0 127 | train/n04552348/n0455234800326624.png 0 128 | train/n04552348/n0455234800325870.png 0 129 | train/n04552348/n0455234800325533.png 0 130 | train/n04552348/n0455234800325524.png 0 131 | train/n04552348/n0455234800326465.png 0 132 | train/n04552348/n0455234800326545.png 0 133 | train/n04552348/n0455234800325881.png 0 134 | train/n04552348/n0455234800325979.png 0 135 | train/n04552348/n0455234800326093.png 0 136 | train/n04552348/n0455234800325463.png 0 137 | train/n04552348/n0455234800326083.png 0 138 | train/n04552348/n0455234800326340.png 0 139 | train/n04552348/n0455234800325821.png 0 140 | train/n04552348/n0455234800326174.png 0 141 | train/n04552348/n0455234800326175.png 0 142 | train/n04552348/n0455234800325950.png 0 143 | train/n04552348/n0455234800325564.png 0 144 | train/n04552348/n0455234800326614.png 0 145 | train/n04552348/n0455234800325771.png 0 146 | train/n04552348/n0455234800326238.png 0 147 | train/n04552348/n0455234800326268.png 0 148 | train/n04552348/n0455234800326481.png 0 149 | train/n04552348/n0455234800326531.png 0 150 | train/n04552348/n0455234800326458.png 0 151 | train/n04552348/n0455234800325468.png 0 152 | train/n04552348/n0455234800326536.png 0 153 | train/n04552348/n0455234800326350.png 0 154 | train/n04552348/n0455234800325515.png 0 155 | train/n04552348/n0455234800326502.png 0 156 | train/n04552348/n0455234800326076.png 0 157 | train/n04552348/n0455234800326288.png 0 158 | train/n04552348/n0455234800325974.png 0 159 | train/n04552348/n0455234800325729.png 0 160 | train/n04552348/n0455234800325970.png 0 161 | train/n04552348/n0455234800326301.png 0 162 | train/n04552348/n0455234800325668.png 0 163 | train/n04552348/n0455234800326289.png 0 164 | train/n04552348/n0455234800325591.png 0 165 | train/n04552348/n0455234800325671.png 0 166 | train/n04552348/n0455234800326475.png 0 167 | train/n04552348/n0455234800325865.png 0 168 | train/n04552348/n0455234800326434.png 0 169 | train/n04552348/n0455234800326049.png 0 170 | train/n04552348/n0455234800325690.png 0 171 | train/n04552348/n0455234800325876.png 0 172 | train/n04552348/n0455234800325963.png 0 173 | train/n04552348/n0455234800326181.png 0 174 | train/n04552348/n0455234800326242.png 0 175 | train/n04552348/n0455234800325908.png 0 176 | train/n04552348/n0455234800325701.png 0 177 | train/n04552348/n0455234800325507.png 0 178 | train/n04552348/n0455234800325588.png 0 179 | train/n04552348/n0455234800325756.png 0 180 | train/n04552348/n0455234800325703.png 0 181 | train/n04552348/n0455234800325478.png 0 182 | train/n04552348/n0455234800326155.png 0 183 | train/n04552348/n0455234800326081.png 0 184 | train/n04552348/n0455234800325888.png 0 185 | train/n04552348/n0455234800325546.png 0 186 | train/n04552348/n0455234800326203.png 0 187 | train/n04552348/n0455234800325846.png 0 188 | train/n04552348/n0455234800326051.png 0 189 | train/n04552348/n0455234800326383.png 0 190 | train/n04552348/n0455234800326727.png 0 191 | train/n04552348/n0455234800326285.png 0 192 | train/n04552348/n0455234800326551.png 0 193 | train/n04552348/n0455234800326078.png 0 194 | train/n04552348/n0455234800325910.png 0 195 | train/n04552348/n0455234800326135.png 0 196 | train/n04552348/n0455234800326376.png 0 197 | train/n04552348/n0455234800326515.png 0 198 | train/n04552348/n0455234800326057.png 0 199 | train/n04552348/n0455234800326435.png 0 200 | train/n04552348/n0455234800325605.png 0 201 | train/n04552348/n0455234800325815.png 0 202 | train/n04552348/n0455234800325793.png 0 203 | train/n04552348/n0455234800325617.png 0 204 | train/n04552348/n0455234800325645.png 0 205 | train/n04552348/n0455234800326170.png 0 206 | train/n04552348/n0455234800325806.png 0 207 | train/n04552348/n0455234800325969.png 0 208 | train/n04552348/n0455234800325732.png 0 209 | train/n04552348/n0455234800326615.png 0 210 | train/n04552348/n0455234800326189.png 0 211 | train/n04552348/n0455234800326373.png 0 212 | train/n04552348/n0455234800325530.png 0 213 | train/n04552348/n0455234800325949.png 0 214 | train/n04552348/n0455234800326721.png 0 215 | train/n04552348/n0455234800325500.png 0 216 | train/n04552348/n0455234800325574.png 0 217 | train/n04552348/n0455234800326034.png 0 218 | train/n04552348/n0455234800325549.png 0 219 | train/n04552348/n0455234800326409.png 0 220 | train/n04552348/n0455234800326101.png 0 221 | train/n04552348/n0455234800326382.png 0 222 | train/n04552348/n0455234800325933.png 0 223 | train/n04552348/n0455234800326361.png 0 224 | train/n04552348/n0455234800325647.png 0 225 | train/n04552348/n0455234800326186.png 0 226 | train/n04552348/n0455234800326058.png 0 227 | train/n04552348/n0455234800326141.png 0 228 | train/n04552348/n0455234800326702.png 0 229 | train/n04552348/n0455234800326207.png 0 230 | train/n04552348/n0455234800325945.png 0 231 | train/n04552348/n0455234800325675.png 0 232 | train/n04552348/n0455234800326341.png 0 233 | train/n04552348/n0455234800326272.png 0 234 | train/n04552348/n0455234800326031.png 0 235 | train/n04552348/n0455234800326616.png 0 236 | train/n04552348/n0455234800325995.png 0 237 | train/n04552348/n0455234800326720.png 0 238 | train/n04552348/n0455234800325918.png 0 239 | train/n04552348/n0455234800325659.png 0 240 | train/n04552348/n0455234800326552.png 0 241 | train/n04552348/n0455234800325651.png 0 242 | train/n04552348/n0455234800325466.png 0 243 | train/n04552348/n0455234800325867.png 0 244 | train/n04552348/n0455234800325721.png 0 245 | train/n04552348/n0455234800325919.png 0 246 | train/n04552348/n0455234800325789.png 0 247 | train/n04552348/n0455234800325984.png 0 248 | train/n04552348/n0455234800325568.png 0 249 | train/n04552348/n0455234800326637.png 0 250 | train/n04552348/n0455234800325987.png 0 251 | train/n04552348/n0455234800326657.png 0 252 | train/n04552348/n0455234800326643.png 0 253 | train/n04552348/n0455234800325579.png 0 254 | train/n04552348/n0455234800325686.png 0 255 | train/n04552348/n0455234800326606.png 0 256 | train/n04552348/n0455234800326535.png 0 257 | train/n04552348/n0455234800326153.png 0 258 | train/n04552348/n0455234800326396.png 0 259 | train/n04552348/n0455234800326256.png 0 260 | train/n04552348/n0455234800326666.png 0 261 | -------------------------------------------------------------------------------- /dataset/tiered-imagenet/test_new_domain_fsl.txt: -------------------------------------------------------------------------------- 1 | test_new_domain/n02105251/n0210525100097217.png 0 2 | test_new_domain/n02105251/n0210525100096267.png 0 3 | test_new_domain/n02105251/n0210525100097401.png 0 4 | test_new_domain/n02105251/n0210525100096782.png 0 5 | test_new_domain/n02105251/n0210525100096484.png 0 6 | test_new_domain/n02105251/n0210525100097493.png 0 7 | test_new_domain/n02105251/n0210525100096459.png 0 8 | test_new_domain/n02105251/n0210525100096802.png 0 9 | test_new_domain/n02105251/n0210525100097219.png 0 10 | test_new_domain/n02105251/n0210525100097293.png 0 11 | test_new_domain/n02105251/n0210525100096930.png 0 12 | test_new_domain/n02105251/n0210525100097302.png 0 13 | test_new_domain/n02105251/n0210525100097387.png 0 14 | test_new_domain/n02105251/n0210525100096592.png 0 15 | test_new_domain/n02105251/n0210525100096536.png 0 16 | test_new_domain/n02105251/n0210525100097123.png 0 17 | test_new_domain/n02105251/n0210525100097130.png 0 18 | test_new_domain/n02105251/n0210525100096596.png 0 19 | test_new_domain/n02105251/n0210525100096921.png 0 20 | test_new_domain/n02105251/n0210525100096783.png 0 21 | test_new_domain/n02105251/n0210525100096263.png 0 22 | test_new_domain/n02105251/n0210525100096262.png 0 23 | test_new_domain/n02105251/n0210525100096401.png 0 24 | test_new_domain/n02105251/n0210525100096879.png 0 25 | test_new_domain/n02105251/n0210525100097458.png 0 26 | test_new_domain/n02105251/n0210525100097023.png 0 27 | test_new_domain/n02105251/n0210525100096522.png 0 28 | test_new_domain/n02105251/n0210525100096771.png 0 29 | test_new_domain/n02105251/n0210525100096366.png 0 30 | test_new_domain/n02105251/n0210525100097501.png 0 31 | test_new_domain/n02105251/n0210525100097494.png 0 32 | test_new_domain/n02105251/n0210525100097414.png 0 33 | test_new_domain/n02105251/n0210525100096775.png 0 34 | test_new_domain/n02105251/n0210525100096441.png 0 35 | test_new_domain/n02105251/n0210525100096324.png 0 36 | test_new_domain/n02105251/n0210525100096346.png 0 37 | test_new_domain/n02105251/n0210525100097457.png 0 38 | test_new_domain/n02105251/n0210525100097294.png 0 39 | test_new_domain/n02105251/n0210525100096359.png 0 40 | test_new_domain/n02105251/n0210525100096460.png 0 41 | test_new_domain/n02105251/n0210525100096335.png 0 42 | test_new_domain/n02105251/n0210525100097042.png 0 43 | test_new_domain/n02105251/n0210525100096392.png 0 44 | test_new_domain/n02105251/n0210525100097386.png 0 45 | test_new_domain/n02105251/n0210525100096894.png 0 46 | test_new_domain/n02105251/n0210525100097446.png 0 47 | test_new_domain/n02105251/n0210525100096576.png 0 48 | test_new_domain/n02105251/n0210525100097488.png 0 49 | test_new_domain/n02105251/n0210525100097380.png 0 50 | test_new_domain/n02105251/n0210525100096396.png 0 51 | test_new_domain/n02105251/n0210525100097496.png 0 52 | test_new_domain/n02105251/n0210525100097126.png 0 53 | test_new_domain/n02105251/n0210525100097052.png 0 54 | test_new_domain/n02105251/n0210525100096622.png 0 55 | test_new_domain/n02105251/n0210525100096853.png 0 56 | test_new_domain/n02105251/n0210525100096856.png 0 57 | test_new_domain/n02105251/n0210525100097053.png 0 58 | test_new_domain/n02105251/n0210525100096917.png 0 59 | test_new_domain/n02105251/n0210525100096510.png 0 60 | test_new_domain/n02105251/n0210525100096633.png 0 61 | test_new_domain/n02105251/n0210525100097439.png 0 62 | test_new_domain/n02105251/n0210525100096477.png 0 63 | test_new_domain/n02105251/n0210525100096763.png 0 64 | test_new_domain/n02105251/n0210525100097127.png 0 65 | test_new_domain/n02105251/n0210525100097337.png 0 66 | test_new_domain/n02105251/n0210525100096363.png 0 67 | test_new_domain/n02105251/n0210525100096646.png 0 68 | test_new_domain/n02105251/n0210525100096288.png 0 69 | test_new_domain/n02105251/n0210525100096957.png 0 70 | test_new_domain/n02105251/n0210525100097257.png 0 71 | test_new_domain/n02105251/n0210525100096819.png 0 72 | test_new_domain/n02105251/n0210525100096336.png 0 73 | test_new_domain/n02105251/n0210525100096870.png 0 74 | test_new_domain/n02105251/n0210525100097234.png 0 75 | test_new_domain/n02105251/n0210525100096294.png 0 76 | test_new_domain/n02105251/n0210525100097079.png 0 77 | test_new_domain/n02105251/n0210525100097191.png 0 78 | test_new_domain/n02105251/n0210525100097300.png 0 79 | test_new_domain/n02105251/n0210525100096831.png 0 80 | test_new_domain/n02105251/n0210525100096781.png 0 81 | test_new_domain/n02105251/n0210525100096589.png 0 82 | test_new_domain/n02105251/n0210525100097453.png 0 83 | test_new_domain/n02105251/n0210525100097338.png 0 84 | test_new_domain/n02105251/n0210525100096938.png 0 85 | test_new_domain/n02105251/n0210525100097062.png 0 86 | test_new_domain/n02105251/n0210525100097177.png 0 87 | test_new_domain/n02105251/n0210525100097513.png 0 88 | test_new_domain/n02105251/n0210525100097165.png 0 89 | test_new_domain/n02105251/n0210525100096342.png 0 90 | test_new_domain/n02105251/n0210525100097114.png 0 91 | test_new_domain/n02105251/n0210525100097518.png 0 92 | test_new_domain/n02105251/n0210525100096306.png 0 93 | test_new_domain/n02105251/n0210525100096997.png 0 94 | test_new_domain/n02105251/n0210525100096425.png 0 95 | test_new_domain/n02105251/n0210525100096340.png 0 96 | test_new_domain/n02105251/n0210525100096351.png 0 97 | test_new_domain/n02105251/n0210525100097005.png 0 98 | test_new_domain/n02105251/n0210525100096315.png 0 99 | test_new_domain/n02105251/n0210525100097341.png 0 100 | test_new_domain/n02105251/n0210525100097024.png 0 101 | test_new_domain/n02105251/n0210525100097109.png 0 102 | test_new_domain/n02105251/n0210525100097018.png 0 103 | test_new_domain/n02105251/n0210525100097253.png 0 104 | test_new_domain/n02105251/n0210525100097032.png 0 105 | test_new_domain/n02105251/n0210525100096476.png 0 106 | test_new_domain/n02105251/n0210525100097246.png 0 107 | test_new_domain/n02105251/n0210525100096298.png 0 108 | test_new_domain/n02105251/n0210525100096611.png 0 109 | test_new_domain/n02105251/n0210525100097480.png 0 110 | test_new_domain/n02105251/n0210525100096694.png 0 111 | test_new_domain/n02105251/n0210525100096574.png 0 112 | test_new_domain/n02105251/n0210525100096255.png 0 113 | test_new_domain/n02105251/n0210525100096839.png 0 114 | test_new_domain/n02105251/n0210525100096628.png 0 115 | test_new_domain/n02105251/n0210525100097345.png 0 116 | test_new_domain/n02105251/n0210525100097425.png 0 117 | test_new_domain/n02105251/n0210525100096907.png 0 118 | test_new_domain/n02105251/n0210525100096804.png 0 119 | test_new_domain/n02105251/n0210525100096761.png 0 120 | test_new_domain/n02105251/n0210525100096439.png 0 121 | test_new_domain/n02105251/n0210525100096750.png 0 122 | test_new_domain/n02105251/n0210525100096798.png 0 123 | test_new_domain/n02105251/n0210525100097158.png 0 124 | test_new_domain/n02105251/n0210525100096449.png 0 125 | test_new_domain/n02105251/n0210525100096490.png 0 126 | test_new_domain/n02105251/n0210525100097058.png 0 127 | test_new_domain/n02105251/n0210525100096371.png 0 128 | test_new_domain/n02105251/n0210525100096367.png 0 129 | test_new_domain/n02105251/n0210525100097275.png 0 130 | test_new_domain/n02105251/n0210525100097220.png 0 131 | test_new_domain/n02105251/n0210525100097358.png 0 132 | test_new_domain/n02105251/n0210525100097202.png 0 133 | test_new_domain/n02105251/n0210525100096554.png 0 134 | test_new_domain/n02105251/n0210525100096778.png 0 135 | test_new_domain/n02105251/n0210525100096973.png 0 136 | test_new_domain/n02105251/n0210525100096859.png 0 137 | test_new_domain/n02105251/n0210525100097328.png 0 138 | test_new_domain/n02105251/n0210525100096834.png 0 139 | test_new_domain/n02105251/n0210525100097317.png 0 140 | test_new_domain/n02105251/n0210525100097332.png 0 141 | test_new_domain/n02105251/n0210525100097460.png 0 142 | test_new_domain/n02105251/n0210525100096337.png 0 143 | test_new_domain/n02105251/n0210525100096772.png 0 144 | test_new_domain/n02105251/n0210525100097182.png 0 145 | test_new_domain/n02105251/n0210525100097213.png 0 146 | test_new_domain/n02105251/n0210525100097313.png 0 147 | test_new_domain/n02105251/n0210525100097025.png 0 148 | test_new_domain/n02105251/n0210525100096983.png 0 149 | test_new_domain/n02105251/n0210525100096581.png 0 150 | test_new_domain/n02105251/n0210525100096847.png 0 151 | test_new_domain/n02105251/n0210525100097055.png 0 152 | test_new_domain/n02105251/n0210525100097210.png 0 153 | test_new_domain/n02105251/n0210525100097145.png 0 154 | test_new_domain/n02105251/n0210525100097492.png 0 155 | test_new_domain/n02105251/n0210525100096293.png 0 156 | test_new_domain/n02105251/n0210525100097211.png 0 157 | test_new_domain/n02105251/n0210525100096665.png 0 158 | test_new_domain/n02105251/n0210525100096472.png 0 159 | test_new_domain/n02105251/n0210525100097304.png 0 160 | test_new_domain/n02105251/n0210525100096418.png 0 161 | test_new_domain/n02105251/n0210525100096385.png 0 162 | test_new_domain/n02105251/n0210525100096766.png 0 163 | test_new_domain/n02105251/n0210525100097330.png 0 164 | test_new_domain/n02105251/n0210525100096701.png 0 165 | test_new_domain/n02105251/n0210525100097110.png 0 166 | test_new_domain/n02105251/n0210525100096302.png 0 167 | test_new_domain/n02105251/n0210525100097408.png 0 168 | test_new_domain/n02105251/n0210525100096397.png 0 169 | test_new_domain/n02105251/n0210525100096693.png 0 170 | test_new_domain/n02105251/n0210525100096721.png 0 171 | test_new_domain/n02105251/n0210525100097187.png 0 172 | test_new_domain/n02105251/n0210525100096395.png 0 173 | test_new_domain/n02105251/n0210525100097509.png 0 174 | test_new_domain/n02105251/n0210525100096523.png 0 175 | test_new_domain/n02105251/n0210525100097273.png 0 176 | test_new_domain/n02105251/n0210525100096773.png 0 177 | test_new_domain/n02105251/n0210525100097061.png 0 178 | test_new_domain/n02105251/n0210525100097477.png 0 179 | test_new_domain/n02105251/n0210525100097083.png 0 180 | test_new_domain/n02105251/n0210525100096355.png 0 181 | test_new_domain/n02105251/n0210525100096329.png 0 182 | test_new_domain/n02105251/n0210525100097417.png 0 183 | test_new_domain/n02105251/n0210525100096673.png 0 184 | test_new_domain/n02105251/n0210525100096594.png 0 185 | test_new_domain/n02105251/n0210525100096785.png 0 186 | test_new_domain/n02105251/n0210525100097482.png 0 187 | test_new_domain/n02105251/n0210525100097423.png 0 188 | test_new_domain/n02105251/n0210525100096368.png 0 189 | test_new_domain/n02105251/n0210525100097310.png 0 190 | test_new_domain/n02105251/n0210525100097035.png 0 191 | test_new_domain/n02105251/n0210525100096512.png 0 192 | test_new_domain/n02105251/n0210525100097499.png 0 193 | test_new_domain/n02105251/n0210525100097278.png 0 194 | test_new_domain/n02105251/n0210525100097339.png 0 195 | test_new_domain/n02105251/n0210525100096672.png 0 196 | test_new_domain/n02105251/n0210525100097163.png 0 197 | test_new_domain/n02105251/n0210525100096902.png 0 198 | test_new_domain/n02105251/n0210525100096988.png 0 199 | test_new_domain/n02105251/n0210525100097131.png 0 200 | test_new_domain/n02105251/n0210525100096866.png 0 201 | test_new_domain/n02105251/n0210525100096709.png 0 202 | test_new_domain/n02105251/n0210525100096478.png 0 203 | test_new_domain/n02105251/n0210525100097355.png 0 204 | test_new_domain/n02105251/n0210525100097036.png 0 205 | test_new_domain/n02105251/n0210525100097251.png 0 206 | test_new_domain/n02105251/n0210525100096838.png 0 207 | test_new_domain/n02105251/n0210525100096537.png 0 208 | test_new_domain/n02105251/n0210525100097020.png 0 209 | test_new_domain/n02105251/n0210525100096422.png 0 210 | test_new_domain/n02105251/n0210525100096699.png 0 211 | test_new_domain/n02105251/n0210525100096936.png 0 212 | test_new_domain/n02105251/n0210525100096892.png 0 213 | test_new_domain/n02105251/n0210525100096530.png 0 214 | test_new_domain/n02105251/n0210525100097010.png 0 215 | test_new_domain/n02105251/n0210525100096533.png 0 216 | test_new_domain/n02105251/n0210525100096796.png 0 217 | test_new_domain/n02105251/n0210525100096817.png 0 218 | test_new_domain/n02105251/n0210525100096551.png 0 219 | test_new_domain/n02105251/n0210525100097203.png 0 220 | test_new_domain/n02105251/n0210525100097034.png 0 221 | test_new_domain/n02105251/n0210525100096822.png 0 222 | test_new_domain/n02105251/n0210525100096735.png 0 223 | test_new_domain/n02105251/n0210525100096725.png 0 224 | test_new_domain/n02105251/n0210525100096520.png 0 225 | test_new_domain/n02105251/n0210525100097550.png 0 226 | test_new_domain/n02105251/n0210525100097536.png 0 227 | test_new_domain/n02105251/n0210525100097479.png 0 228 | test_new_domain/n02105251/n0210525100096790.png 0 229 | test_new_domain/n02105251/n0210525100096808.png 0 230 | test_new_domain/n02105251/n0210525100097397.png 0 231 | test_new_domain/n02105251/n0210525100097171.png 0 232 | test_new_domain/n02105251/n0210525100097262.png 0 233 | test_new_domain/n02105251/n0210525100097154.png 0 234 | test_new_domain/n02105251/n0210525100097000.png 0 235 | test_new_domain/n02105251/n0210525100097411.png 0 236 | test_new_domain/n02105251/n0210525100096284.png 0 237 | -------------------------------------------------------------------------------- /dataset/tiered-imagenet/test.txt: -------------------------------------------------------------------------------- 1 | test/n02105251/n0210525100096882.png 0 2 | test/n02105251/n0210525100096671.png 0 3 | test/n02105251/n0210525100097149.png 0 4 | test/n02105251/n0210525100097328.png 0 5 | test/n02105251/n0210525100097345.png 0 6 | test/n02105251/n0210525100096265.png 0 7 | test/n02105251/n0210525100096471.png 0 8 | test/n02105251/n0210525100096403.png 0 9 | test/n02105251/n0210525100097121.png 0 10 | test/n02105251/n0210525100096611.png 0 11 | test/n02105251/n0210525100097050.png 0 12 | test/n02105251/n0210525100096544.png 0 13 | test/n02105251/n0210525100096992.png 0 14 | test/n02105251/n0210525100097053.png 0 15 | test/n02105251/n0210525100097207.png 0 16 | test/n02105251/n0210525100096623.png 0 17 | test/n02105251/n0210525100097173.png 0 18 | test/n02105251/n0210525100096513.png 0 19 | test/n02105251/n0210525100096778.png 0 20 | test/n02105251/n0210525100097371.png 0 21 | test/n02105251/n0210525100096347.png 0 22 | test/n02105251/n0210525100097188.png 0 23 | test/n02105251/n0210525100096576.png 0 24 | test/n02105251/n0210525100097090.png 0 25 | test/n02105251/n0210525100097420.png 0 26 | test/n02105251/n0210525100096435.png 0 27 | test/n02105251/n0210525100096821.png 0 28 | test/n02105251/n0210525100096969.png 0 29 | test/n02105251/n0210525100096820.png 0 30 | test/n02105251/n0210525100096517.png 0 31 | test/n02105251/n0210525100096852.png 0 32 | test/n02105251/n0210525100096437.png 0 33 | test/n02105251/n0210525100096702.png 0 34 | test/n02105251/n0210525100096540.png 0 35 | test/n02105251/n0210525100096452.png 0 36 | test/n02105251/n0210525100096558.png 0 37 | test/n02105251/n0210525100096588.png 0 38 | test/n02105251/n0210525100097299.png 0 39 | test/n02105251/n0210525100096282.png 0 40 | test/n02105251/n0210525100097500.png 0 41 | test/n02105251/n0210525100097049.png 0 42 | test/n02105251/n0210525100096888.png 0 43 | test/n02105251/n0210525100096375.png 0 44 | test/n02105251/n0210525100097412.png 0 45 | test/n02105251/n0210525100096687.png 0 46 | test/n02105251/n0210525100097022.png 0 47 | test/n02105251/n0210525100096636.png 0 48 | test/n02105251/n0210525100096991.png 0 49 | test/n02105251/n0210525100096885.png 0 50 | test/n02105251/n0210525100096899.png 0 51 | test/n02105251/n0210525100097417.png 0 52 | test/n02105251/n0210525100096819.png 0 53 | test/n02105251/n0210525100097452.png 0 54 | test/n02105251/n0210525100097125.png 0 55 | test/n02105251/n0210525100096510.png 0 56 | test/n02105251/n0210525100097230.png 0 57 | test/n02105251/n0210525100097214.png 0 58 | test/n02105251/n0210525100097000.png 0 59 | test/n02105251/n0210525100096848.png 0 60 | test/n02105251/n0210525100096316.png 0 61 | test/n02105251/n0210525100097436.png 0 62 | test/n02105251/n0210525100096321.png 0 63 | test/n02105251/n0210525100097459.png 0 64 | test/n02105251/n0210525100097157.png 0 65 | test/n02105251/n0210525100097168.png 0 66 | test/n02105251/n0210525100097040.png 0 67 | test/n02105251/n0210525100096664.png 0 68 | test/n02105251/n0210525100097261.png 0 69 | test/n02105251/n0210525100096740.png 0 70 | test/n02105251/n0210525100096520.png 0 71 | test/n02105251/n0210525100097437.png 0 72 | test/n02105251/n0210525100096718.png 0 73 | test/n02105251/n0210525100097162.png 0 74 | test/n02105251/n0210525100096458.png 0 75 | test/n02105251/n0210525100097273.png 0 76 | test/n02105251/n0210525100096610.png 0 77 | test/n02105251/n0210525100096262.png 0 78 | test/n02105251/n0210525100096334.png 0 79 | test/n02105251/n0210525100096271.png 0 80 | test/n02105251/n0210525100097221.png 0 81 | test/n02105251/n0210525100096365.png 0 82 | test/n02105251/n0210525100097526.png 0 83 | test/n02105251/n0210525100096455.png 0 84 | test/n02105251/n0210525100096942.png 0 85 | test/n02105251/n0210525100097418.png 0 86 | test/n02105251/n0210525100096516.png 0 87 | test/n02105251/n0210525100096760.png 0 88 | test/n02105251/n0210525100096696.png 0 89 | test/n02105251/n0210525100096927.png 0 90 | test/n02105251/n0210525100096818.png 0 91 | test/n02105251/n0210525100097414.png 0 92 | test/n02105251/n0210525100096512.png 0 93 | test/n02105251/n0210525100096813.png 0 94 | test/n02105251/n0210525100096627.png 0 95 | test/n02105251/n0210525100096761.png 0 96 | test/n02105251/n0210525100097523.png 0 97 | test/n02105251/n0210525100096280.png 0 98 | test/n02105251/n0210525100096858.png 0 99 | test/n02105251/n0210525100096350.png 0 100 | test/n02105251/n0210525100097528.png 0 101 | test/n02105251/n0210525100097465.png 0 102 | test/n02105251/n0210525100096267.png 0 103 | test/n02105251/n0210525100097066.png 0 104 | test/n02105251/n0210525100096809.png 0 105 | test/n02105251/n0210525100096773.png 0 106 | test/n02105251/n0210525100097241.png 0 107 | test/n02105251/n0210525100096587.png 0 108 | test/n02105251/n0210525100096725.png 0 109 | test/n02105251/n0210525100096542.png 0 110 | test/n02105251/n0210525100097155.png 0 111 | test/n02105251/n0210525100097404.png 0 112 | test/n02105251/n0210525100096484.png 0 113 | test/n02105251/n0210525100096772.png 0 114 | test/n02105251/n0210525100097051.png 0 115 | test/n02105251/n0210525100097489.png 0 116 | test/n02105251/n0210525100096961.png 0 117 | test/n02105251/n0210525100096924.png 0 118 | test/n02105251/n0210525100097488.png 0 119 | test/n02105251/n0210525100097166.png 0 120 | test/n02105251/n0210525100096644.png 0 121 | test/n02105251/n0210525100097483.png 0 122 | test/n02105251/n0210525100097409.png 0 123 | test/n02105251/n0210525100097469.png 0 124 | test/n02105251/n0210525100097453.png 0 125 | test/n02105251/n0210525100096632.png 0 126 | test/n02105251/n0210525100097314.png 0 127 | test/n02105251/n0210525100097084.png 0 128 | test/n02105251/n0210525100097326.png 0 129 | test/n02105251/n0210525100096418.png 0 130 | test/n02105251/n0210525100097342.png 0 131 | test/n02105251/n0210525100096759.png 0 132 | test/n02105251/n0210525100097275.png 0 133 | test/n02105251/n0210525100096338.png 0 134 | test/n02105251/n0210525100097120.png 0 135 | test/n02105251/n0210525100096795.png 0 136 | test/n02105251/n0210525100097058.png 0 137 | test/n02105251/n0210525100096420.png 0 138 | test/n02105251/n0210525100097072.png 0 139 | test/n02105251/n0210525100097441.png 0 140 | test/n02105251/n0210525100097231.png 0 141 | test/n02105251/n0210525100097264.png 0 142 | test/n02105251/n0210525100097291.png 0 143 | test/n02105251/n0210525100096469.png 0 144 | test/n02105251/n0210525100096288.png 0 145 | test/n02105251/n0210525100096693.png 0 146 | test/n02105251/n0210525100096359.png 0 147 | test/n02105251/n0210525100096268.png 0 148 | test/n02105251/n0210525100096649.png 0 149 | test/n02105251/n0210525100097135.png 0 150 | test/n02105251/n0210525100097427.png 0 151 | test/n02105251/n0210525100096410.png 0 152 | test/n02105251/n0210525100096692.png 0 153 | test/n02105251/n0210525100096974.png 0 154 | test/n02105251/n0210525100097161.png 0 155 | test/n02105251/n0210525100096488.png 0 156 | test/n02105251/n0210525100097256.png 0 157 | test/n02105251/n0210525100097373.png 0 158 | test/n02105251/n0210525100096429.png 0 159 | test/n02105251/n0210525100097099.png 0 160 | test/n02105251/n0210525100096916.png 0 161 | test/n02105251/n0210525100096864.png 0 162 | test/n02105251/n0210525100097516.png 0 163 | test/n02105251/n0210525100096817.png 0 164 | test/n02105251/n0210525100097416.png 0 165 | test/n02105251/n0210525100096533.png 0 166 | test/n02105251/n0210525100096871.png 0 167 | test/n02105251/n0210525100096881.png 0 168 | test/n02105251/n0210525100096735.png 0 169 | test/n02105251/n0210525100097142.png 0 170 | test/n02105251/n0210525100097530.png 0 171 | test/n02105251/n0210525100096531.png 0 172 | test/n02105251/n0210525100096456.png 0 173 | test/n02105251/n0210525100096880.png 0 174 | test/n02105251/n0210525100096697.png 0 175 | test/n02105251/n0210525100096928.png 0 176 | test/n02105251/n0210525100096326.png 0 177 | test/n02105251/n0210525100096604.png 0 178 | test/n02105251/n0210525100096417.png 0 179 | test/n02105251/n0210525100096297.png 0 180 | test/n02105251/n0210525100096995.png 0 181 | test/n02105251/n0210525100096862.png 0 182 | test/n02105251/n0210525100097103.png 0 183 | test/n02105251/n0210525100097379.png 0 184 | test/n02105251/n0210525100096701.png 0 185 | test/n02105251/n0210525100096579.png 0 186 | test/n02105251/n0210525100097246.png 0 187 | test/n02105251/n0210525100096278.png 0 188 | test/n02105251/n0210525100097503.png 0 189 | test/n02105251/n0210525100096922.png 0 190 | test/n02105251/n0210525100097226.png 0 191 | test/n02105251/n0210525100096401.png 0 192 | test/n02105251/n0210525100096748.png 0 193 | test/n02105251/n0210525100096908.png 0 194 | test/n02105251/n0210525100097384.png 0 195 | test/n02105251/n0210525100097126.png 0 196 | test/n02105251/n0210525100097029.png 0 197 | test/n02105251/n0210525100097054.png 0 198 | test/n02105251/n0210525100096259.png 0 199 | test/n02105251/n0210525100096258.png 0 200 | test/n02105251/n0210525100096327.png 0 201 | test/n02105251/n0210525100097272.png 0 202 | test/n02105251/n0210525100097381.png 0 203 | test/n02105251/n0210525100097546.png 0 204 | test/n02105251/n0210525100096846.png 0 205 | test/n02105251/n0210525100096890.png 0 206 | test/n02105251/n0210525100097495.png 0 207 | test/n02105251/n0210525100096708.png 0 208 | test/n02105251/n0210525100096541.png 0 209 | test/n02105251/n0210525100096967.png 0 210 | test/n02105251/n0210525100096305.png 0 211 | test/n02105251/n0210525100097393.png 0 212 | test/n02105251/n0210525100097204.png 0 213 | test/n02105251/n0210525100097522.png 0 214 | test/n02105251/n0210525100096391.png 0 215 | test/n02105251/n0210525100096581.png 0 216 | test/n02105251/n0210525100096719.png 0 217 | test/n02105251/n0210525100096380.png 0 218 | test/n02105251/n0210525100096536.png 0 219 | test/n02105251/n0210525100097109.png 0 220 | test/n02105251/n0210525100097364.png 0 221 | test/n02105251/n0210525100097356.png 0 222 | test/n02105251/n0210525100096949.png 0 223 | test/n02105251/n0210525100097007.png 0 224 | test/n02105251/n0210525100096463.png 0 225 | test/n02105251/n0210525100097146.png 0 226 | test/n02105251/n0210525100097478.png 0 227 | test/n02105251/n0210525100096837.png 0 228 | test/n02105251/n0210525100096549.png 0 229 | test/n02105251/n0210525100096710.png 0 230 | test/n02105251/n0210525100096442.png 0 231 | test/n02105251/n0210525100097124.png 0 232 | test/n02105251/n0210525100096560.png 0 233 | test/n02105251/n0210525100096325.png 0 234 | test/n02105251/n0210525100097348.png 0 235 | test/n02105251/n0210525100097130.png 0 236 | test/n02105251/n0210525100096443.png 0 237 | test/n02105251/n0210525100096367.png 0 238 | test/n02105251/n0210525100096342.png 0 239 | test/n02105251/n0210525100096344.png 0 240 | test/n02105251/n0210525100096913.png 0 241 | test/n02105251/n0210525100096595.png 0 242 | test/n02105251/n0210525100096640.png 0 243 | test/n02105251/n0210525100096565.png 0 244 | test/n02105251/n0210525100097398.png 0 245 | test/n02105251/n0210525100096567.png 0 246 | test/n02105251/n0210525100096907.png 0 247 | test/n02105251/n0210525100096440.png 0 248 | test/n02105251/n0210525100097464.png 0 249 | test/n02105251/n0210525100096478.png 0 250 | test/n02105251/n0210525100096590.png 0 251 | test/n02105251/n0210525100096266.png 0 252 | test/n02105251/n0210525100097202.png 0 253 | test/n02105251/n0210525100096500.png 0 254 | test/n02105251/n0210525100097077.png 0 255 | test/n02105251/n0210525100096323.png 0 256 | test/n02105251/n0210525100097093.png 0 257 | test/n02105251/n0210525100096370.png 0 258 | test/n02105251/n0210525100097547.png 0 259 | test/n02105251/n0210525100096764.png 0 260 | test/n02105251/n0210525100097067.png 0 261 | test/n02105251/n0210525100096324.png 0 262 | test/n02105251/n0210525100096480.png 0 263 | test/n02105251/n0210525100096602.png 0 264 | test/n02105251/n0210525100096331.png 0 265 | test/n02105251/n0210525100097114.png 0 266 | test/n02105251/n0210525100097242.png 0 267 | test/n02105251/n0210525100097081.png 0 268 | test/n02105251/n0210525100096749.png 0 269 | test/n02105251/n0210525100097073.png 0 270 | test/n02105251/n0210525100096675.png 0 271 | test/n02105251/n0210525100096438.png 0 272 | test/n02105251/n0210525100096608.png 0 273 | test/n02105251/n0210525100096685.png 0 274 | test/n02105251/n0210525100096711.png 0 275 | test/n02105251/n0210525100096353.png 0 276 | test/n02105251/n0210525100096292.png 0 277 | test/n02105251/n0210525100096601.png 0 278 | test/n02105251/n0210525100096936.png 0 279 | test/n02105251/n0210525100097497.png 0 280 | test/n02105251/n0210525100096767.png 0 281 | test/n02105251/n0210525100096803.png 0 282 | test/n02105251/n0210525100096930.png 0 283 | test/n02105251/n0210525100096346.png 0 284 | test/n02105251/n0210525100097543.png 0 285 | test/n02105251/n0210525100096349.png 0 286 | test/n02105251/n0210525100096460.png 0 287 | test/n02105251/n0210525100096613.png 0 288 | test/n02105251/n0210525100097518.png 0 289 | test/n02105251/n0210525100097374.png 0 290 | test/n02105251/n0210525100096383.png 0 291 | test/n02105251/n0210525100096682.png 0 292 | test/n02105251/n0210525100096497.png 0 293 | test/n02105251/n0210525100097269.png 0 294 | test/n02105251/n0210525100097138.png 0 295 | test/n02105251/n0210525100097292.png 0 296 | test/n02105251/n0210525100097380.png 0 297 | test/n02105251/n0210525100096523.png 0 298 | test/n02105251/n0210525100097302.png 0 299 | test/n02105251/n0210525100096494.png 0 300 | test/n02105251/n0210525100097548.png 0 301 | test/n02105251/n0210525100096843.png 0 302 | test/n02105251/n0210525100097243.png 0 303 | test/n02105251/n0210525100096594.png 0 304 | test/n02105251/n0210525100097492.png 0 305 | test/n02105251/n0210525100096775.png 0 306 | test/n02105251/n0210525100096860.png 0 307 | test/n02105251/n0210525100096891.png 0 308 | test/n02105251/n0210525100096600.png 0 309 | test/n02105251/n0210525100097369.png 0 310 | test/n02105251/n0210525100097190.png 0 311 | test/n02105251/n0210525100096369.png 0 312 | test/n02105251/n0210525100096978.png 0 313 | test/n02105251/n0210525100096751.png 0 314 | test/n02105251/n0210525100097123.png 0 315 | test/n02105251/n0210525100096491.png 0 316 | test/n02105251/n0210525100096828.png 0 317 | test/n02105251/n0210525100097020.png 0 318 | test/n02105251/n0210525100097046.png 0 319 | test/n02105251/n0210525100097127.png 0 320 | test/n02105251/n0210525100097139.png 0 321 | test/n02105251/n0210525100096650.png 0 322 | test/n02105251/n0210525100096931.png 0 323 | test/n02105251/n0210525100096294.png 0 324 | test/n02105251/n0210525100096328.png 0 325 | test/n02105251/n0210525100097089.png 0 326 | test/n02105251/n0210525100097325.png 0 327 | test/n02105251/n0210525100096758.png 0 328 | test/n02105251/n0210525100097186.png 0 329 | test/n02105251/n0210525100096312.png 0 330 | test/n02105251/n0210525100097277.png 0 331 | test/n02105251/n0210525100096433.png 0 332 | test/n02105251/n0210525100096570.png 0 333 | test/n02105251/n0210525100097110.png 0 334 | test/n02105251/n0210525100097534.png 0 335 | test/n02105251/n0210525100097206.png 0 336 | test/n02105251/n0210525100097180.png 0 337 | test/n02105251/n0210525100097213.png 0 338 | test/n02105251/n0210525100096412.png 0 339 | test/n02105251/n0210525100096614.png 0 340 | test/n02105251/n0210525100096959.png 0 341 | test/n02105251/n0210525100097087.png 0 342 | test/n02105251/n0210525100097111.png 0 343 | test/n02105251/n0210525100096336.png 0 344 | test/n02105251/n0210525100097233.png 0 345 | test/n02105251/n0210525100096596.png 0 346 | -------------------------------------------------------------------------------- /domain_adaptive_module/eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path as osp 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | import network 10 | import loss 11 | import pre_process as prep 12 | from torch.utils.data import DataLoader 13 | import lr_schedule 14 | import data_list 15 | from data_list import ImageList 16 | from torch.autograd import Variable 17 | import random 18 | import pdb 19 | import math 20 | 21 | 22 | from torchvision import models 23 | 24 | class ResNetFc(nn.Module): 25 | def __init__(self, use_bottleneck=True, bottleneck_dim=256, new_cls=False, class_num=1000): 26 | super(ResNetFc, self).__init__() 27 | model_resnet = models.resnet50(pretrained=True) 28 | self.conv1 = model_resnet.conv1 29 | self.bn1 = model_resnet.bn1 30 | self.relu = model_resnet.relu 31 | self.maxpool = model_resnet.maxpool 32 | self.layer1 = model_resnet.layer1 33 | self.layer2 = model_resnet.layer2 34 | self.layer3 = model_resnet.layer3 35 | self.layer4 = model_resnet.layer4 36 | self.avgpool = model_resnet.avgpool 37 | self.feature_layers = nn.Sequential(self.conv1, self.bn1, self.relu, self.maxpool, \ 38 | self.layer1, self.layer2, self.layer3, self.layer4, self.avgpool) 39 | 40 | self.use_bottleneck = use_bottleneck 41 | self.new_cls = new_cls 42 | if new_cls: 43 | if self.use_bottleneck: 44 | self.bottleneck = nn.Linear(model_resnet.fc.in_features, bottleneck_dim) 45 | self.fc = nn.Linear(bottleneck_dim, class_num) 46 | self.bottleneck.apply(init_weights) 47 | self.fc.apply(init_weights) 48 | self.__in_features = bottleneck_dim 49 | else: 50 | self.fc = nn.Linear(model_resnet.fc.in_features, class_num) 51 | self.fc.apply(init_weights) 52 | self.__in_features = model_resnet.fc.in_features 53 | else: 54 | self.fc = model_resnet.fc 55 | self.__in_features = model_resnet.fc.in_features 56 | 57 | def forward(self, x): 58 | x = self.feature_layers(x) 59 | x = x.view(x.size(0), -1) 60 | if self.use_bottleneck and self.new_cls: 61 | x = self.bottleneck(x) 62 | y = self.fc(x) 63 | return x, y 64 | 65 | 66 | 67 | def image_classification_test(loader, model, test_10crop=True): 68 | start_test = True 69 | with torch.no_grad(): 70 | # if test_10crop: 71 | # iter_test = [iter(loader['test'][i]) for i in range(10)] 72 | # for i in range(len(loader['test'][0])): 73 | # data = [iter_test[j].next() for j in range(10)] 74 | # inputs = [data[j][0] for j in range(10)] 75 | # labels = data[0][1] 76 | # for j in range(10): 77 | # inputs[j] = inputs[j].cuda() 78 | # labels = labels 79 | # outputs = [] 80 | # for j in range(10): 81 | # _, predict_out = model(inputs[j]) 82 | # outputs.append(nn.Softmax(dim=1)(predict_out)) 83 | # outputs = sum(outputs) 84 | # if start_test: 85 | # all_output = outputs.float().cpu() 86 | # all_label = labels.float() 87 | # start_test = False 88 | # else: 89 | # all_output = torch.cat((all_output, outputs.float().cpu()), 0) 90 | # all_label = torch.cat((all_label, labels.float()), 0) 91 | # else: 92 | iter_test = iter(loader["test"]) 93 | for i in range(len(loader['test'])): 94 | data = iter_test.next() 95 | inputs = data[0] 96 | labels = data[1] 97 | print(inputs.size(), labels.size()) 98 | inputs = inputs.cuda() 99 | labels = labels.cuda() 100 | outputs = model(inputs) 101 | print(outputs.size()) 102 | if start_test: 103 | all_output = outputs.float().cpu() 104 | all_label = labels.float() 105 | start_test = False 106 | else: 107 | all_output = torch.cat((all_output, outputs.float().cpu()), 0) 108 | all_label = torch.cat((all_label, labels.float()), 0) 109 | all_output = all_output.squeeze(2).squeeze(2).numpy() 110 | all_label = all_label.cpu().numpy() 111 | length = len(all_label) 112 | f = open('feature15_all.txt','w') 113 | for i in range(length): 114 | line = all_output[i] 115 | for item in line: 116 | print (item, end =' ', file=f) 117 | print(int(all_label[i]), file=f) 118 | f.close() 119 | return 120 | 121 | 122 | def train(config): 123 | ## set pre-process 124 | prep_dict = {} 125 | prep_config = config["prep"] 126 | prep_dict["source"] = prep.image_train(**config["prep"]['params']) 127 | prep_dict["target"] = prep.image_train(**config["prep"]['params']) 128 | if prep_config["test_10crop"]: 129 | prep_dict["test"] = prep.image_test_10crop(**config["prep"]['params']) 130 | else: 131 | prep_dict["test"] = prep.image_test(**config["prep"]['params']) 132 | 133 | ## prepare data 134 | dsets = {} 135 | dset_loaders = {} 136 | data_config = config["data"] 137 | train_bs = data_config["source"]["batch_size"] 138 | test_bs = data_config["test"]["batch_size"] 139 | print(config, data_config) 140 | # dsets["source"] = ImageList(open(data_config["source"]["list_path"]).readlines(), \ 141 | # transform=prep_dict["source"]) 142 | # dset_loaders["source"] = DataLoader(dsets["source"], batch_size=train_bs, \ 143 | # shuffle=True, num_workers=4, drop_last=True) 144 | # dsets["target"] = ImageList(open(data_config["target"]["list_path"]).readlines(), \ 145 | # transform=prep_dict["target"]) 146 | # dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, \ 147 | # shuffle=True, num_workers=4, drop_last=True) 148 | 149 | # if prep_config["test_10crop"]: 150 | # for i in range(10): 151 | # dsets["test"] = [ImageList(open(data_config["test"]["list_path"]).readlines(), \ 152 | # transform=prep_dict["test"][i]) for i in range(10)] 153 | # dset_loaders["test"] = [DataLoader(dset, batch_size=test_bs, \ 154 | # shuffle=False, num_workers=4) for dset in dsets['test']] 155 | # else: 156 | 157 | 158 | dsets["test"] = ImageList(open(data_config["test"]["list_path"]).readlines(), \ 159 | transform=prep_dict["test"]) 160 | dset_loaders["test"] = DataLoader(dsets["test"], batch_size=test_bs, \ 161 | shuffle=False, num_workers=4) 162 | 163 | print ('load data finished') 164 | class_num = config["network"]["params"]["class_num"] 165 | 166 | ## set base network 167 | net_config = config["network"] 168 | #base_network = net_config["name"](**net_config["params"]) 169 | #base_network = base_network.cuda() 170 | #print(base_network) 171 | 172 | base_network = torch.load('snapshot/iter_90000_model.pth.tar') 173 | base_network= list(base_network.children())[0] 174 | base_network = base_network.module 175 | # for key, value in base_network.state_dict().items(): 176 | # print(key) 177 | base_network = list(base_network.children())[9].cuda() 178 | # base_network= torch.nn.Sequential(*list(base_network.children())[:-1]).cuda() 179 | base_network = nn.DataParallel(base_network) 180 | print(base_network) 181 | #base_network.load_state_dict(checkpoint, strict=False) 182 | # ## add additional network for some methods 183 | # if config["loss"]["random"]: 184 | # random_layer = network.RandomLayer([base_network.output_num(), class_num], config["loss"]["random_dim"]) 185 | # ad_net = network.AdversarialNetwork(config["loss"]["random_dim"], 1024) 186 | # else: 187 | # random_layer = None 188 | # ad_net = network.AdversarialNetwork(base_network.output_num() * class_num, 1024) 189 | # if config["loss"]["random"]: 190 | # random_layer.cuda() 191 | # ad_net = ad_net.cuda() 192 | # parameter_list = base_network.get_parameters()# + ad_net.get_parameters() 193 | 194 | # ## set optimizer 195 | # optimizer_config = config["optimizer"] 196 | # optimizer = optimizer_config["type"](parameter_list, \ 197 | # **(optimizer_config["optim_params"])) 198 | # param_lr = [] 199 | # for param_group in optimizer.param_groups: 200 | # param_lr.append(param_group["lr"]) 201 | # schedule_param = optimizer_config["lr_param"] 202 | # lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]] 203 | 204 | # gpus = config['gpu'].split(',') 205 | # if len(gpus) > 1: 206 | # # ad_net = nn.DataParallel(ad_net, device_ids=[int(i) for i in gpus]) 207 | # base_network = nn.DataParallel(base_network, device_ids=[int(i) for i in gpus]) 208 | 209 | 210 | ## train 211 | # len_train_source = len(dset_loaders["source"]) 212 | # len_train_target = len(dset_loaders["target"]) 213 | # transfer_loss_value = classifier_loss_value = total_loss_value = 0.0 214 | # best_acc = 0.0 215 | 216 | 217 | base_network.train(False) 218 | image_classification_test(dset_loaders, \ 219 | base_network, test_10crop=prep_config["test_10crop"]) 220 | 221 | 222 | return 223 | 224 | if __name__ == "__main__": 225 | parser = argparse.ArgumentParser(description='Conditional Domain Adversarial Network') 226 | parser.add_argument('--method', type=str, default='CDAN+E', choices=['CDAN', 'CDAN+E', 'DANN']) 227 | parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run") 228 | parser.add_argument('--net', type=str, default='ResNet50', choices=["ResNet18", "ResNet34", "ResNet50", "ResNet101", "ResNet152", "VGG11", "VGG13", "VGG16", "VGG19", "VGG11BN", "VGG13BN", "VGG16BN", "VGG19BN", "AlexNet"]) 229 | parser.add_argument('--dset', type=str, default='office', choices=['office', 'image-clef', 'visda', 'office-home', 'imagenet'], help="The dataset or source dataset used") 230 | parser.add_argument('--s_dset_path', type=str, default='dataset/mini-imagenet/list/train_list.txt', help="The source dataset path list") 231 | parser.add_argument('--t_dset_path', type=str, default='dataset/mini-imagenet/list/test_transfer_20.txt', help="The target dataset path list") 232 | parser.add_argument('--test_interval', type=int, default=5000000000, help="interval of two continuous test phase") 233 | parser.add_argument('--snapshot_interval', type=int, default=5000, help="interval of two continuous output model") 234 | parser.add_argument('--output_dir', type=str, default='san', help="output directory of our model (in ../snapshot directory)") 235 | parser.add_argument('--lr', type=float, default=0.001, help="learning rate") 236 | parser.add_argument('--random', type=bool, default=False, help="whether use random projection") 237 | args = parser.parse_args() 238 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id 239 | #os.environ["CUDA_VISIBLE_DEVICES"] = '0,1,2,3' 240 | 241 | # train config 242 | config = {} 243 | config['method'] = args.method 244 | config["gpu"] = args.gpu_id 245 | config["num_iterations"] = 100004 246 | config["test_interval"] = args.test_interval 247 | config["snapshot_interval"] = args.snapshot_interval 248 | config["output_for_test"] = True 249 | config["output_path"] = "snapshot/" + args.output_dir 250 | if not osp.exists(config["output_path"]): 251 | os.system('mkdir -p '+config["output_path"]) 252 | config["out_file"] = open(osp.join(config["output_path"], "log.txt"), "w") 253 | if not osp.exists(config["output_path"]): 254 | os.mkdir(config["output_path"]) 255 | 256 | config["prep"] = {"test_10crop":False, 'params':{"resize_size":256, "crop_size":224, 'alexnet':False}} 257 | config["loss"] = {"trade_off":1.0} 258 | if "AlexNet" in args.net: 259 | config["prep"]['params']['alexnet'] = True 260 | config["prep"]['params']['crop_size'] = 227 261 | config["network"] = {"name":network.AlexNetFc, \ 262 | "params":{"use_bottleneck":True, "bottleneck_dim":256, "new_cls":True} } 263 | elif "ResNet" in args.net: 264 | config["network"] = {"name":network.ResNetFc, \ 265 | "params":{"resnet_name":args.net, "use_bottleneck":True, "bottleneck_dim":256, "new_cls":True} } 266 | elif "VGG" in args.net: 267 | config["network"] = {"name":network.VGGFc, \ 268 | "params":{"vgg_name":args.net, "use_bottleneck":True, "bottleneck_dim":256, "new_cls":True} } 269 | config["loss"]["random"] = args.random 270 | config["loss"]["random_dim"] = 1024 271 | 272 | config["optimizer"] = {"type":optim.SGD, "optim_params":{'lr':args.lr, "momentum":0.9, \ 273 | "weight_decay":0.0005, "nesterov":True}, "lr_type":"inv", \ 274 | "lr_param":{"lr":args.lr, "gamma":0.001, "power":0.75} } 275 | 276 | config["dataset"] = args.dset 277 | config["data"] = {"source":{"list_path":args.s_dset_path, "batch_size":200}, \ 278 | "target":{"list_path":args.t_dset_path, "batch_size":8}, \ 279 | "test":{"list_path":args.t_dset_path, "batch_size":256}} 280 | 281 | if config["dataset"] == "office": 282 | if ("amazon" in args.s_dset_path and "webcam" in args.t_dset_path) or \ 283 | ("webcam" in args.s_dset_path and "dslr" in args.t_dset_path) or \ 284 | ("webcam" in args.s_dset_path and "amazon" in args.t_dset_path) or \ 285 | ("dslr" in args.s_dset_path and "amazon" in args.t_dset_path): 286 | config["optimizer"]["lr_param"]["lr"] = 0.001 # optimal parameters 287 | elif ("amazon" in args.s_dset_path and "dslr" in args.t_dset_path) or \ 288 | ("dslr" in args.s_dset_path and "webcam" in args.t_dset_path): 289 | config["optimizer"]["lr_param"]["lr"] = 0.0003 # optimal parameters 290 | config["network"]["params"]["class_num"] = 31 291 | elif config["dataset"] == "image-clef": 292 | config["optimizer"]["lr_param"]["lr"] = 0.001 # optimal parameters 293 | config["network"]["params"]["class_num"] = 12 294 | elif config["dataset"] == "visda": 295 | config["optimizer"]["lr_param"]["lr"] = 0.001 # optimal parameters 296 | config["network"]["params"]["class_num"] = 12 297 | config['loss']["trade_off"] = 1.0 298 | elif config["dataset"] == "office-home": 299 | config["optimizer"]["lr_param"]["lr"] = 0.001 # optimal parameters 300 | config["network"]["params"]["class_num"] = 65 301 | elif config["dataset"] == "imagenet": 302 | config["optimizer"]["lr_param"]["lr"] = 0.001 # optimal parameters 303 | config["network"]["params"]["class_num"] = 1000 304 | config['loss']["trade_off"] = 1.0 305 | else: 306 | raise ValueError('Dataset cannot be recognized. Please define your own dataset here.') 307 | config["out_file"].write(str(config)) 308 | config["out_file"].flush() 309 | train(config) 310 | -------------------------------------------------------------------------------- /dataset/mini-imagenet/val_source_domain.txt: -------------------------------------------------------------------------------- 1 | val/n09256479/n0925647900009047.png 0 2 | val/n09256479/n0925647900009254.png 0 3 | val/n09256479/n0925647900009147.png 0 4 | val/n09256479/n0925647900009398.png 0 5 | val/n09256479/n0925647900009464.png 0 6 | val/n02174001/n0217400100002843.png 1 7 | val/n02174001/n0217400100002860.png 1 8 | val/n02174001/n0217400100002531.png 1 9 | val/n02174001/n0217400100002775.png 1 10 | val/n02174001/n0217400100002812.png 1 11 | val/n02091244/n0209124400000814.png 2 12 | val/n02091244/n0209124400001102.png 2 13 | val/n02091244/n0209124400000793.png 2 14 | val/n02091244/n0209124400001026.png 2 15 | val/n02091244/n0209124400000771.png 2 16 | val/n02981792/n0298179200004298.png 3 17 | val/n02981792/n0298179200004220.png 3 18 | val/n02981792/n0298179200004435.png 3 19 | val/n02981792/n0298179200004719.png 3 20 | val/n02981792/n0298179200004769.png 3 21 | val/n02114548/n0211454800001390.png 4 22 | val/n02114548/n0211454800001228.png 4 23 | val/n02114548/n0211454800001227.png 4 24 | val/n02114548/n0211454800001289.png 4 25 | val/n02114548/n0211454800001768.png 4 26 | val/n03417042/n0341704200005529.png 5 27 | val/n03417042/n0341704200005684.png 5 28 | val/n03417042/n0341704200005482.png 5 29 | val/n03417042/n0341704200005578.png 5 30 | val/n03417042/n0341704200005789.png 5 31 | val/n02950826/n0295082600003185.png 6 32 | val/n02950826/n0295082600003330.png 6 33 | val/n02950826/n0295082600003114.png 6 34 | val/n02950826/n0295082600003131.png 6 35 | val/n02950826/n0295082600003347.png 6 36 | val/n03773504/n0377350400008061.png 7 37 | val/n03773504/n0377350400008282.png 7 38 | val/n03773504/n0377350400007940.png 7 39 | val/n03773504/n0377350400007923.png 7 40 | val/n03773504/n0377350400008241.png 7 41 | val/n03980874/n0398087400008639.png 8 42 | val/n03980874/n0398087400008652.png 8 43 | val/n03980874/n0398087400008674.png 8 44 | val/n03980874/n0398087400008864.png 8 45 | val/n03980874/n0398087400008407.png 8 46 | val/n03535780/n0353578000006240.png 9 47 | val/n03535780/n0353578000006383.png 9 48 | val/n03535780/n0353578000006452.png 9 49 | val/n03535780/n0353578000006192.png 9 50 | val/n03535780/n0353578000006437.png 9 51 | val/n03584254/n0358425400006895.png 10 52 | val/n03584254/n0358425400006806.png 10 53 | val/n03584254/n0358425400006837.png 10 54 | val/n03584254/n0358425400006996.png 10 55 | val/n03584254/n0358425400006883.png 10 56 | val/n02971356/n0297135600003799.png 11 57 | val/n02971356/n0297135600003714.png 11 58 | val/n02971356/n0297135600004081.png 11 59 | val/n02971356/n0297135600003931.png 11 60 | val/n02971356/n0297135600004044.png 11 61 | val/n03075370/n0307537000005300.png 12 62 | val/n03075370/n0307537000005224.png 12 63 | val/n03075370/n0307537000004884.png 12 64 | val/n03075370/n0307537000004826.png 12 65 | val/n03075370/n0307537000004979.png 12 66 | val/n02138441/n0213844100002064.png 13 67 | val/n02138441/n0213844100002074.png 13 68 | val/n02138441/n0213844100001828.png 13 69 | val/n02138441/n0213844100002095.png 13 70 | val/n02138441/n0213844100002172.png 13 71 | val/n03770439/n0377043900007611.png 14 72 | val/n03770439/n0377043900007331.png 14 73 | val/n03770439/n0377043900007628.png 14 74 | val/n03770439/n0377043900007590.png 14 75 | val/n03770439/n0377043900007219.png 14 76 | val/n01855672/n0185567200000342.png 15 77 | val/n01855672/n0185567200000359.png 15 78 | val/n01855672/n0185567200000470.png 15 79 | val/n01855672/n0185567200000055.png 15 80 | val/n01855672/n0185567200000116.png 15 81 | val/n09256479/n0925647900009047.png 0 82 | val/n09256479/n0925647900009254.png 0 83 | val/n09256479/n0925647900009147.png 0 84 | val/n09256479/n0925647900009398.png 0 85 | val/n09256479/n0925647900009464.png 0 86 | val/n02174001/n0217400100002843.png 1 87 | val/n02174001/n0217400100002860.png 1 88 | val/n02174001/n0217400100002531.png 1 89 | val/n02174001/n0217400100002775.png 1 90 | val/n02174001/n0217400100002812.png 1 91 | val/n02091244/n0209124400000814.png 2 92 | val/n02091244/n0209124400001102.png 2 93 | val/n02091244/n0209124400000793.png 2 94 | val/n02091244/n0209124400001026.png 2 95 | val/n02091244/n0209124400000771.png 2 96 | val/n02981792/n0298179200004298.png 3 97 | val/n02981792/n0298179200004220.png 3 98 | val/n02981792/n0298179200004435.png 3 99 | val/n02981792/n0298179200004719.png 3 100 | val/n02981792/n0298179200004769.png 3 101 | val/n02114548/n0211454800001390.png 4 102 | val/n02114548/n0211454800001228.png 4 103 | val/n02114548/n0211454800001227.png 4 104 | val/n02114548/n0211454800001289.png 4 105 | val/n02114548/n0211454800001768.png 4 106 | val/n03417042/n0341704200005529.png 5 107 | val/n03417042/n0341704200005684.png 5 108 | val/n03417042/n0341704200005482.png 5 109 | val/n03417042/n0341704200005578.png 5 110 | val/n03417042/n0341704200005789.png 5 111 | val/n02950826/n0295082600003185.png 6 112 | val/n02950826/n0295082600003330.png 6 113 | val/n02950826/n0295082600003114.png 6 114 | val/n02950826/n0295082600003131.png 6 115 | val/n02950826/n0295082600003347.png 6 116 | val/n03773504/n0377350400008061.png 7 117 | val/n03773504/n0377350400008282.png 7 118 | val/n03773504/n0377350400007940.png 7 119 | val/n03773504/n0377350400007923.png 7 120 | val/n03773504/n0377350400008241.png 7 121 | val/n03980874/n0398087400008639.png 8 122 | val/n03980874/n0398087400008652.png 8 123 | val/n03980874/n0398087400008674.png 8 124 | val/n03980874/n0398087400008864.png 8 125 | val/n03980874/n0398087400008407.png 8 126 | val/n03535780/n0353578000006240.png 9 127 | val/n03535780/n0353578000006383.png 9 128 | val/n03535780/n0353578000006452.png 9 129 | val/n03535780/n0353578000006192.png 9 130 | val/n03535780/n0353578000006437.png 9 131 | val/n03584254/n0358425400006895.png 10 132 | val/n03584254/n0358425400006806.png 10 133 | val/n03584254/n0358425400006837.png 10 134 | val/n03584254/n0358425400006996.png 10 135 | val/n03584254/n0358425400006883.png 10 136 | val/n02971356/n0297135600003799.png 11 137 | val/n02971356/n0297135600003714.png 11 138 | val/n02971356/n0297135600004081.png 11 139 | val/n02971356/n0297135600003931.png 11 140 | val/n02971356/n0297135600004044.png 11 141 | val/n03075370/n0307537000005300.png 12 142 | val/n03075370/n0307537000005224.png 12 143 | val/n03075370/n0307537000004884.png 12 144 | val/n03075370/n0307537000004826.png 12 145 | val/n03075370/n0307537000004979.png 12 146 | val/n02138441/n0213844100002064.png 13 147 | val/n02138441/n0213844100002074.png 13 148 | val/n02138441/n0213844100001828.png 13 149 | val/n02138441/n0213844100002095.png 13 150 | val/n02138441/n0213844100002172.png 13 151 | val/n03770439/n0377043900007611.png 14 152 | val/n03770439/n0377043900007331.png 14 153 | val/n03770439/n0377043900007628.png 14 154 | val/n03770439/n0377043900007590.png 14 155 | val/n03770439/n0377043900007219.png 14 156 | val/n01855672/n0185567200000342.png 15 157 | val/n01855672/n0185567200000359.png 15 158 | val/n01855672/n0185567200000470.png 15 159 | val/n01855672/n0185567200000055.png 15 160 | val/n01855672/n0185567200000116.png 15 161 | val/n09256479/n0925647900009047.png 0 162 | val/n09256479/n0925647900009254.png 0 163 | val/n09256479/n0925647900009147.png 0 164 | val/n09256479/n0925647900009398.png 0 165 | val/n09256479/n0925647900009464.png 0 166 | val/n02174001/n0217400100002843.png 1 167 | val/n02174001/n0217400100002860.png 1 168 | val/n02174001/n0217400100002531.png 1 169 | val/n02174001/n0217400100002775.png 1 170 | val/n02174001/n0217400100002812.png 1 171 | val/n02091244/n0209124400000814.png 2 172 | val/n02091244/n0209124400001102.png 2 173 | val/n02091244/n0209124400000793.png 2 174 | val/n02091244/n0209124400001026.png 2 175 | val/n02091244/n0209124400000771.png 2 176 | val/n02981792/n0298179200004298.png 3 177 | val/n02981792/n0298179200004220.png 3 178 | val/n02981792/n0298179200004435.png 3 179 | val/n02981792/n0298179200004719.png 3 180 | val/n02981792/n0298179200004769.png 3 181 | val/n02114548/n0211454800001390.png 4 182 | val/n02114548/n0211454800001228.png 4 183 | val/n02114548/n0211454800001227.png 4 184 | val/n02114548/n0211454800001289.png 4 185 | val/n02114548/n0211454800001768.png 4 186 | val/n03417042/n0341704200005529.png 5 187 | val/n03417042/n0341704200005684.png 5 188 | val/n03417042/n0341704200005482.png 5 189 | val/n03417042/n0341704200005578.png 5 190 | val/n03417042/n0341704200005789.png 5 191 | val/n02950826/n0295082600003185.png 6 192 | val/n02950826/n0295082600003330.png 6 193 | val/n02950826/n0295082600003114.png 6 194 | val/n02950826/n0295082600003131.png 6 195 | val/n02950826/n0295082600003347.png 6 196 | val/n03773504/n0377350400008061.png 7 197 | val/n03773504/n0377350400008282.png 7 198 | val/n03773504/n0377350400007940.png 7 199 | val/n03773504/n0377350400007923.png 7 200 | val/n03773504/n0377350400008241.png 7 201 | val/n03980874/n0398087400008639.png 8 202 | val/n03980874/n0398087400008652.png 8 203 | val/n03980874/n0398087400008674.png 8 204 | val/n03980874/n0398087400008864.png 8 205 | val/n03980874/n0398087400008407.png 8 206 | val/n03535780/n0353578000006240.png 9 207 | val/n03535780/n0353578000006383.png 9 208 | val/n03535780/n0353578000006452.png 9 209 | val/n03535780/n0353578000006192.png 9 210 | val/n03535780/n0353578000006437.png 9 211 | val/n03584254/n0358425400006895.png 10 212 | val/n03584254/n0358425400006806.png 10 213 | val/n03584254/n0358425400006837.png 10 214 | val/n03584254/n0358425400006996.png 10 215 | val/n03584254/n0358425400006883.png 10 216 | val/n02971356/n0297135600003799.png 11 217 | val/n02971356/n0297135600003714.png 11 218 | val/n02971356/n0297135600004081.png 11 219 | val/n02971356/n0297135600003931.png 11 220 | val/n02971356/n0297135600004044.png 11 221 | val/n03075370/n0307537000005300.png 12 222 | val/n03075370/n0307537000005224.png 12 223 | val/n03075370/n0307537000004884.png 12 224 | val/n03075370/n0307537000004826.png 12 225 | val/n03075370/n0307537000004979.png 12 226 | val/n02138441/n0213844100002064.png 13 227 | val/n02138441/n0213844100002074.png 13 228 | val/n02138441/n0213844100001828.png 13 229 | val/n02138441/n0213844100002095.png 13 230 | val/n02138441/n0213844100002172.png 13 231 | val/n03770439/n0377043900007611.png 14 232 | val/n03770439/n0377043900007331.png 14 233 | val/n03770439/n0377043900007628.png 14 234 | val/n03770439/n0377043900007590.png 14 235 | val/n03770439/n0377043900007219.png 14 236 | val/n01855672/n0185567200000342.png 15 237 | val/n01855672/n0185567200000359.png 15 238 | val/n01855672/n0185567200000470.png 15 239 | val/n01855672/n0185567200000055.png 15 240 | val/n01855672/n0185567200000116.png 15 241 | val/n09256479/n0925647900009047.png 0 242 | val/n09256479/n0925647900009254.png 0 243 | val/n09256479/n0925647900009147.png 0 244 | val/n09256479/n0925647900009398.png 0 245 | val/n09256479/n0925647900009464.png 0 246 | val/n02174001/n0217400100002843.png 1 247 | val/n02174001/n0217400100002860.png 1 248 | val/n02174001/n0217400100002531.png 1 249 | val/n02174001/n0217400100002775.png 1 250 | val/n02174001/n0217400100002812.png 1 251 | val/n02091244/n0209124400000814.png 2 252 | val/n02091244/n0209124400001102.png 2 253 | val/n02091244/n0209124400000793.png 2 254 | val/n02091244/n0209124400001026.png 2 255 | val/n02091244/n0209124400000771.png 2 256 | val/n02981792/n0298179200004298.png 3 257 | val/n02981792/n0298179200004220.png 3 258 | val/n02981792/n0298179200004435.png 3 259 | val/n02981792/n0298179200004719.png 3 260 | val/n02981792/n0298179200004769.png 3 261 | val/n02114548/n0211454800001390.png 4 262 | val/n02114548/n0211454800001228.png 4 263 | val/n02114548/n0211454800001227.png 4 264 | val/n02114548/n0211454800001289.png 4 265 | val/n02114548/n0211454800001768.png 4 266 | val/n03417042/n0341704200005529.png 5 267 | val/n03417042/n0341704200005684.png 5 268 | val/n03417042/n0341704200005482.png 5 269 | val/n03417042/n0341704200005578.png 5 270 | val/n03417042/n0341704200005789.png 5 271 | val/n02950826/n0295082600003185.png 6 272 | val/n02950826/n0295082600003330.png 6 273 | val/n02950826/n0295082600003114.png 6 274 | val/n02950826/n0295082600003131.png 6 275 | val/n02950826/n0295082600003347.png 6 276 | val/n03773504/n0377350400008061.png 7 277 | val/n03773504/n0377350400008282.png 7 278 | val/n03773504/n0377350400007940.png 7 279 | val/n03773504/n0377350400007923.png 7 280 | val/n03773504/n0377350400008241.png 7 281 | val/n03980874/n0398087400008639.png 8 282 | val/n03980874/n0398087400008652.png 8 283 | val/n03980874/n0398087400008674.png 8 284 | val/n03980874/n0398087400008864.png 8 285 | val/n03980874/n0398087400008407.png 8 286 | val/n03535780/n0353578000006240.png 9 287 | val/n03535780/n0353578000006383.png 9 288 | val/n03535780/n0353578000006452.png 9 289 | val/n03535780/n0353578000006192.png 9 290 | val/n03535780/n0353578000006437.png 9 291 | val/n03584254/n0358425400006895.png 10 292 | val/n03584254/n0358425400006806.png 10 293 | val/n03584254/n0358425400006837.png 10 294 | val/n03584254/n0358425400006996.png 10 295 | val/n03584254/n0358425400006883.png 10 296 | val/n02971356/n0297135600003799.png 11 297 | val/n02971356/n0297135600003714.png 11 298 | val/n02971356/n0297135600004081.png 11 299 | val/n02971356/n0297135600003931.png 11 300 | val/n02971356/n0297135600004044.png 11 301 | val/n03075370/n0307537000005300.png 12 302 | val/n03075370/n0307537000005224.png 12 303 | val/n03075370/n0307537000004884.png 12 304 | val/n03075370/n0307537000004826.png 12 305 | val/n03075370/n0307537000004979.png 12 306 | val/n02138441/n0213844100002064.png 13 307 | val/n02138441/n0213844100002074.png 13 308 | val/n02138441/n0213844100001828.png 13 309 | val/n02138441/n0213844100002095.png 13 310 | val/n02138441/n0213844100002172.png 13 311 | val/n03770439/n0377043900007611.png 14 312 | val/n03770439/n0377043900007331.png 14 313 | val/n03770439/n0377043900007628.png 14 314 | val/n03770439/n0377043900007590.png 14 315 | val/n03770439/n0377043900007219.png 14 316 | val/n01855672/n0185567200000342.png 15 317 | val/n01855672/n0185567200000359.png 15 318 | val/n01855672/n0185567200000470.png 15 319 | val/n01855672/n0185567200000055.png 15 320 | val/n01855672/n0185567200000116.png 15 321 | val/n09256479/n0925647900009047.png 0 322 | val/n09256479/n0925647900009254.png 0 323 | val/n09256479/n0925647900009147.png 0 324 | val/n09256479/n0925647900009398.png 0 325 | val/n09256479/n0925647900009464.png 0 326 | val/n02174001/n0217400100002843.png 1 327 | val/n02174001/n0217400100002860.png 1 328 | val/n02174001/n0217400100002531.png 1 329 | val/n02174001/n0217400100002775.png 1 330 | val/n02174001/n0217400100002812.png 1 331 | val/n02091244/n0209124400000814.png 2 332 | val/n02091244/n0209124400001102.png 2 333 | val/n02091244/n0209124400000793.png 2 334 | val/n02091244/n0209124400001026.png 2 335 | val/n02091244/n0209124400000771.png 2 336 | val/n02981792/n0298179200004298.png 3 337 | val/n02981792/n0298179200004220.png 3 338 | val/n02981792/n0298179200004435.png 3 339 | val/n02981792/n0298179200004719.png 3 340 | val/n02981792/n0298179200004769.png 3 341 | val/n02114548/n0211454800001390.png 4 342 | val/n02114548/n0211454800001228.png 4 343 | val/n02114548/n0211454800001227.png 4 344 | val/n02114548/n0211454800001289.png 4 345 | val/n02114548/n0211454800001768.png 4 346 | val/n03417042/n0341704200005529.png 5 347 | val/n03417042/n0341704200005684.png 5 348 | val/n03417042/n0341704200005482.png 5 349 | val/n03417042/n0341704200005578.png 5 350 | val/n03417042/n0341704200005789.png 5 351 | val/n02950826/n0295082600003185.png 6 352 | val/n02950826/n0295082600003330.png 6 353 | val/n02950826/n0295082600003114.png 6 354 | val/n02950826/n0295082600003131.png 6 355 | val/n02950826/n0295082600003347.png 6 356 | val/n03773504/n0377350400008061.png 7 357 | val/n03773504/n0377350400008282.png 7 358 | val/n03773504/n0377350400007940.png 7 359 | val/n03773504/n0377350400007923.png 7 360 | val/n03773504/n0377350400008241.png 7 361 | val/n03980874/n0398087400008639.png 8 362 | val/n03980874/n0398087400008652.png 8 363 | val/n03980874/n0398087400008674.png 8 364 | val/n03980874/n0398087400008864.png 8 365 | val/n03980874/n0398087400008407.png 8 366 | val/n03535780/n0353578000006240.png 9 367 | val/n03535780/n0353578000006383.png 9 368 | val/n03535780/n0353578000006452.png 9 369 | val/n03535780/n0353578000006192.png 9 370 | val/n03535780/n0353578000006437.png 9 371 | val/n03584254/n0358425400006895.png 10 372 | val/n03584254/n0358425400006806.png 10 373 | val/n03584254/n0358425400006837.png 10 374 | val/n03584254/n0358425400006996.png 10 375 | val/n03584254/n0358425400006883.png 10 376 | val/n02971356/n0297135600003799.png 11 377 | val/n02971356/n0297135600003714.png 11 378 | val/n02971356/n0297135600004081.png 11 379 | val/n02971356/n0297135600003931.png 11 380 | val/n02971356/n0297135600004044.png 11 381 | val/n03075370/n0307537000005300.png 12 382 | val/n03075370/n0307537000005224.png 12 383 | val/n03075370/n0307537000004884.png 12 384 | val/n03075370/n0307537000004826.png 12 385 | val/n03075370/n0307537000004979.png 12 386 | val/n02138441/n0213844100002064.png 13 387 | val/n02138441/n0213844100002074.png 13 388 | val/n02138441/n0213844100001828.png 13 389 | val/n02138441/n0213844100002095.png 13 390 | val/n02138441/n0213844100002172.png 13 391 | val/n03770439/n0377043900007611.png 14 392 | val/n03770439/n0377043900007331.png 14 393 | val/n03770439/n0377043900007628.png 14 394 | val/n03770439/n0377043900007590.png 14 395 | val/n03770439/n0377043900007219.png 14 396 | val/n01855672/n0185567200000342.png 15 397 | val/n01855672/n0185567200000359.png 15 398 | val/n01855672/n0185567200000470.png 15 399 | val/n01855672/n0185567200000055.png 15 400 | val/n01855672/n0185567200000116.png 15 401 | -------------------------------------------------------------------------------- /domain_adaptive_module/network.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torchvision 5 | from torchvision import models 6 | from torch.autograd import Variable 7 | import math 8 | import pdb 9 | 10 | def calc_coeff(iter_num, high=1.0, low=0.0, alpha=10.0, max_iter=10000.0): 11 | return np.float(2.0 * (high - low) / (1.0 + np.exp(-alpha*iter_num / max_iter)) - (high - low) + low) 12 | 13 | def init_weights(m): 14 | classname = m.__class__.__name__ 15 | if classname.find('Conv2d') != -1 or classname.find('ConvTranspose2d') != -1: 16 | nn.init.kaiming_uniform_(m.weight) 17 | nn.init.zeros_(m.bias) 18 | elif classname.find('BatchNorm') != -1: 19 | nn.init.normal_(m.weight, 1.0, 0.02) 20 | nn.init.zeros_(m.bias) 21 | elif classname.find('Linear') != -1: 22 | nn.init.xavier_normal_(m.weight) 23 | nn.init.zeros_(m.bias) 24 | 25 | class RandomLayer(nn.Module): 26 | def __init__(self, input_dim_list=[], output_dim=1024): 27 | super(RandomLayer, self).__init__() 28 | self.input_num = len(input_dim_list) 29 | self.output_dim = output_dim 30 | self.random_matrix = [torch.randn(input_dim_list[i], output_dim) for i in range(self.input_num)] 31 | 32 | def forward(self, input_list): 33 | return_list = [torch.mm(input_list[i], self.random_matrix[i]) for i in range(self.input_num)] 34 | return_tensor = return_list[0] / math.pow(float(self.output_dim), 1.0/len(return_list)) 35 | for single in return_list[1:]: 36 | return_tensor = torch.mul(return_tensor, single) 37 | return return_tensor 38 | 39 | def cuda(self): 40 | super(RandomLayer, self).cuda() 41 | self.random_matrix = [val.cuda() for val in self.random_matrix] 42 | 43 | class LRN(nn.Module): 44 | def __init__(self, local_size=1, alpha=1.0, beta=0.75, ACROSS_CHANNELS=True): 45 | super(LRN, self).__init__() 46 | self.ACROSS_CHANNELS = ACROSS_CHANNELS 47 | if ACROSS_CHANNELS: 48 | self.average=nn.AvgPool3d(kernel_size=(local_size, 1, 1), 49 | stride=1, 50 | padding=(int((local_size-1.0)/2), 0, 0)) 51 | else: 52 | self.average=nn.AvgPool2d(kernel_size=local_size, 53 | stride=1, 54 | padding=int((local_size-1.0)/2)) 55 | self.alpha = alpha 56 | self.beta = beta 57 | 58 | 59 | def forward(self, x): 60 | if self.ACROSS_CHANNELS: 61 | div = x.pow(2).unsqueeze(1) 62 | div = self.average(div).squeeze(1) 63 | div = div.mul(self.alpha).add(1.0).pow(self.beta) 64 | else: 65 | div = x.pow(2) 66 | div = self.average(div) 67 | div = div.mul(self.alpha).add(1.0).pow(self.beta) 68 | x = x.div(div) 69 | return x 70 | 71 | class AlexNet(nn.Module): 72 | 73 | def __init__(self, num_classes=1000): 74 | super(AlexNet, self).__init__() 75 | self.features = nn.Sequential( 76 | nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=0), 77 | nn.ReLU(inplace=True), 78 | LRN(local_size=5, alpha=0.0001, beta=0.75), 79 | nn.MaxPool2d(kernel_size=3, stride=2), 80 | nn.Conv2d(96, 256, kernel_size=5, padding=2, groups=2), 81 | nn.ReLU(inplace=True), 82 | LRN(local_size=5, alpha=0.0001, beta=0.75), 83 | nn.MaxPool2d(kernel_size=3, stride=2), 84 | nn.Conv2d(256, 384, kernel_size=3, padding=1), 85 | nn.ReLU(inplace=True), 86 | nn.Conv2d(384, 384, kernel_size=3, padding=1, groups=2), 87 | nn.ReLU(inplace=True), 88 | nn.Conv2d(384, 256, kernel_size=3, padding=1, groups=2), 89 | nn.ReLU(inplace=True), 90 | nn.MaxPool2d(kernel_size=3, stride=2), 91 | ) 92 | self.classifier = nn.Sequential( 93 | nn.Linear(256 * 6 * 6, 4096), 94 | nn.ReLU(inplace=True), 95 | nn.Dropout(), 96 | nn.Linear(4096, 4096), 97 | nn.ReLU(inplace=True), 98 | nn.Dropout(), 99 | nn.Linear(4096, num_classes), 100 | ) 101 | 102 | def forward(self, x): 103 | x = self.features(x) 104 | print(x.size()) 105 | x = x.view(x.size(0), 256 * 6 * 6) 106 | x = self.classifier(x) 107 | return x 108 | 109 | 110 | def alexnet(pretrained=False, **kwargs): 111 | r"""AlexNet model architecture from the 112 | `"One weird trick..." `_ paper. 113 | Args: 114 | pretrained (bool): If True, returns a model pre-trained on ImageNet 115 | """ 116 | model = AlexNet(**kwargs) 117 | if pretrained: 118 | model_path = './alexnet.pth.tar' 119 | pretrained_model = torch.load(model_path) 120 | model.load_state_dict(pretrained_model['state_dict']) 121 | return model 122 | 123 | # convnet without the last layer 124 | class AlexNetFc(nn.Module): 125 | def __init__(self, use_bottleneck=True, bottleneck_dim=256, new_cls=False, class_num=1000): 126 | super(AlexNetFc, self).__init__() 127 | model_alexnet = alexnet(pretrained=True) 128 | self.features = model_alexnet.features 129 | self.classifier = nn.Sequential() 130 | for i in range(6): 131 | self.classifier.add_module("classifier"+str(i), model_alexnet.classifier[i]) 132 | self.feature_layers = nn.Sequential(self.features, self.classifier) 133 | 134 | self.use_bottleneck = use_bottleneck 135 | self.new_cls = new_cls 136 | if new_cls: 137 | if self.use_bottleneck: 138 | self.bottleneck = nn.Linear(4096, bottleneck_dim) 139 | self.fc = nn.Linear(bottleneck_dim, class_num) 140 | self.bottleneck.apply(init_weights) 141 | self.fc.apply(init_weights) 142 | self.__in_features = bottleneck_dim 143 | else: 144 | self.fc = nn.Linear(4096, class_num) 145 | self.fc.apply(init_weights) 146 | self.__in_features = 4096 147 | else: 148 | self.fc = model_alexnet.classifier[6] 149 | self.__in_features = 4096 150 | 151 | def forward(self, x): 152 | x = self.features(x) 153 | x = x.view(x.size(0), -1) 154 | x = self.classifier(x) 155 | if self.use_bottleneck and self.new_cls: 156 | x = self.bottleneck(x) 157 | y = self.fc(x) 158 | return x, y 159 | 160 | def output_num(self): 161 | return self.__in_features 162 | 163 | def get_parameters(self): 164 | if self.new_cls: 165 | if self.use_bottleneck: 166 | parameter_list = [{"params":self.features.parameters(), "lr_mult":1, 'decay_mult':2}, \ 167 | {"params":self.classifier.parameters(), "lr_mult":1, 'decay_mult':2}, \ 168 | {"params":self.bottleneck.parameters(), "lr_mult":10, 'decay_mult':2}, \ 169 | {"params":self.fc.parameters(), "lr_mult":10, 'decay_mult':2}] 170 | else: 171 | parameter_list = [{"params":self.feature_layers.parameters(), "lr_mult":1, 'decay_mult':2}, \ 172 | {"params":self.classifier.parameters(), "lr_mult":1, 'decay_mult':2}, \ 173 | {"params":self.fc.parameters(), "lr_mult":10, 'decay_mult':2}] 174 | else: 175 | parameter_list = [{"params":self.parameters(), "lr_mult":1, 'decay_mult':2}] 176 | return parameter_list 177 | 178 | 179 | resnet_dict = {"ResNet18":models.resnet18, "ResNet34":models.resnet34, "ResNet50":models.resnet50, "ResNet101":models.resnet101, "ResNet152":models.resnet152} 180 | 181 | def grl_hook(coeff): 182 | def fun1(grad): 183 | return -coeff*grad.clone() 184 | return fun1 185 | 186 | class ResNetFc(nn.Module): 187 | def __init__(self, resnet_name, use_bottleneck=True, bottleneck_dim=256, new_cls=False, class_num=1000, pretrained_model='tiered_checkpoint.pth.tar'): 188 | super(ResNetFc, self).__init__() 189 | model_resnet = resnet_dict[resnet_name](pretrained=True) 190 | 191 | pretrained_dict = torch.load(pretrained_model)['state_dict'] 192 | model_dict = model_resnet.state_dict() 193 | # 1. filter out unnecessary keys 194 | pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items() if k[7:] in model_dict and not k[7:].startswith('fc')} 195 | # 2. overwrite entries in the existing state dict 196 | print(pretrained_dict) 197 | model_dict.update(pretrained_dict) 198 | # 3. load the new state dict 199 | model_resnet.load_state_dict(model_dict) 200 | 201 | self.conv1 = model_resnet.conv1 202 | self.bn1 = model_resnet.bn1 203 | self.relu = model_resnet.relu 204 | self.maxpool = model_resnet.maxpool 205 | self.layer1 = model_resnet.layer1 206 | self.layer2 = model_resnet.layer2 207 | self.layer3 = model_resnet.layer3 208 | self.layer4 = model_resnet.layer4 209 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 210 | self.feature_layers = nn.Sequential(self.conv1, self.bn1, self.relu, self.maxpool, \ 211 | self.layer1, self.layer2, self.layer3, self.layer4, self.avgpool) 212 | 213 | self.use_bottleneck = use_bottleneck 214 | self.new_cls = new_cls 215 | if new_cls: 216 | if self.use_bottleneck: 217 | self.bottleneck = nn.Linear(model_resnet.fc.in_features, bottleneck_dim) 218 | self.fc = nn.Linear(bottleneck_dim, class_num) 219 | self.bottleneck.apply(init_weights) 220 | self.fc.apply(init_weights) 221 | self.__in_features = bottleneck_dim 222 | else: 223 | self.fc = nn.Linear(model_resnet.fc.in_features, class_num) 224 | self.fc.apply(init_weights) 225 | self.__in_features = model_resnet.fc.in_features 226 | else: 227 | self.fc = model_resnet.fc 228 | self.__in_features = model_resnet.fc.in_features 229 | 230 | def forward(self, x): 231 | x = self.feature_layers(x) 232 | x = x.view(x.size(0), -1) 233 | if self.use_bottleneck and self.new_cls: 234 | x = self.bottleneck(x) 235 | y = self.fc(x) 236 | return x, y 237 | 238 | def output_num(self): 239 | return self.__in_features 240 | 241 | def get_parameters(self): 242 | if self.new_cls: 243 | if self.use_bottleneck: 244 | parameter_list = [{"params":self.feature_layers.parameters(), "lr_mult":1, 'decay_mult':2}, \ 245 | {"params":self.bottleneck.parameters(), "lr_mult":10, 'decay_mult':2}, \ 246 | {"params":self.fc.parameters(), "lr_mult":10, 'decay_mult':2}] 247 | else: 248 | parameter_list = [{"params":self.feature_layers.parameters(), "lr_mult":1, 'decay_mult':2}, \ 249 | {"params":self.fc.parameters(), "lr_mult":10, 'decay_mult':2}] 250 | else: 251 | parameter_list = [{"params":self.parameters(), "lr_mult":1, 'decay_mult':2}] 252 | return parameter_list 253 | 254 | vgg_dict = {"VGG11":models.vgg11, "VGG13":models.vgg13, "VGG16":models.vgg16, "VGG19":models.vgg19, "VGG11BN":models.vgg11_bn, "VGG13BN":models.vgg13_bn, "VGG16BN":models.vgg16_bn, "VGG19BN":models.vgg19_bn} 255 | class VGGFc(nn.Module): 256 | def __init__(self, vgg_name, use_bottleneck=True, bottleneck_dim=256, new_cls=False, class_num=1000): 257 | super(VGGFc, self).__init__() 258 | model_vgg = vgg_dict[vgg_name](pretrained=True) 259 | self.features = model_vgg.features 260 | self.classifier = nn.Sequential() 261 | for i in range(6): 262 | self.classifier.add_module("classifier"+str(i), model_vgg.classifier[i]) 263 | self.feature_layers = nn.Sequential(self.features, self.classifier) 264 | 265 | self.use_bottleneck = use_bottleneck 266 | self.new_cls = new_cls 267 | if new_cls: 268 | if self.use_bottleneck: 269 | self.bottleneck = nn.Linear(4096, bottleneck_dim) 270 | self.fc = nn.Linear(bottleneck_dim, class_num) 271 | self.bottleneck.apply(init_weights) 272 | self.fc.apply(init_weights) 273 | self.__in_features = bottleneck_dim 274 | else: 275 | self.fc = nn.Linear(4096, class_num) 276 | self.fc.apply(init_weights) 277 | self.__in_features = 4096 278 | else: 279 | self.fc = model_vgg.classifier[6] 280 | self.__in_features = 4096 281 | 282 | def forward(self, x): 283 | x = self.features(x) 284 | x = x.view(x.size(0), -1) 285 | x = self.classifier(x) 286 | if self.use_bottleneck and self.new_cls: 287 | x = self.bottleneck(x) 288 | y = self.fc(x) 289 | return x, y 290 | 291 | def output_num(self): 292 | return self.__in_features 293 | 294 | def get_parameters(self): 295 | if self.new_cls: 296 | if self.use_bottleneck: 297 | parameter_list = [{"params":self.features.parameters(), "lr_mult":1, 'decay_mult':2}, \ 298 | {"params":self.classifier.parameters(), "lr_mult":1, 'decay_mult':2}, \ 299 | {"params":self.bottleneck.parameters(), "lr_mult":10, 'decay_mult':2}, \ 300 | {"params":self.fc.parameters(), "lr_mult":10, 'decay_mult':2}] 301 | else: 302 | parameter_list = [{"params":self.feature_layers.parameters(), "lr_mult":1, 'decay_mult':2}, \ 303 | {"params":self.classifier.parameters(), "lr_mult":1, 'decay_mult':2}, \ 304 | {"params":self.fc.parameters(), "lr_mult":10, 'decay_mult':2}] 305 | else: 306 | parameter_list = [{"params":self.parameters(), "lr_mult":1, 'decay_mult':2}] 307 | return parameter_list 308 | 309 | # For SVHN dataset 310 | class DTN(nn.Module): 311 | def __init__(self): 312 | super(DTN, self).__init__() 313 | self.conv_params = nn.Sequential ( 314 | nn.Conv2d(3, 64, kernel_size=5, stride=2, padding=2), 315 | nn.BatchNorm2d(64), 316 | nn.Dropout2d(0.1), 317 | nn.ReLU(), 318 | nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2), 319 | nn.BatchNorm2d(128), 320 | nn.Dropout2d(0.3), 321 | nn.ReLU(), 322 | nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2), 323 | nn.BatchNorm2d(256), 324 | nn.Dropout2d(0.5), 325 | nn.ReLU() 326 | ) 327 | 328 | self.fc_params = nn.Sequential ( 329 | nn.Linear(256*4*4, 512), 330 | nn.BatchNorm1d(512), 331 | nn.ReLU(), 332 | nn.Dropout() 333 | ) 334 | 335 | self.classifier = nn.Linear(512, 10) 336 | self.__in_features = 512 337 | 338 | def forward(self, x): 339 | x = self.conv_params(x) 340 | x = x.view(x.size(0), -1) 341 | x = self.fc_params(x) 342 | y = self.classifier(x) 343 | return x, y 344 | 345 | def output_num(self): 346 | return self.__in_features 347 | 348 | class LeNet(nn.Module): 349 | def __init__(self): 350 | super(LeNet, self).__init__() 351 | self.conv_params = nn.Sequential( 352 | nn.Conv2d(1, 20, kernel_size=5), 353 | nn.MaxPool2d(2), 354 | nn.ReLU(), 355 | nn.Conv2d(20, 50, kernel_size=5), 356 | nn.Dropout2d(p=0.5), 357 | nn.MaxPool2d(2), 358 | nn.ReLU(), 359 | ) 360 | 361 | self.fc_params = nn.Sequential(nn.Linear(50*4*4, 500), nn.ReLU(), nn.Dropout(p=0.5)) 362 | self.classifier = nn.Linear(500, 10) 363 | self.__in_features = 500 364 | 365 | 366 | def forward(self, x): 367 | x = self.conv_params(x) 368 | x = x.view(x.size(0), -1) 369 | x = self.fc_params(x) 370 | y = self.classifier(x) 371 | return x, y 372 | 373 | def output_num(self): 374 | return self.__in_features 375 | 376 | class AdversarialNetwork(nn.Module): 377 | def __init__(self, in_feature, hidden_size): 378 | super(AdversarialNetwork, self).__init__() 379 | self.ad_layer1 = nn.Linear(in_feature, hidden_size) 380 | self.ad_layer2 = nn.Linear(hidden_size, hidden_size) 381 | self.ad_layer3 = nn.Linear(hidden_size, 1) 382 | self.relu1 = nn.ReLU() 383 | self.relu2 = nn.ReLU() 384 | self.dropout1 = nn.Dropout(0.5) 385 | self.dropout2 = nn.Dropout(0.5) 386 | self.sigmoid = nn.Sigmoid() 387 | self.apply(init_weights) 388 | self.iter_num = 0 389 | self.alpha = 10 390 | self.low = 0.0 391 | self.high = 1.0 392 | self.max_iter = 10000.0 393 | 394 | def forward(self, x): 395 | if self.training: 396 | self.iter_num += 1 397 | coeff = calc_coeff(self.iter_num, self.high, self.low, self.alpha, self.max_iter) 398 | x = x * 1.0 399 | x.register_hook(grl_hook(coeff)) 400 | x = self.ad_layer1(x) 401 | x = self.relu1(x) 402 | x = self.dropout1(x) 403 | x = self.ad_layer2(x) 404 | x = self.relu2(x) 405 | x = self.dropout2(x) 406 | y = self.ad_layer3(x) 407 | y = self.sigmoid(y) 408 | return y 409 | 410 | def output_num(self): 411 | return 1 412 | def get_parameters(self): 413 | return [{"params":self.parameters(), "lr_mult":10, 'decay_mult':2}] 414 | -------------------------------------------------------------------------------- /train_lambda.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path as osp 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | import torch.nn.functional as F 10 | import domain_adaptive_module.network as network 11 | import domain_adaptive_module.loss as loss 12 | import domain_adaptive_module.pre_process as prep 13 | from torch.utils.data import DataLoader 14 | import domain_adaptive_module.lr_schedule as lr_schedule 15 | from torch.autograd import Variable 16 | from data_loader import * 17 | import random 18 | import pdb 19 | import math 20 | from prototypical_module.utils import pprint, set_gpu, ensure_path, Averager, Timer, count_acc, euclidean_metric 21 | 22 | 23 | class LambdaLearner(nn.Module): 24 | def __init__(self, feature_dim): 25 | super(LambdaLearner, self).__init__() 26 | self.fc = nn.Linear(feature_dim, 1) 27 | self.sigmoid = nn.Sigmoid() 28 | def forward(self, x): 29 | out = self.fc(x) 30 | out = self.sigmoid(out) 31 | return out 32 | 33 | def image_classification_test(loader, model, test_10crop=True): 34 | start_test = True 35 | with torch.no_grad(): 36 | if test_10crop: 37 | iter_test = [iter(loader['test'][i]) for i in range(10)] 38 | for i in range(len(loader['test'][0])): 39 | data = [iter_test[j].next() for j in range(10)] 40 | inputs = [data[j][0] for j in range(10)] 41 | labels = data[0][1] 42 | for j in range(10): 43 | inputs[j] = inputs[j].cuda() 44 | labels = labels 45 | outputs = [] 46 | for j in range(10): 47 | _, predict_out = model(inputs[j]) 48 | outputs.append(nn.Softmax(dim=1)(predict_out)) 49 | outputs = sum(outputs) 50 | if start_test: 51 | all_output = outputs.float().cpu() 52 | all_label = labels.float() 53 | start_test = False 54 | else: 55 | all_output = torch.cat((all_output, outputs.float().cpu()), 0) 56 | all_label = torch.cat((all_label, labels.float()), 0) 57 | else: 58 | iter_test = iter(loader["test"]) 59 | for i in range(len(loader['test'])): 60 | data = iter_test.next() 61 | inputs = data[0] 62 | labels = data[1] 63 | inputs = inputs.cuda() 64 | labels = labels.cuda() 65 | _, outputs = model(inputs) 66 | if start_test: 67 | all_output = outputs.float().cpu() 68 | all_label = labels.float() 69 | start_test = False 70 | else: 71 | all_output = torch.cat((all_output, outputs.float().cpu()), 0) 72 | all_label = torch.cat((all_label, labels.float()), 0) 73 | _, predict = torch.max(all_output, 1) 74 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0]) 75 | return accuracy 76 | 77 | 78 | def train(config): 79 | ## set pre-process 80 | prep_dict = {} 81 | prep_config = config["prep"] 82 | prep_dict["source"] = prep.image_train(**config["prep"]['params']) 83 | prep_dict["target"] = prep.image_train(**config["prep"]['params']) 84 | if prep_config["test_10crop"]: 85 | prep_dict["test"] = prep.image_test_10crop(**config["prep"]['params']) 86 | else: 87 | prep_dict["test"] = prep.image_test(**config["prep"]['params']) 88 | 89 | ## prepare data 90 | dsets = {} 91 | dset_loaders = {} 92 | data_config = config["data"] 93 | train_bs = data_config["source"]["batch_size"] 94 | test_bs = data_config["test"]["batch_size"] 95 | 96 | dsets["target"] = MiniImageNet(root=data_config["target"]["root"], dataset=config["dataset"], mode=data_config["target"]["split"]) 97 | dset_loaders["target"] = DataLoader(dsets["target"], batch_size=config["shot"] * config["train_way"], \ 98 | shuffle=True, num_workers=4, drop_last=True) 99 | dsets["source"] = MiniImageNet(root=data_config["source"]["root"], dataset=config["dataset"], mode=data_config["source"]["split"]) 100 | fsl_train_sampler = CategoriesSampler(dsets["source"].label, 100, 101 | config["train_way"], config["shot"] + config["query"]) 102 | dset_loaders["source"] = DataLoader(dataset=dsets["source"], batch_sampler=fsl_train_sampler, 103 | num_workers=4, pin_memory=True) 104 | fsl_valset = MiniImageNet(root=data_config["fsl_test"]["root"], dataset=config["dataset"], mode=data_config["fsl_test"]["split"]) 105 | fsl_val_sampler = CategoriesSampler(fsl_valset.label, 400, 106 | config["test_way"], config["shot"] + config["query"]) 107 | fsl_val_loader = DataLoader(dataset=fsl_valset, batch_sampler=fsl_val_sampler, 108 | num_workers=4, pin_memory=True) 109 | 110 | class_num = config["network"]["params"]["class_num"] 111 | 112 | ## set base network 113 | net_config = config["network"] 114 | base_network = net_config["name"](**net_config["params"]) 115 | base_network = base_network.cuda() 116 | 117 | ## add additional network for some methods 118 | if config["loss"]["random"]: 119 | random_layer = network.RandomLayer([base_network.output_num(), class_num], config["loss"]["random_dim"]) 120 | ad_net = network.AdversarialNetwork(config["loss"]["random_dim"], 1024) 121 | else: 122 | random_layer = None 123 | ad_net = network.AdversarialNetwork(base_network.output_num() * class_num, 1024) 124 | if config["loss"]["random"]: 125 | random_layer.cuda() 126 | ad_net = ad_net.cuda() 127 | parameter_list = base_network.get_parameters() + ad_net.get_parameters() 128 | 129 | ## set optimizer 130 | optimizer_config = config["optimizer"] 131 | optimizer = optimizer_config["type"](parameter_list, \ 132 | **(optimizer_config["optim_params"])) 133 | param_lr = [] 134 | for param_group in optimizer.param_groups: 135 | param_lr.append(param_group["lr"]) 136 | schedule_param = optimizer_config["lr_param"] 137 | lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]] 138 | 139 | gpus = config['gpu'].split(',') 140 | if len(gpus) > 1: 141 | ad_net = nn.DataParallel(ad_net, device_ids=[int(i) for i in gpus]) 142 | base_network = nn.DataParallel(base_network, device_ids=[int(i) for i in gpus]) 143 | 144 | 145 | ## train 146 | len_train_source = len(dset_loaders["source"]) 147 | len_train_target = len(dset_loaders["target"]) 148 | len_train_source_target = len(dset_loaders["source_target"]) 149 | transfer_loss_value = classifier_loss_value = total_loss_value = 0.0 150 | best_acc = 0.0 151 | start = 0 152 | for i in range(config["num_iterations"]): 153 | if i % config["test_interval"] == config["test_interval"] - 1: 154 | base_network.train(False) 155 | # temp_acc = image_classification_test(dset_loaders, \ 156 | # base_network, test_10crop=prep_config["test_10crop"]) 157 | # temp_model = nn.Sequential(base_network) 158 | # if temp_acc > best_acc: 159 | # best_acc = temp_acc 160 | # best_model = temp_model 161 | # log_str = "iter: {:05d}, precision: {:.5f}".format(i, temp_acc) 162 | # config["out_file"].write(log_str+"\n") 163 | # config["out_file"].flush() 164 | # print(log_str) 165 | for i, batch in enumerate(fsl_val_loader, 1): 166 | data, _ = [_.cuda() for _ in batch] 167 | k = config["test_way"] * config["shot"] 168 | data_shot, data_query = data[:k], data[k:] 169 | 170 | x, _ = base_network(data_shot) 171 | x = x.reshape(config["shot"], config["test_way"], -1).mean(dim=0) 172 | p = x 173 | proto_query, _ = base_network(data_query) 174 | proto_query = proto_query.reshape(config["shot"], config["train_way"], -1).mean(dim=0) 175 | logits = euclidean_metric(proto_query, p) 176 | 177 | label = torch.arange(config["test_way"]).repeat(config["query"]) 178 | label = label.type(torch.cuda.LongTensor) 179 | 180 | acc = count_acc(logits, label) 181 | ave_acc.add(acc) 182 | print('batch {}: {:.2f}({:.2f})'.format(i, ave_acc.item() * 100, acc * 100)) 183 | 184 | x = None; p = None; logits = None 185 | if i % config["snapshot_interval"] == 0: 186 | torch.save(nn.Sequential(base_network), osp.join(config["output_path"], \ 187 | "iter_{:05d}_model.pth.tar".format(i))) 188 | 189 | loss_params = config["loss"] 190 | ## train one iter 191 | base_network.train(True) 192 | ad_net.train(True) 193 | optimizer = lr_scheduler(optimizer, i, **schedule_param) 194 | optimizer.zero_grad() 195 | 196 | if i % len_train_source == 0: 197 | iter_fsl_train = iter(dset_loaders["source"]) 198 | if i % len_train_target == 0: 199 | iter_target = iter(dset_loaders["target"]) 200 | 201 | inputs_target, labels_target = iter_target.next() 202 | inputs_fsl, labels_fsl = iter_fsl_train.next() 203 | inputs_target = inputs_target.cuda() 204 | inputs_fsl, labels_fsl = inputs_fsl.cuda(), labels_fsl.cuda() 205 | 206 | p = config["shot"] * config["train_way"] 207 | data_shot, data_query = inputs_fsl[:p], inputs_fsl[p:] 208 | labels_source = labels_fsl[:p] 209 | 210 | proto_source, outputs_source = base_network(data_shot) 211 | features_target, outputs_target = base_network(inputs_target) 212 | proto = proto_source.reshape(config["shot"], config["train_way"], -1).mean(dim=0) 213 | 214 | label = torch.arange(config["train_way"]).repeat(config["query"]) 215 | label = label.type(torch.cuda.LongTensor) 216 | query_proto, _ = base_network(data_query) 217 | logits = euclidean_metric(query_proto, proto) 218 | # fsl_loss = F.cross_entropy(logits, label) 219 | fsl_loss = nn.CrossEntropyLoss()(logits, label) 220 | fsl_acc = count_acc(logits, label) 221 | 222 | features = torch.cat((proto_source, features_target), dim=0) 223 | outputs = torch.cat((outputs_source, outputs_target), dim=0) 224 | softmax_out = nn.Softmax(dim=1)(outputs) 225 | if config['method'] == 'CDAN+E': 226 | entropy = loss.Entropy(softmax_out) 227 | transfer_loss = loss.CDAN([features, softmax_out], ad_net, entropy, network.calc_coeff(i), random_layer) 228 | elif config['method'] == 'CDAN': 229 | transfer_loss = loss.CDAN([features, softmax_out], ad_net, None, None, random_layer) 230 | elif config['method'] == 'DANN': 231 | transfer_loss = loss.DANN(features, ad_net) 232 | else: 233 | raise ValueError('Method cannot be recognized.') 234 | # classifier_loss = nn.CrossEntropyLoss()(outputs_source, labels_source) 235 | 236 | 237 | 238 | if i % 1 == 0: 239 | print('iter: ', i, 'transfer_loss: ', transfer_loss.data, 'fsl_loss: ', fsl_loss.data, 'fsl_acc: ', fsl_acc) 240 | total_loss = loss_params["trade_off"] * transfer_loss + 0.2 * fsl_loss 241 | print total_loss 242 | total_loss.backward() 243 | optimizer.step() 244 | torch.save(best_model, osp.join(config["output_path"], "best_model.pth.tar")) 245 | return best_acc 246 | 247 | 248 | if __name__ == "__main__": 249 | parser = argparse.ArgumentParser(description='Conditional Domain Adversarial Network') 250 | parser.add_argument('--method', type=str, default='CDAN+E', choices=['CDAN', 'CDAN+E', 'DANN']) 251 | parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run") 252 | parser.add_argument('--net', type=str, default='ResNet50', choices=["ResNet18", "ResNet34", "ResNet50", "ResNet101", "ResNet152", "VGG11", "VGG13", "VGG16", "VGG19", "VGG11BN", "VGG13BN", "VGG16BN", "VGG19BN", "AlexNet"]) 253 | parser.add_argument('--dset', type=str, default='office', choices=['office', 'image-clef', 'visda', 'office-home', 'mini-imagenet', 'tiered-imagenet'], help="The dataset or source dataset used") 254 | parser.add_argument('--s_dset_path', type=str, default='dataset/mini-imagenet/train', help="The dataset path") 255 | parser.add_argument('--fsl_test_path', type=str, default='dataset/mini-imagenet/test_new_domain', help="The dataset path") 256 | parser.add_argument('--test_interval', type=int, default=10000, help="interval of two continuous test phase") 257 | parser.add_argument('--snapshot_interval', type=int, default=500, help="interval of two continuous output model") 258 | parser.add_argument('--output_dir', type=str, default='san', help="output directory of our model (in ../snapshot directory)") 259 | parser.add_argument('--lr', type=float, default=0.0005, help="learning rate") 260 | parser.add_argument('--random', type=bool, default=False, help="whether use random projection") 261 | parser.add_argument('--shot', type=int, default=1) 262 | parser.add_argument('--query', type=int, default=15) 263 | parser.add_argument('--train-way', type=int, default=30) 264 | parser.add_argument('--test-way', type=int, default=5) 265 | args = parser.parse_args() 266 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id 267 | #os.environ["CUDA_VISIBLE_DEVICES"] = '0,1,2,3' 268 | 269 | # train config 270 | config = {} 271 | config['method'] = args.method 272 | config["gpu"] = args.gpu_id 273 | config["num_iterations"] = 100004 274 | config["test_interval"] = args.test_interval 275 | config["snapshot_interval"] = args.snapshot_interval 276 | config["output_for_test"] = True 277 | config["output_path"] = "snapshot/" + args.output_dir 278 | if not osp.exists(config["output_path"]): 279 | os.system('mkdir -p '+config["output_path"]) 280 | config["out_file"] = open(osp.join(config["output_path"], "log.txt"), "w") 281 | if not osp.exists(config["output_path"]): 282 | os.mkdir(config["output_path"]) 283 | 284 | config["prep"] = {"test_10crop":True, 'params':{"resize_size":256, "crop_size":224, 'alexnet':False}} 285 | config["loss"] = {"trade_off":1.0} 286 | if "AlexNet" in args.net: 287 | config["prep"]['params']['alexnet'] = True 288 | config["prep"]['params']['crop_size'] = 227 289 | config["network"] = {"name":network.AlexNetFc, \ 290 | "params":{"use_bottleneck":True, "bottleneck_dim":256, "new_cls":True} } 291 | elif "ResNet" in args.net: 292 | config["network"] = {"name":network.ResNetFc, \ 293 | "params":{"resnet_name":args.net, "use_bottleneck":True, "bottleneck_dim":256, "new_cls":True} } 294 | elif "VGG" in args.net: 295 | config["network"] = {"name":network.VGGFc, \ 296 | "params":{"vgg_name":args.net, "use_bottleneck":True, "bottleneck_dim":256, "new_cls":True} } 297 | config["loss"]["random"] = args.random 298 | config["loss"]["random_dim"] = 1024 299 | 300 | config["optimizer"] = {"type":optim.SGD, "optim_params":{'lr':args.lr, "momentum":0.9, \ 301 | "weight_decay":0.0005, "nesterov":True}, "lr_type":"inv", \ 302 | "lr_param":{"lr":args.lr, "gamma":0.001, "power":0.75} } 303 | 304 | config["dataset"] = args.dset 305 | config["data"] = {"source":{"root":args.s_dset_path, "split":"train", "batch_size":50}, \ 306 | "target":{"root":args.s_dset_path, "split":"val_new_domain", "batch_size":8}, \ 307 | "test":{"root":args.s_dset_path, "split":"val_new_domain", "batch_size":4}, \ 308 | "fsl_test":{"root":args.fsl_test_path, "split":"val_new_domain_fsl", "batch_size":4}} 309 | 310 | if config["dataset"] == "office": 311 | if ("amazon" in args.s_dset_path and "webcam" in args.t_dset_path) or \ 312 | ("webcam" in args.s_dset_path and "dslr" in args.t_dset_path) or \ 313 | ("webcam" in args.s_dset_path and "amazon" in args.t_dset_path) or \ 314 | ("dslr" in args.s_dset_path and "amazon" in args.t_dset_path): 315 | config["optimizer"]["lr_param"]["lr"] = 0.001 # optimal parameters 316 | elif ("amazon" in args.s_dset_path and "dslr" in args.t_dset_path) or \ 317 | ("dslr" in args.s_dset_path and "webcam" in args.t_dset_path): 318 | config["optimizer"]["lr_param"]["lr"] = 0.0003 # optimal parameters 319 | config["network"]["params"]["class_num"] = 31 320 | elif config["dataset"] == "image-clef": 321 | config["optimizer"]["lr_param"]["lr"] = 0.001 # optimal parameters 322 | config["network"]["params"]["class_num"] = 12 323 | elif config["dataset"] == "visda": 324 | config["optimizer"]["lr_param"]["lr"] = 0.001 # optimal parameters 325 | config["network"]["params"]["class_num"] = 12 326 | config['loss']["trade_off"] = 1.0 327 | elif config["dataset"] == "office-home": 328 | config["optimizer"]["lr_param"]["lr"] = 0.001 # optimal parameters 329 | config["network"]["params"]["class_num"] = 65 330 | elif config["dataset"] == "mini-imagenet": 331 | config["optimizer"]["lr_param"]["lr"] = 0.001 # optimal parameters 332 | config["network"]["params"]["class_num"] = 64 333 | config['loss']["trade_off"] = 1.0 334 | elif config["dataset"] == "tiered-imagenet": 335 | config["optimizer"]["lr_param"]["lr"] = 0.001 # optimal parameters 336 | config["network"]["params"]["class_num"] = 351 337 | config['loss']["trade_off"] = 1.0 338 | else: 339 | raise ValueError('Dataset cannot be recognized. Please define your own dataset here.') 340 | config["out_file"].write(str(config)) 341 | config["out_file"].flush() 342 | 343 | config["shot"] = args.shot 344 | config["query"] = args.query 345 | config["train_way"] = args.train_way 346 | config["test_way"] = args.test_way 347 | train(config) 348 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path as osp 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | import torch.nn.functional as F 10 | import domain_adaptive_module.network as network 11 | import domain_adaptive_module.loss as loss 12 | import domain_adaptive_module.pre_process as prep 13 | from torch.utils.data import DataLoader 14 | import domain_adaptive_module.lr_schedule as lr_schedule 15 | from torch.autograd import Variable 16 | from data_loader import * 17 | import random 18 | import pdb 19 | import math 20 | from prototypical_module.utils import pprint, set_gpu, ensure_path, Averager, Timer, count_acc, euclidean_metric 21 | from torch.autograd import Variable 22 | from torch.nn.parameter import Parameter 23 | 24 | 25 | class learnedweight(nn.Module): 26 | def __init__(self): 27 | super(learnedweight, self).__init__() 28 | self.fsl_weight = Parameter(torch.ones(1), requires_grad=True) 29 | self.da_weight = Parameter(torch.ones(1), requires_grad=True) 30 | 31 | def forward(self, fsl_loss, da_loss): 32 | final_loss = self.fsl_weight + torch.exp(-1 * self.fsl_weight)*fsl_loss + self.da_weight + torch.exp(-1 * self.da_weight)*da_loss 33 | return final_loss 34 | 35 | def image_classification_test(loader, model, test_10crop=True): 36 | start_test = True 37 | with torch.no_grad(): 38 | if test_10crop: 39 | iter_test = [iter(loader['test'][i]) for i in range(10)] 40 | for i in range(len(loader['test'][0])): 41 | data = [iter_test[j].next() for j in range(10)] 42 | inputs = [data[j][0] for j in range(10)] 43 | labels = data[0][1] 44 | for j in range(10): 45 | inputs[j] = inputs[j].cuda() 46 | labels = labels 47 | outputs = [] 48 | for j in range(10): 49 | _, predict_out = model(inputs[j]) 50 | outputs.append(nn.Softmax(dim=1)(predict_out)) 51 | outputs = sum(outputs) 52 | if start_test: 53 | all_output = outputs.float().cpu() 54 | all_label = labels.float() 55 | start_test = False 56 | else: 57 | all_output = torch.cat((all_output, outputs.float().cpu()), 0) 58 | all_label = torch.cat((all_label, labels.float()), 0) 59 | else: 60 | iter_test = iter(loader["test"]) 61 | for i in range(len(loader['test'])): 62 | data = iter_test.next() 63 | inputs = data[0] 64 | labels = data[1] 65 | inputs = inputs.cuda() 66 | labels = labels.cuda() 67 | _, outputs = model(inputs) 68 | if start_test: 69 | all_output = outputs.float().cpu() 70 | all_label = labels.float() 71 | start_test = False 72 | else: 73 | all_output = torch.cat((all_output, outputs.float().cpu()), 0) 74 | all_label = torch.cat((all_label, labels.float()), 0) 75 | _, predict = torch.max(all_output, 1) 76 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0]) 77 | return accuracy 78 | 79 | 80 | def train(config): 81 | ## set pre-process 82 | prep_dict = {} 83 | prep_config = config["prep"] 84 | prep_dict["source"] = prep.image_train(**config["prep"]['params']) 85 | prep_dict["target"] = prep.image_train(**config["prep"]['params']) 86 | if prep_config["test_10crop"]: 87 | prep_dict["test"] = prep.image_test_10crop(**config["prep"]['params']) 88 | else: 89 | prep_dict["test"] = prep.image_test(**config["prep"]['params']) 90 | 91 | ## prepare data 92 | dsets = {} 93 | dset_loaders = {} 94 | data_config = config["data"] 95 | train_bs = data_config["source"]["batch_size"] 96 | test_bs = data_config["test"]["batch_size"] 97 | 98 | dsets["target"] = MiniImageNet(root=data_config["target"]["root"], dataset=config["dataset"], mode=data_config["target"]["split"]) 99 | dset_loaders["target"] = DataLoader(dsets["target"], batch_size=config["shot"] * config["train_way"], \ 100 | shuffle=True, num_workers=4, drop_last=True) 101 | dsets["source"] = MiniImageNet(root=data_config["source"]["root"], dataset=config["dataset"], mode=data_config["source"]["split"]) 102 | fsl_train_sampler = CategoriesSampler(dsets["source"].label, 100, 103 | config["train_way"], config["shot"] + config["query"]) 104 | dset_loaders["source"] = DataLoader(dataset=dsets["source"], batch_sampler=fsl_train_sampler, 105 | num_workers=4, pin_memory=True) 106 | fsl_valset = MiniImageNet(root=data_config["fsl_test"]["root"], dataset=config["dataset"], mode=data_config["fsl_test"]["split"]) 107 | fsl_val_sampler = CategoriesSampler(fsl_valset.label, 400, 108 | config["test_way"], config["shot"] + config["query"]) 109 | fsl_val_loader = DataLoader(dataset=fsl_valset, batch_sampler=fsl_val_sampler, 110 | num_workers=4, pin_memory=True) 111 | 112 | class_num = config["network"]["params"]["class_num"] 113 | 114 | ## set base network 115 | net_config = config["network"] 116 | base_network = net_config["name"](**net_config["params"]) 117 | base_network = base_network.cuda() 118 | 119 | ## add additional network for some methods 120 | if config["loss"]["random"]: 121 | random_layer = network.RandomLayer([base_network.output_num(), class_num], config["loss"]["random_dim"]) 122 | ad_net = network.AdversarialNetwork(config["loss"]["random_dim"], 1024) 123 | else: 124 | random_layer = None 125 | ad_net = network.AdversarialNetwork(base_network.output_num() * class_num, 1024) 126 | if config["loss"]["random"]: 127 | random_layer.cuda() 128 | ad_net = ad_net.cuda() 129 | autoweight = learnedweight().cuda() 130 | #print(base_network.get_parameters()) 131 | #print([{'params': autoweight.parameters(), 'lr_mult': 1, 'decay_mult': 2}]) 132 | parameter_list = base_network.get_parameters() + ad_net.get_parameters() + [{'params': autoweight.parameters(), 'lr_mult': 1, 'decay_mult': 2}] 133 | 134 | ## set optimizer 135 | optimizer_config = config["optimizer"] 136 | optimizer = optimizer_config["type"](parameter_list, \ 137 | **(optimizer_config["optim_params"])) 138 | param_lr = [] 139 | for param_group in optimizer.param_groups: 140 | param_lr.append(param_group["lr"]) 141 | schedule_param = optimizer_config["lr_param"] 142 | lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]] 143 | 144 | gpus = config['gpu'].split(',') 145 | if len(gpus) > 1: 146 | ad_net = nn.DataParallel(ad_net, device_ids=[int(i) for i in gpus]) 147 | base_network = nn.DataParallel(base_network, device_ids=[int(i) for i in gpus]) 148 | autoweight = nn.DataParallel(autoweight, device_ids=[int(i) for i in gpus]) 149 | 150 | 151 | ## train 152 | len_train_source = len(dset_loaders["source"]) 153 | len_train_target = len(dset_loaders["target"]) 154 | transfer_loss_value = classifier_loss_value = total_loss_value = 0.0 155 | best_acc = 0.0 156 | start = 0 157 | for i in range(config["num_iterations"]): 158 | if i % config["test_interval"] == config["test_interval"] - 1: 159 | base_network.train(False) 160 | # temp_acc = image_classification_test(dset_loaders, \ 161 | # base_network, test_10crop=prep_config["test_10crop"]) 162 | # temp_model = nn.Sequential(base_network) 163 | # if temp_acc > best_acc: 164 | # best_acc = temp_acc 165 | # best_model = temp_model 166 | # log_str = "iter: {:05d}, precision: {:.5f}".format(i, temp_acc) 167 | # config["out_file"].write(log_str+"\n") 168 | # config["out_file"].flush() 169 | # print(log_str) 170 | for i, batch in enumerate(fsl_val_loader, 1): 171 | data, _ = [_.cuda() for _ in batch] 172 | k = config["test_way"] * config["shot"] 173 | data_shot, data_query = data[:k], data[k:] 174 | 175 | x, _ = base_network(data_shot) 176 | x = x.reshape(config["shot"], config["test_way"], -1).mean(dim=0) 177 | p = x 178 | proto_query, _ = base_network(data_query) 179 | proto_query = proto_query.reshape(config["shot"], config["train_way"], -1).mean(dim=0) 180 | logits = euclidean_metric(proto_query, p) 181 | 182 | label = torch.arange(config["test_way"]).repeat(config["query"]) 183 | label = label.type(torch.cuda.LongTensor) 184 | 185 | acc = count_acc(logits, label) 186 | ave_acc.add(acc) 187 | print('batch {}: {:.2f}({:.2f})'.format(i, ave_acc.item() * 100, acc * 100)) 188 | 189 | x = None; p = None; logits = None 190 | if i % config["snapshot_interval"] == 0: 191 | torch.save(nn.Sequential(base_network), osp.join(config["output_path"], \ 192 | "iter_{:05d}_model.pth.tar".format(i))) 193 | 194 | loss_params = config["loss"] 195 | ## train one iter 196 | base_network.train(True) 197 | ad_net.train(True) 198 | optimizer = lr_scheduler(optimizer, i, **schedule_param) 199 | optimizer.zero_grad() 200 | 201 | if i % len_train_source == 0: 202 | iter_fsl_train = iter(dset_loaders["source"]) 203 | if i % len_train_target == 0: 204 | iter_target = iter(dset_loaders["target"]) 205 | 206 | inputs_target, labels_target = iter_target.next() 207 | inputs_fsl, labels_fsl = iter_fsl_train.next() 208 | inputs_target = inputs_target.cuda() 209 | inputs_fsl, labels_fsl = inputs_fsl.cuda(), labels_fsl.cuda() 210 | 211 | p = config["shot"] * config["train_way"] 212 | data_shot, data_query = inputs_fsl[:p], inputs_fsl[p:] 213 | labels_source = labels_fsl[:p] 214 | 215 | proto_source, outputs_source = base_network(data_shot) 216 | features_target, outputs_target = base_network(inputs_target) 217 | proto = proto_source.reshape(config["shot"], config["train_way"], -1).mean(dim=0) 218 | 219 | label = torch.arange(config["train_way"]).repeat(config["query"]) 220 | label = label.type(torch.cuda.LongTensor) 221 | query_proto, _ = base_network(data_query) 222 | logits = euclidean_metric(query_proto, proto) 223 | # fsl_loss = F.cross_entropy(logits, label) 224 | fsl_loss = nn.CrossEntropyLoss()(logits, label) 225 | fsl_acc = count_acc(logits, label) 226 | 227 | features = torch.cat((proto_source, features_target), dim=0) 228 | outputs = torch.cat((outputs_source, outputs_target), dim=0) 229 | softmax_out = nn.Softmax(dim=1)(outputs) 230 | if config['method'] == 'CDAN+E': 231 | entropy = loss.Entropy(softmax_out) 232 | transfer_loss = loss.CDAN([features, softmax_out], ad_net, entropy, network.calc_coeff(i), random_layer) 233 | elif config['method'] == 'CDAN': 234 | transfer_loss = loss.CDAN([features, softmax_out], ad_net, None, None, random_layer) 235 | elif config['method'] == 'DANN': 236 | transfer_loss = loss.DANN(features, ad_net) 237 | else: 238 | raise ValueError('Method cannot be recognized.') 239 | # classifier_loss = nn.CrossEntropyLoss()(outputs_source, labels_source) 240 | 241 | 242 | 243 | if i % 1 == 0: 244 | print('iter: ', i, 'transfer_loss: ', transfer_loss.data, 'fsl_loss: ', fsl_loss.data, 'fsl_acc: ', fsl_acc) 245 | # total_loss = loss_params["trade_off"] * transfer_loss + 0.2 * fsl_loss 246 | total_loss = autoweight(fsl_loss, transfer_loss)/10 247 | print(total_loss) 248 | total_loss.backward() 249 | optimizer.step() 250 | torch.save(best_model, osp.join(config["output_path"], "best_model.pth.tar")) 251 | return best_acc 252 | 253 | 254 | if __name__ == "__main__": 255 | parser = argparse.ArgumentParser(description='Conditional Domain Adversarial Network') 256 | parser.add_argument('--method', type=str, default='CDAN+E', choices=['CDAN', 'CDAN+E', 'DANN']) 257 | parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run") 258 | parser.add_argument('--net', type=str, default='ResNet50', choices=["ResNet18", "ResNet34", "ResNet50", "ResNet101", "ResNet152", "VGG11", "VGG13", "VGG16", "VGG19", "VGG11BN", "VGG13BN", "VGG16BN", "VGG19BN", "AlexNet"]) 259 | parser.add_argument('--dset', type=str, default='office', choices=['office', 'image-clef', 'visda', 'office-home', 'mini-imagenet', 'tiered-imagenet'], help="The dataset or source dataset used") 260 | parser.add_argument('--s_dset_path', type=str, default='dataset/mini-imagenet/train', help="The dataset path") 261 | parser.add_argument('--fsl_test_path', type=str, default='dataset/mini-imagenet/test_new_domain', help="The dataset path") 262 | parser.add_argument('--test_interval', type=int, default=10000, help="interval of two continuous test phase") 263 | parser.add_argument('--snapshot_interval', type=int, default=500, help="interval of two continuous output model") 264 | parser.add_argument('--output_dir', type=str, default='mini_auto_weight10', help="output directory of our model (in ../snapshot directory)") 265 | parser.add_argument('--lr', type=float, default=0.0005, help="learning rate") 266 | parser.add_argument('--random', type=bool, default=False, help="whether use random projection") 267 | parser.add_argument('--shot', type=int, default=1) 268 | parser.add_argument('--query', type=int, default=15) 269 | parser.add_argument('--train-way', type=int, default=30) 270 | parser.add_argument('--test-way', type=int, default=5) 271 | parser.add_argument('--pretrained', type=str, default='tiered_checkpoint.pth.tar') 272 | args = parser.parse_args() 273 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id 274 | #os.environ["CUDA_VISIBLE_DEVICES"] = '0,1,2,3' 275 | 276 | # train config 277 | config = {} 278 | config['method'] = args.method 279 | config["gpu"] = args.gpu_id 280 | config["num_iterations"] = 100004 281 | config["test_interval"] = args.test_interval 282 | config["snapshot_interval"] = args.snapshot_interval 283 | config["output_for_test"] = True 284 | config["output_path"] = "snapshot/" + args.output_dir 285 | if not osp.exists(config["output_path"]): 286 | os.system('mkdir -p '+config["output_path"]) 287 | config["out_file"] = open(osp.join(config["output_path"], "log.txt"), "w") 288 | if not osp.exists(config["output_path"]): 289 | os.mkdir(config["output_path"]) 290 | 291 | config["prep"] = {"test_10crop":True, 'params':{"resize_size":256, "crop_size":224, 'alexnet':False}} 292 | config["loss"] = {"trade_off":1.0} 293 | if "AlexNet" in args.net: 294 | config["prep"]['params']['alexnet'] = True 295 | config["prep"]['params']['crop_size'] = 227 296 | config["network"] = {"name":network.AlexNetFc, \ 297 | "params":{"use_bottleneck":True, "bottleneck_dim":256, "new_cls":True} } 298 | elif "ResNet" in args.net: 299 | config["network"] = {"name":network.ResNetFc, \ 300 | "params":{"resnet_name":args.net, "use_bottleneck":True, "bottleneck_dim":256, "new_cls":True, "pretrained_model":args.pretrained} } 301 | elif "VGG" in args.net: 302 | config["network"] = {"name":network.VGGFc, \ 303 | "params":{"vgg_name":args.net, "use_bottleneck":True, "bottleneck_dim":256, "new_cls":True} } 304 | config["loss"]["random"] = args.random 305 | config["loss"]["random_dim"] = 1024 306 | 307 | config["optimizer"] = {"type":optim.SGD, "optim_params":{'lr':args.lr, "momentum":0.9, \ 308 | "weight_decay":0.0005, "nesterov":True}, "lr_type":"inv", \ 309 | "lr_param":{"lr":args.lr, "gamma":0.001, "power":0.75} } 310 | 311 | config["dataset"] = args.dset 312 | config["data"] = {"source":{"root":args.s_dset_path, "split":"train", "batch_size":50}, \ 313 | "target":{"root":args.s_dset_path, "split":"val_new_domain", "batch_size":8}, \ 314 | "test":{"root":args.s_dset_path, "split":"val_new_domain", "batch_size":4}, \ 315 | "fsl_test":{"root":args.fsl_test_path, "split":"val_new_domain_fsl", "batch_size":4}} 316 | 317 | if config["dataset"] == "office": 318 | if ("amazon" in args.s_dset_path and "webcam" in args.t_dset_path) or \ 319 | ("webcam" in args.s_dset_path and "dslr" in args.t_dset_path) or \ 320 | ("webcam" in args.s_dset_path and "amazon" in args.t_dset_path) or \ 321 | ("dslr" in args.s_dset_path and "amazon" in args.t_dset_path): 322 | config["optimizer"]["lr_param"]["lr"] = 0.001 # optimal parameters 323 | elif ("amazon" in args.s_dset_path and "dslr" in args.t_dset_path) or \ 324 | ("dslr" in args.s_dset_path and "webcam" in args.t_dset_path): 325 | config["optimizer"]["lr_param"]["lr"] = 0.0003 # optimal parameters 326 | config["network"]["params"]["class_num"] = 31 327 | elif config["dataset"] == "image-clef": 328 | config["optimizer"]["lr_param"]["lr"] = 0.001 # optimal parameters 329 | config["network"]["params"]["class_num"] = 12 330 | elif config["dataset"] == "visda": 331 | config["optimizer"]["lr_param"]["lr"] = 0.001 # optimal parameters 332 | config["network"]["params"]["class_num"] = 12 333 | config['loss']["trade_off"] = 1.0 334 | elif config["dataset"] == "office-home": 335 | config["optimizer"]["lr_param"]["lr"] = 0.001 # optimal parameters 336 | config["network"]["params"]["class_num"] = 65 337 | elif config["dataset"] == "mini-imagenet": 338 | config["optimizer"]["lr_param"]["lr"] = 0.001 # optimal parameters 339 | config["network"]["params"]["class_num"] = 64 340 | config['loss']["trade_off"] = 1.0 341 | elif config["dataset"] == "tiered-imagenet": 342 | config["optimizer"]["lr_param"]["lr"] = 0.001 # optimal parameters 343 | config["network"]["params"]["class_num"] = 351 344 | config['loss']["trade_off"] = 1.0 345 | else: 346 | raise ValueError('Dataset cannot be recognized. Please define your own dataset here.') 347 | config["out_file"].write(str(config)) 348 | config["out_file"].flush() 349 | 350 | config["shot"] = args.shot 351 | config["query"] = args.query 352 | config["train_way"] = args.train_way 353 | config["test_way"] = args.test_way 354 | train(config) 355 | --------------------------------------------------------------------------------