├── .DS_Store ├── LICENSE ├── README.md ├── classification.py ├── datasets ├── .DS_Store ├── __init__.py ├── oneShotBaseCls.py └── softRandom.py ├── datasplit ├── test.csv ├── train.csv └── val.csv ├── logger.py ├── models ├── .DS_Store ├── metaNet_1shot.t7 ├── metaNet_5shot.t7 └── softRandom.t7 ├── onlyBasetwoLoss.py ├── option.py └── picture ├── .DS_Store ├── FuseNet.png ├── approach.png ├── deformed_images.png └── meta_learning.png /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tankche1/IDeMe-Net/6e8e212da7a9381a5d2bd011e5dc62c23452e010/.DS_Store -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Zitian Chen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Image Deformation Meta-Networks for One-Shot Learning 2 | 3 | A PyTorch implementation of "Image Deformation Meta-Networks for One-Shot Learning"(CVPR 2019 Oral). 4 | 5 | > [**Image Deformation Meta-Networks for One-Shot Learning**](), 6 | > Zitian Chen, Yanwei Fu, Yu-Xiong Wang, Lin Ma, Wei Liu, Martial Hebert 7 | 8 | # ![](picture/meta_learning.png) 9 | 10 | 11 | 12 | ![](picture/deformed_images.png) 13 | 14 | 15 | 16 | ![](picture/approach.png) 17 | 18 | 19 | 20 | ## Installation 21 | 22 | ``` 23 | python=2.7 24 | pytorch=0.4.0 25 | ``` 26 | 27 | ## Datasets 28 | 29 | The data split is from [**Semantic Feature Augmentation in Few-shot Learning**]() 30 | 31 | ``` 32 | Please put the data in: 33 | /home/yourusername/data/miniImagenet 34 | 35 | The images are put in 36 | .../miniImagenet/images 37 | such as:miniImagenet\images\n0153282900000006.jpg 38 | We provide the data split in ./datasplit/,please put them at 39 | .../miniImagenet/train.csv 40 | .../miniImagenet/test.csv 41 | .../miniImagenet/val.csv 42 | ``` 43 | 44 | 45 | 46 | ## Train & Test 47 | 48 | Notice that we train the model on **4 Titan X**. 42000MB GPU memory is required or may cause CUDA out of memory. 49 | 50 | ``` 51 | # First, we fix the deformation sub-network and train the embedding sub-network with randomly deformed images 52 | 53 | # We provide softRandom.t7 as the embedding sub-network 54 | # if you want to train your own, run python classification.py --tensorname yournetworkname 55 | 56 | 57 | # Then, we fix the embedding sub-network and learn the deformation sub-network 58 | 59 | CUDA_VISIBLE_DEVICES=0,1,2,3 python onlyBasetwoLoss.py --network softRandom --shots 5 --augnum 5 --fixCls 1 --tensorname metaNet_5shot --chooseNum 30 60 | 61 | # If you want to further improve, then fix one sub-network and iteratively train the other. 62 | 63 | # update cls 64 | CUDA_VISIBLE_DEVICES=0,1,2,3 python onlyBasetwoLoss.py --network softRandom --shots 5 --augnum 5 --fixCls 0 --fixAttention 1 --tensorname metaNet_5shot_round2 --chooseNum 30 --GNet metaNet_5shot 65 | 66 | We also provide our model: metaNet_1shot.t7 and metaNet_5shot.t7 in ./models 67 | 68 | You can use --GNet metaNet_1shot to load the model. 69 | 70 | 71 | ``` 72 | 73 | 74 | 75 | 76 | ## License 77 | 78 | IDeMe-Net is released under the MIT License (refer to the LICENSE file for details). 79 | 80 | 81 | ## Citation 82 | 83 | If you find this project useful for your research, please use the following BibTeX entry. 84 | ``` 85 | @inproceedings{chen2019image, 86 | title={Image deformation meta-networks for one-shot learning}, 87 | author={Chen, Zitian and Fu, Yanwei and Wang, Yu-Xiong and Ma, Lin and Liu, Wei and Hebert, Martial}, 88 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, 89 | pages={8680--8689}, 90 | year={2019} 91 | } 92 | ``` 93 | 94 | -------------------------------------------------------------------------------- /classification.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import argparse 4 | import torch 5 | import torch.optim as optim 6 | from tqdm import * 7 | from torch.autograd import Variable 8 | import torch.nn.functional as F 9 | import torch.nn as nn 10 | import torch.utils.data 11 | import torchvision.transforms as transforms 12 | import torchvision.datasets as datasets 13 | import torchvision.models as models 14 | import torchvision 15 | import matplotlib.pyplot as plt 16 | from option import Options 17 | from datasets import softRandom 18 | from torch.optim import lr_scheduler 19 | import copy 20 | import time 21 | rootdir = os.getcwd() 22 | 23 | args = Options().parse() 24 | 25 | 26 | image_datasets = {x: softRandom.miniImagenetEmbeddingDataset(type=x) 27 | for x in ['train', 'val','test']} 28 | dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=args.batchSize, 29 | shuffle=True, num_workers=args.nthreads) 30 | for x in ['train', 'val','test']} 31 | dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val','test']} 32 | 33 | 34 | ###################################################################### 35 | # Define the Embedding Network 36 | 37 | class ClassificationNetwork(nn.Module): 38 | def __init__(self): 39 | super(ClassificationNetwork, self).__init__() 40 | self.convnet = torchvision.models.resnet18(pretrained=False) 41 | num_ftrs = self.convnet.fc.in_features 42 | self.convnet.fc = nn.Linear(num_ftrs,64) 43 | 44 | def forward(self,inputs): 45 | outputs = self.convnet(inputs) 46 | 47 | return outputs 48 | 49 | classificationNetwork = ClassificationNetwork() 50 | if args.network!='None': 51 | classificationNetwork.load_state_dict(torch.load('models/'+str(args.network)+'.t7', map_location=lambda storage, loc: storage)) 52 | print('loading ',str(args.network)) 53 | classificationNetwork = classificationNetwork.cuda() 54 | 55 | my_list = ['convnet.fc.weight', 'convnet.fc.bias'] 56 | params = list(filter(lambda kv: kv[0] in my_list, classificationNetwork.named_parameters())) 57 | base_params = list(filter(lambda kv: kv[0] not in my_list, classificationNetwork.named_parameters()))## 58 | 59 | # print(params,base_params) 60 | 61 | ############################################# 62 | #Define the optimizer# 63 | 64 | criterion = nn.CrossEntropyLoss() 65 | 66 | if args.network=='None': 67 | optimizer_embedding = optim.Adam([ 68 | {'params': classificationNetwork.parameters()}, 69 | ], lr=0.001) 70 | else: 71 | optimizer_embedding = optim.Adam([ 72 | {'params': params,'lr': args.LR*0.1}, 73 | {'params': base_params, 'lr': args.LR}## 74 | ]) 75 | 76 | embedding_lr_scheduler = lr_scheduler.StepLR(optimizer_embedding, step_size=10, gamma=0.5) 77 | 78 | 79 | ###################################################################### 80 | # Train and evaluate 81 | # ^^^^^^^^^^^^^^^^^^ 82 | 83 | 84 | def train_model(model, criterion, optimizer, scheduler, num_epochs=25): 85 | since = time.time() 86 | 87 | best_model_wts = copy.deepcopy(model.state_dict()) 88 | best_loss = 1000000000.0 89 | 90 | for epoch in range(num_epochs): 91 | print('Epoch {}/{}'.format(epoch, num_epochs - 1)) 92 | print('-' * 10) 93 | 94 | # Each epoch has a training and validation phase 95 | for phase in [ 'train']: 96 | 97 | if phase == 'train': 98 | scheduler.step() 99 | model.train(True) # Set model to training mode 100 | else: 101 | model.train(False) # Set model to evaluate mode 102 | 103 | 104 | running_loss = 0.0 105 | tot_dist = 0.0 106 | running_corrects = 0 107 | loss = 0 108 | 109 | # Iterate over data. 110 | for i,(inputs,labels) in tqdm(enumerate(dataloaders[phase])): 111 | 112 | #c = labels 113 | # wrap them in Variable 114 | inputs = Variable(inputs.cuda()) 115 | labels = Variable(labels.cuda()) 116 | 117 | # zero the parameter gradients 118 | optimizer.zero_grad() 119 | 120 | # forward 121 | outputs = model(inputs) 122 | 123 | _, preds = torch.max(outputs, 1) 124 | 125 | labels = labels.view(labels.size(0)) 126 | 127 | loss = criterion(outputs, labels) 128 | 129 | 130 | # backward + optimize only if in training phase 131 | if phase == 'train': 132 | loss.backward() 133 | optimizer.step() 134 | 135 | # statistics 136 | running_loss += loss.item() * inputs.size(0) 137 | running_corrects += torch.sum(preds == labels.view(-1)).item() 138 | #print(running_corrects) 139 | 140 | 141 | epoch_loss = running_loss / (dataset_sizes[phase]*1.0) 142 | epoch_acc = running_corrects / (dataset_sizes[phase]*1.0) 143 | info = { 144 | phase+'loss': running_loss, 145 | phase+'Accuracy': epoch_acc, 146 | } 147 | 148 | print('{} Loss: {:.4f} Accuracy: {:.4f} '.format( 149 | phase, epoch_loss,epoch_acc)) 150 | 151 | # deep copy the model 152 | if phase == 'train' and epoch_loss < best_loss: 153 | best_loss = epoch_loss 154 | best_model_wts = copy.deepcopy(model.state_dict()) 155 | 156 | 157 | print() 158 | # if epoch>=30 and epoch %3 ==0: 159 | # torch.save(best_model_wts,os.path.join(rootdir,'models/'+str(args.tensorname)+ str(epoch) + '.t7')) 160 | # print('save!') 161 | if epoch % 10 ==0: 162 | torch.save(best_model_wts,os.path.join(rootdir,'models/'+str(args.tensorname)+ '.t7')) 163 | print('save!') 164 | ## 165 | 166 | time_elapsed = time.time() - since 167 | print('Training complete in {:.0f}m {:.0f}s'.format( 168 | time_elapsed // 60, time_elapsed % 60)) 169 | print('Best val Loss: {:4f}'.format(best_loss)) 170 | 171 | 172 | # load best model weights 173 | model.load_state_dict(best_model_wts) 174 | return model 175 | 176 | 177 | classificationNetwork = train_model(classificationNetwork, criterion, optimizer_embedding, 178 | embedding_lr_scheduler, num_epochs=35)## 179 | 180 | 181 | torch.save(classificationNetwork.state_dict(),os.path.join(rootdir,'models/'+str(args.tensorname)+'.t7')) 182 | 183 | 184 | -------------------------------------------------------------------------------- /datasets/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tankche1/IDeMe-Net/6e8e212da7a9381a5d2bd011e5dc62c23452e010/datasets/.DS_Store -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tankche1/IDeMe-Net/6e8e212da7a9381a5d2bd011e5dc62c23452e010/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/oneShotBaseCls.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | import torchvision.transforms as transforms 4 | from torch.autograd import Variable 5 | from PIL import Image 6 | import os.path 7 | import csv 8 | import math 9 | import collections 10 | from tqdm import tqdm 11 | import datetime 12 | 13 | import numpy as np 14 | import numpy 15 | #from watch import NlabelTovector 16 | import getpass 17 | userName = getpass.getuser() 18 | 19 | pathminiImageNet = '/home/'+userName+'/data/miniImagenet/' 20 | pathImages = os.path.join(pathminiImageNet,'images/') 21 | # LAMBDA FUNCTIONS 22 | filenameToPILImage = lambda x: Image.open(x) 23 | 24 | np.random.seed(2191) 25 | 26 | patch_xl = [0,0,0,74,74,74,148,148,148] 27 | patch_xr = [74,74,74,148,148,148,224,224,224] 28 | patch_yl = [0,74,148,0,74,148,0,74,148] 29 | patch_yr = [74,148,224,74,148,224,74,148,224] 30 | 31 | class miniImagenetOneshotDataset(data.Dataset): 32 | def __init__(self, dataroot = '/home/'+userName+'/data/miniImagenet', type = 'train',ways=5,shots=1,test_num=1,epoch=100,galleryNum = 10): 33 | # oneShot setting 34 | self.ways = ways 35 | self.shots = shots 36 | self.test_num = test_num # indicate test number of each class 37 | self.__size = epoch 38 | 39 | self.transform = transforms.Compose([filenameToPILImage, 40 | transforms.Resize(256), 41 | transforms.CenterCrop(224), 42 | transforms.ToTensor(), 43 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 44 | ]) 45 | 46 | self.galleryTransform = transforms.Compose([filenameToPILImage, 47 | transforms.RandomHorizontalFlip(p=0.5), 48 | transforms.Resize(256), 49 | transforms.CenterCrop(224), 50 | transforms.ToTensor(), 51 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 52 | ]) 53 | 54 | def loadSplit(splitFile): 55 | dictLabels = {} 56 | with open(splitFile) as csvfile: 57 | csvreader = csv.reader(csvfile, delimiter=',') 58 | next(csvreader, None) 59 | for i,row in enumerate(csvreader): 60 | filename = row[0] 61 | label = row[1] 62 | 63 | if label in dictLabels.keys(): 64 | dictLabels[label].append(filename) 65 | else: 66 | dictLabels[label] = [filename] 67 | return dictLabels 68 | 69 | self.miniImagenetImagesDir = os.path.join(dataroot,'images') 70 | 71 | self.unData = loadSplit(splitFile = os.path.join(dataroot,'train' + '.csv')) 72 | self.data = loadSplit(splitFile = os.path.join(dataroot,type + '.csv')) 73 | 74 | self.type = type 75 | self.data = collections.OrderedDict(sorted(self.data.items())) 76 | self.unData = collections.OrderedDict(sorted(self.unData.items())) 77 | self.galleryNum = galleryNum 78 | 79 | # sample Gallery 80 | self.Gallery = [] 81 | numpy.random.seed(2019) 82 | for classes in range(len(self.unData.keys())): 83 | Files = np.random.choice(self.unData[self.unData.keys()[classes]], self.galleryNum, False) 84 | for file in Files: 85 | self.Gallery.append(file) 86 | 87 | numpy.random.seed() 88 | 89 | self.keyTobh = {} 90 | for c in range(len(self.data.keys())): 91 | self.keyTobh[self.data.keys()[c]] = c 92 | 93 | for c in range(len(self.unData.keys())): 94 | self.keyTobh[self.unData.keys()[c]] = c 95 | 96 | #print(self.keyTobh) 97 | def batchModel(model,AInputs,requireGrad): 98 | Batch = (AInputs.size(0)+args.batchSize-1)//args.batchSize 99 | First = True 100 | Cfeatures = 1 101 | 102 | 103 | for b in range(Batch): 104 | if b1] = 1 153 | # t[t<0] = 0 154 | # return t 155 | 156 | # detransform = transforms.Compose([ 157 | # Denormalize(mu, sigma), 158 | # Clip(), 159 | # transforms.ToPILImage(), 160 | # ]) 161 | 162 | 163 | # def plotPicture(image,name): 164 | # fig = plt.figure() 165 | # ax = fig.add_subplot(111) 166 | # A = image.clone() 167 | # ax.imshow(detransform(A)) 168 | # fig.savefig('picture/'+str(name)+'.png') 169 | # print('picture/'+str(name)+'.png') 170 | # plt.close(fig) 171 | 172 | # if __name__ == '__main__': 173 | # dataTrain = miniImagenetEmbeddingDataset(type='train') 174 | # print(len(dataTrain)) 175 | 176 | # C,_ = dataTrain.__getitem__(2) 177 | # print('Size: ',C.size()) 178 | # plotPicture(C,'origin') 179 | # C = torch.flip(C,[2]) 180 | # plotPicture(C,'flip') 181 | 182 | -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | # Code referenced from https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514 2 | import tensorflow as tf 3 | import numpy as np 4 | import scipy.misc 5 | try: 6 | from StringIO import StringIO # Python 2.7 7 | except ImportError: 8 | from io import BytesIO # Python 3.x 9 | 10 | 11 | class Logger(object): 12 | 13 | def __init__(self, log_dir): 14 | """Create a summary writer logging to log_dir.""" 15 | self.writer = tf.summary.FileWriter(log_dir) 16 | 17 | def scalar_summary(self, tag, value, step): 18 | """Log a scalar variable.""" 19 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]) 20 | self.writer.add_summary(summary, step) 21 | 22 | def image_summary(self, tag, images, step): 23 | """Log a list of images.""" 24 | 25 | img_summaries = [] 26 | for i, img in enumerate(images): 27 | # Write the image to a string 28 | try: 29 | s = StringIO() 30 | except: 31 | s = BytesIO() 32 | scipy.misc.toimage(img).save(s, format="png") 33 | 34 | # Create an Image object 35 | img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(), 36 | height=img.shape[0], 37 | width=img.shape[1]) 38 | # Create a Summary value 39 | img_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum)) 40 | 41 | # Create and write Summary 42 | summary = tf.Summary(value=img_summaries) 43 | self.writer.add_summary(summary, step) 44 | 45 | def histo_summary(self, tag, values, step, bins=1000): 46 | """Log a histogram of the tensor of values.""" 47 | 48 | # Create a histogram using numpy 49 | counts, bin_edges = np.histogram(values, bins=bins) 50 | 51 | # Fill the fields of the histogram proto 52 | hist = tf.HistogramProto() 53 | hist.min = float(np.min(values)) 54 | hist.max = float(np.max(values)) 55 | hist.num = int(np.prod(values.shape)) 56 | hist.sum = float(np.sum(values)) 57 | hist.sum_squares = float(np.sum(values**2)) 58 | 59 | # Drop the start of the first bin 60 | bin_edges = bin_edges[1:] 61 | 62 | # Add bin edges and counts 63 | for edge in bin_edges: 64 | hist.bucket_limit.append(edge) 65 | for c in counts: 66 | hist.bucket.append(c) 67 | 68 | # Create and write Summary 69 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)]) 70 | self.writer.add_summary(summary, step) 71 | self.writer.flush() -------------------------------------------------------------------------------- /models/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tankche1/IDeMe-Net/6e8e212da7a9381a5d2bd011e5dc62c23452e010/models/.DS_Store -------------------------------------------------------------------------------- /models/metaNet_1shot.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tankche1/IDeMe-Net/6e8e212da7a9381a5d2bd011e5dc62c23452e010/models/metaNet_1shot.t7 -------------------------------------------------------------------------------- /models/metaNet_5shot.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tankche1/IDeMe-Net/6e8e212da7a9381a5d2bd011e5dc62c23452e010/models/metaNet_5shot.t7 -------------------------------------------------------------------------------- /models/softRandom.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tankche1/IDeMe-Net/6e8e212da7a9381a5d2bd011e5dc62c23452e010/models/softRandom.t7 -------------------------------------------------------------------------------- /onlyBasetwoLoss.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import argparse 4 | import torch 5 | import torch.optim as optim 6 | from tqdm import * 7 | from torch.autograd import Variable 8 | import torch.nn.functional as F 9 | import torch.nn as nn 10 | import torch.utils.data 11 | import torchvision.transforms as transforms 12 | import torchvision.datasets as datasets 13 | import torchvision.models as models 14 | import torchvision 15 | # import matplotlib.pyplot as plt 16 | from option import Options 17 | from datasets import oneShotBaseCls 18 | from datasets import oneShotUnsuperviseCls 19 | 20 | from torch.optim import lr_scheduler 21 | import copy 22 | import time 23 | rootdir = os.getcwd() 24 | 25 | args = Options().parse() 26 | 27 | from logger import Logger 28 | logger = Logger('./logs/'+args.tensorname)## 29 | 30 | image_datasets = {} 31 | 32 | print('sample from base!') 33 | image_datasets = {x: oneShotBaseCls.miniImagenetOneshotDataset(type=x,ways= (args.trainways if x=='train' else args.ways),shots=args.shots,test_num=args.test_num,epoch=args.epoch,galleryNum=args.galleryNum) 34 | for x in ['train', 'val','test']} 35 | 36 | def worker_init_fn(worker_id): 37 | np.random.seed(np.random.get_state()[1][0] + worker_id) 38 | 39 | dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=1, 40 | shuffle=(x=='train'), num_workers=args.nthreads,worker_init_fn=worker_init_fn) 41 | for x in ['train', 'val','test']} 42 | dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val','test']} 43 | 44 | ###################################################################### 45 | # Weight matrix pre-process 46 | 47 | patch_xl = [] 48 | patch_xr = [] 49 | patch_yl = [] 50 | patch_yr = [] 51 | 52 | if args.Fang == 3: 53 | point = [0,74,148,224] 54 | elif args.Fang == 5: 55 | point = [0,44,88,132,176,224] 56 | elif args.Fang == 7: 57 | point = [0,32,64,96,128,160,192,224] 58 | 59 | 60 | 61 | for i in range(args.Fang): 62 | for j in range(args.Fang): 63 | patch_xl.append(point[i]) 64 | patch_xr.append(point[i+1]) 65 | patch_yl.append(point[j]) 66 | patch_yr.append(point[j+1]) 67 | 68 | fixSquare = torch.zeros(1,args.Fang*args.Fang,3,224,224).float() 69 | for i in range(args.Fang*args.Fang): 70 | fixSquare[:,i,:,patch_xl[i]:patch_xr[i],patch_yl[i]:patch_yr[i]] = 1.00 71 | fixSquare = fixSquare.cuda() 72 | 73 | oneSquare = torch.ones(1,3,224,224).float() 74 | oneSquare = oneSquare.cuda() 75 | ###################################################################### 76 | #plot related 77 | import matplotlib 78 | matplotlib.use('agg') 79 | import matplotlib.pyplot as plt 80 | from matplotlib.pyplot import imshow 81 | #################################################3 82 | 83 | mu = [0.485, 0.456, 0.406] 84 | sigma = [0.229, 0.224, 0.225] 85 | class Denormalize(object): 86 | def __init__(self, mean, std): 87 | self.mean = mean 88 | self.std = std 89 | 90 | def __call__(self, tensor): 91 | for t, m, s in zip(tensor, self.mean, self.std): 92 | t.mul_(s).add_(m) 93 | return tensor 94 | 95 | 96 | class Clip(object): 97 | def __init__(self): 98 | return 99 | 100 | def __call__(self, tensor): 101 | t = tensor.clone() 102 | t[t>1] = 1 103 | t[t<0] = 0 104 | return t 105 | 106 | detransform = transforms.Compose([ 107 | Denormalize(mu, sigma), 108 | Clip(), 109 | transforms.ToPILImage(), 110 | ]) 111 | 112 | 113 | def plotPicture(image,name): 114 | fig = plt.figure() 115 | ax = fig.add_subplot(111) 116 | A = image.clone() 117 | ax.imshow(detransform(A)) 118 | fig.savefig('picture/'+str(name)+'.png') 119 | plt.close(fig) 120 | 121 | ###################################################################### 122 | # Define the Embedding Network 123 | class Flatten(nn.Module): 124 | def __init__(self): 125 | super(Flatten, self).__init__() 126 | 127 | def forward(self, x): 128 | return x.view(x.size(0), -1) 129 | class ClassificationNetwork(nn.Module): 130 | def __init__(self): 131 | super(ClassificationNetwork, self).__init__() 132 | self.convnet = torchvision.models.resnet18(pretrained=False) 133 | num_ftrs = self.convnet.fc.in_features 134 | self.convnet.fc = nn.Linear(num_ftrs,64) 135 | #print(self.convnet) 136 | 137 | def forward(self,inputs): 138 | outputs = self.convnet(inputs) 139 | 140 | return outputs 141 | 142 | # resnet18 without fc layer 143 | class weightNet(nn.Module): 144 | def __init__(self): 145 | super(weightNet, self).__init__() 146 | self.resnet = ClassificationNetwork() 147 | self.resnet.load_state_dict(torch.load('models/'+str(args.network)+'.t7', map_location=lambda storage, loc: storage)) 148 | print('loading ',str(args.network)) 149 | 150 | self.conv1 = self.resnet.convnet.conv1 151 | self.conv1.load_state_dict(self.resnet.convnet.conv1.state_dict()) 152 | self.bn1 = self.resnet.convnet.bn1 153 | self.bn1.load_state_dict(self.resnet.convnet.bn1.state_dict()) 154 | self.relu = self.resnet.convnet.relu 155 | self.maxpool = self.resnet.convnet.maxpool 156 | self.layer1 = self.resnet.convnet.layer1 157 | self.layer1.load_state_dict(self.resnet.convnet.layer1.state_dict()) 158 | self.layer2 = self.resnet.convnet.layer2 159 | self.layer2.load_state_dict(self.resnet.convnet.layer2.state_dict()) 160 | self.layer3 = self.resnet.convnet.layer3 161 | self.layer3.load_state_dict(self.resnet.convnet.layer3.state_dict()) 162 | self.layer4 = self.resnet.convnet.layer4 163 | self.layer4.load_state_dict(self.resnet.convnet.layer4.state_dict()) 164 | self.layer4 = self.resnet.convnet.layer4 165 | self.layer4.load_state_dict(self.resnet.convnet.layer4.state_dict()) 166 | self.avgpool = self.resnet.convnet.avgpool 167 | 168 | def forward(self,x): 169 | 170 | x = self.conv1(x) 171 | x = self.bn1(x) 172 | x = self.relu(x) 173 | x = self.maxpool(x) 174 | layer1 = self.layer1(x) # (, 64L, 56L, 56L) 175 | layer2 = self.layer2(layer1) # (, 128L, 28L, 28L) 176 | layer3 = self.layer3(layer2) # (, 256L, 14L, 14L) 177 | layer4 = self.layer4(layer3) # (,512,7,7) 178 | x = self.avgpool(layer4) # (,512,1,1) 179 | x = x.view(x.size(0), -1) 180 | return x 181 | 182 | class smallNet(nn.Module): 183 | def __init__(self): 184 | super(smallNet, self).__init__() 185 | def conv_block(in_channels, out_channels): 186 | return nn.Sequential( 187 | nn.Conv2d(in_channels, out_channels, 3, padding=1), 188 | nn.BatchNorm2d(out_channels), 189 | nn.ReLU(), 190 | nn.MaxPool2d(2) 191 | ) 192 | 193 | self.encoder = nn.Sequential( # 6*224*224 194 | conv_block(6, 32), # 64*112*112 195 | conv_block(32, 64), # 64*56*56 196 | conv_block(64, 64), # 64*28*28 197 | conv_block(64, 32), # 64*14*14 198 | conv_block(32, 16), # 32*7*7 199 | Flatten() # 784 200 | ) 201 | print(self.encoder) 202 | 203 | def forward(self,inputs): 204 | 205 | """ 206 | inputs: Batchsize*3*224*224 207 | outputs: Batchsize*100 208 | """ 209 | outputs = self.encoder(inputs) 210 | 211 | return outputs 212 | 213 | 214 | class GNet(nn.Module): 215 | ''' 216 | Two branch's performance are similar one branch's 217 | So we use one branch here 218 | Deeper attention network do not bring in benifits 219 | So we use small network here 220 | ''' 221 | def __init__(self): 222 | super(GNet, self).__init__() 223 | # self.ANet = weightNet() 224 | # self.BNet = weightNet() 225 | self.attentionNet = smallNet() 226 | 227 | self.toWeight = nn.Sequential( 228 | nn.Linear(784,args.Fang*args.Fang), 229 | # nn.ReLU(), 230 | # nn.Linear(100,args.Fang*args.Fang), 231 | # nn.Linear(1024,9), 232 | # nn.Tanh(), 233 | # nn.ReLU(), 234 | ) 235 | 236 | self.CNet = weightNet() 237 | self.fc = nn.Linear(512,64) 238 | 239 | resnet = ClassificationNetwork() 240 | resnet.load_state_dict(torch.load('models/'+str(args.network)+'.t7', map_location=lambda storage, loc: storage)) 241 | 242 | self.fc.load_state_dict(resnet.convnet.fc.state_dict()) 243 | 244 | self.scale = nn.Parameter(torch.FloatTensor(1).fill_(1.0), requires_grad=True) 245 | 246 | def forward(self,A,B=1,fixSquare=1,oneSquare=1,mode='one'): 247 | # A,B :[batch,3,224,224] fixSquare:[batch,9,3,224,224] oneSquare:[batch,3,224,224] 248 | if mode == 'two': 249 | # Calculate 3*3 weight matrix 250 | batchSize = A.size(0) 251 | feature = self.attentionNet(torch.cat((A,B),1)) 252 | weight = self.toWeight(feature) # [batch,3*3] 253 | 254 | weightSquare = weight.view(batchSize,args.Fang*args.Fang,1,1,1) 255 | weightSquare = weightSquare.expand(batchSize,args.Fang*args.Fang,3,224,224) 256 | weightSquare = weightSquare * fixSquare # [batch,9,3,224,224] 257 | weightSquare = torch.sum(weightSquare,dim=1) # [batch,3,224,224] 258 | 259 | C = weightSquare*A + (oneSquare - weightSquare) * B 260 | Cfeature = self.CNet(C) 261 | return Cfeature, weight, feature 262 | 263 | elif mode == 'one': 264 | # Calculate feature 265 | Cfeature = self.CNet(A) 266 | return Cfeature 267 | 268 | elif mode == 'fc': 269 | # Go through fc layer, just for debug 270 | Cfeature = self.fc(A) 271 | return Cfeature 272 | 273 | GNet = GNet() 274 | 275 | 276 | if args.GNet!='none': 277 | GNet.load_state_dict(torch.load('models/'+args.GNet+'.t7', map_location=lambda storage, loc: storage)) 278 | print('loading ',args.GNet) 279 | 280 | if torch.cuda.device_count() > 1: 281 | print("Let's use", torch.cuda.device_count(), "GPUs!") 282 | # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs 283 | GNet = nn.DataParallel(GNet) 284 | 285 | GNet = GNet.cuda() 286 | 287 | ############################################# 288 | #Define the optimizer 289 | 290 | if torch.cuda.device_count() > 1: 291 | if args.scratch == 0: 292 | optimizer_attention = torch.optim.Adam([ 293 | {'params': GNet.module.attentionNet.parameters()}, 294 | {'params': GNet.module.toWeight.parameters(), 'lr': args.LR} 295 | ], lr=args.LR) # 0.001 296 | optimizer_classifier = torch.optim.Adam([ 297 | {'params': GNet.module.CNet.parameters(),'lr': args.clsLR*0.1}, 298 | {'params': GNet.module.fc.parameters(), 'lr': args.clsLR} 299 | ]) # 0.00003 300 | optimizer_scale = torch.optim.Adam([ 301 | {'params': GNet.module.scale} 302 | ], lr=args.LR) # 0.001 303 | else: 304 | optimizer_attention = torch.optim.Adam([ 305 | {'params': GNet.module.ANet.parameters()}, 306 | {'params': GNet.module.BNet.parameters()}, 307 | {'params': GNet.module.toWeight.parameters()} 308 | ], lr=args.LR) 309 | optimizer_classifier = torch.optim.Adam([ 310 | {'params': GNet.module.CNet.parameters()}, 311 | {'params': GNet.module.fc.parameters()} 312 | ], lr=args.LR) 313 | else: 314 | optimizer_GNet = torch.optim.Adam([ 315 | {'params': base_params}, 316 | {'params': GNet.toWeight.parameters(), 'lr': args.LR} 317 | ], lr=args.LR*0.1) 318 | 319 | Attention_lr_scheduler = lr_scheduler.StepLR(optimizer_attention, step_size=40, gamma=0.5) 320 | Classifier_lr_scheduler = lr_scheduler.StepLR(optimizer_classifier, step_size=40, gamma=0.5) 321 | clsCriterion = nn.CrossEntropyLoss() 322 | 323 | ###################################################################### 324 | # Train and evaluate 325 | # ^^^^^^^^^^^^^^^^^^ 326 | 327 | # Gallery 328 | Gallery = image_datasets['test'].Gallery 329 | galleryFeature = image_datasets['test'].acquireFeature(GNet,args.batchSize).cpu() 330 | 331 | 332 | def euclidean_dist(x, y): 333 | # x: N x D 334 | # y: M x D 335 | n = x.size(0) 336 | m = y.size(0) 337 | d = x.size(1) 338 | assert d == y.size(1) 339 | 340 | x = x.unsqueeze(1).expand(n, m, d) 341 | y = y.unsqueeze(0).expand(n, m, d) 342 | 343 | # To accelerate training, but observe little effect 344 | A = GNet.module.scale 345 | 346 | return (torch.pow(x - y, 2)*A).sum(2) 347 | 348 | def iterateMix(supportImages,supportFeatures,supportBelongs,supportReals,ways): 349 | ''' 350 | Inputs: 351 | supportImages ways,shots,3,224,224 352 | Outputs: 353 | AImages [ways*shots*(1+augnum),3,224,224] 354 | BImages [ways*shots*(1+augnum),3,224,224] 355 | ABelongs: The label in [0,way-1] 356 | Reals: The label in [0,63] # Just for debug 357 | ''' 358 | center = supportFeatures.view(ways,args.shots,-1).mean(1) 359 | 360 | # dists = euclidean_dist(galleryFeature,center) # [ways*unNum,ways] 361 | Num = galleryFeature.size(0)/10 362 | with torch.no_grad(): 363 | dists = euclidean_dist(galleryFeature[:Num].cuda(),center) 364 | for i in range(1,10): 365 | _end = (i+1)*Num 366 | if i==9: 367 | _end = galleryFeature.size(0) 368 | dist = euclidean_dist(galleryFeature[i*Num:_end].cuda(),center) 369 | dists = torch.cat((dists,dist),dim=0) 370 | 371 | dists = dists.transpose(1,0) # [ways,ways*unNum] 372 | 373 | AImages = torch.FloatTensor(ways*args.shots*(1+args.augnum),3,224,224) 374 | ABelongs = torch.LongTensor(ways*args.shots*(1+args.augnum),1) 375 | Reals = torch.LongTensor(ways*args.shots*(1+args.augnum),1) 376 | 377 | BImages = torch.FloatTensor(ways*args.shots*(1+args.augnum),3,224,224) 378 | 379 | _, bh = torch.topk(dists,args.chooseNum,dim=1,largest=False) 380 | 381 | for i in range(ways): 382 | for j in range(args.shots): 383 | 384 | AImages[i*args.shots*(1+args.augnum)+j*(args.augnum+1)+0] = supportImages[i*args.shots+j] 385 | ABelongs[i*args.shots*(1+args.augnum)+j*(args.augnum+1)+0] = supportBelongs[i*args.shots+j] 386 | Reals[i*args.shots*(1+args.augnum)+j*(args.augnum+1)+0] = supportReals[i*args.shots+j] 387 | 388 | BImages[i*args.shots*(1+args.augnum)+j*(args.augnum+1)+0] = supportImages[i*args.shots+j] 389 | 390 | for k in range(args.augnum): 391 | 392 | p = np.random.randint(0,2) 393 | if p==0: 394 | AImages[i*args.shots*(1+args.augnum)+j*(args.augnum+1)+1+k] = torch.flip(supportImages[i*args.shots+j],[2]) 395 | else: 396 | AImages[i*args.shots*(1+args.augnum)+j*(args.augnum+1)+1+k] = supportImages[i*args.shots+j] 397 | ABelongs[i*args.shots*(1+args.augnum)+j*(args.augnum+1)+1+k] = supportBelongs[i*args.shots+j] 398 | Reals[i*args.shots*(1+args.augnum)+j*(args.augnum+1)+1+k] = supportReals[i*args.shots+j] 399 | 400 | choose = np.random.randint(0,args.chooseNum) 401 | BImages[i*args.shots*(1+args.augnum)+j*(args.augnum+1)+1+k] = image_datasets['test'].get_image(Gallery[bh[i][choose]]) 402 | # BImages[i*args.shots*(1+args.augnum)+j*(args.augnum+1)+1+k] = unImages[bh[i][choose]] 403 | 404 | return AImages,BImages,ABelongs,Reals 405 | 406 | def batchModel(model,AInputs,requireGrad): 407 | Batch = (AInputs.size(0)+args.batchSize-1)//args.batchSize 408 | First = True 409 | Cfeatures = 1 410 | 411 | 412 | for b in range(Batch): 413 | if b4000: 469 | break 470 | 471 | Times = Times + 1 472 | 473 | supportInputs = supportInputs.squeeze(0) 474 | supportLabels = supportLabels.squeeze(0) 475 | supportReals = supportReals.squeeze(0) 476 | 477 | testInputs = testInputs.squeeze(0) 478 | testLabels = testLabels.squeeze(0).cuda() 479 | 480 | ways = supportInputs.size(0)/args.shots 481 | 482 | supportFeatures = batchModel(model,supportInputs,requireGrad=False) 483 | testFeatures = batchModel(model,testInputs,requireGrad=True) 484 | 485 | AInputs, BInputs, ABLabels, ABReals = iterateMix(supportInputs,supportFeatures,supportLabels,supportReals,ways=ways) 486 | 487 | 488 | Batch = (AInputs.size(0)+args.batchSize-1)//args.batchSize 489 | 490 | First = True 491 | Cfeatures = 1 492 | Ccls = 1 493 | Weights = 0 494 | 495 | ''' 496 | Pytorch has a bug. 497 | Per input's size has to be divisble by the number of GPU 498 | So make sure each input's size can be devisble by the number of available GPU 499 | ''' 500 | 501 | for b in range(Batch): 502 | if b 1: 599 | best_model_wts = copy.deepcopy(model.module.state_dict()) 600 | else: 601 | best_model_wts = copy.deepcopy(model.state_dict()) 602 | 603 | 604 | print() 605 | if epoch%2 == 0 : 606 | 607 | torch.save(best_model_wts,os.path.join(rootdir,'models/'+str(args.tensorname)+'.t7')) 608 | print('save!') 609 | 610 | 611 | time_elapsed = time.time() - since 612 | print('Training complete in {:.0f}m {:.0f}s'.format( 613 | time_elapsed // 60, time_elapsed % 60)) 614 | print('Best test Loss: {:4f}'.format(best_loss)) 615 | 616 | # load best model weights 617 | model.load_state_dict(best_model_wts) 618 | return model 619 | 620 | 621 | GNet = train_model(GNet, num_epochs=120) 622 | ## 623 | 624 | # ... after training, save your model 625 | 626 | if torch.cuda.device_count() > 1: 627 | torch.save(GNet.module.state_dict(),os.path.join(rootdir,'models/'+str(args.tensorname)+'.t7')) 628 | else: 629 | torch.save(GNet.state_dict(),os.path.join(rootdir,'models/'+str(args.tensorname)+'.t7')) 630 | 631 | # .. to load your previously training model: 632 | #model.load_state_dict(torch.load('mytraining.pt')) 633 | 634 | -------------------------------------------------------------------------------- /option.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import os 4 | 5 | class Options(): 6 | def __init__(self):### 7 | # Training settings 8 | parser = argparse.ArgumentParser(description='Tank Shot') 9 | parser.add_argument('--LR', default=0.001,type=float, 10 | help='Learning rate of the Encoder Network') 11 | parser.add_argument('--clsLR', default=0.001,type=float, 12 | help='Learning rate of the Encoder Network') 13 | parser.add_argument('--batchSize', default=128,type=int, 14 | help='Batch Size') 15 | parser.add_argument('--nthreads', default=8,type=int, 16 | help='threads num to load data') 17 | parser.add_argument('--tensorname',default='resnet18',type=str, 18 | help='tensorboard curve name') 19 | parser.add_argument('--ways', default=5,type=int, 20 | help='number of class for one test') 21 | parser.add_argument('--shots', default=1,type=int, 22 | help='number of pictures of each class to support') 23 | parser.add_argument('--test_num', default=15,type=int, 24 | help='number of pictures of each class for test') 25 | parser.add_argument('--augnum', default=0,type=int, 26 | help='number of augnum') 27 | parser.add_argument('--data',default='miniImageEmbedding',type=str, 28 | help='data loader type') 29 | parser.add_argument('--network',default='None',type=str, 30 | help='load network.t7') 31 | parser.add_argument('--galleryNum', default=30,type=int, 32 | help='number of gallery') 33 | parser.add_argument('--stepSize', default=10,type=int, 34 | help='number of epoch to decay lr') 35 | parser.add_argument('--Fang', default=3,type=int, 36 | help='number of block') 37 | parser.add_argument('--epoch', default=600,type=int, 38 | help='train epoch') 39 | parser.add_argument('--trainways', default=5,type=int, 40 | help='number of class for one episode in training') 41 | parser.add_argument('--fixScale', default=0,type=int, 42 | help='1 means fix Scale ') 43 | parser.add_argument('--GNet',default='none',type=str, 44 | help='load network.t7') 45 | parser.add_argument('--scratch', default=0,type=int, 46 | help='whether to train from scratch') 47 | parser.add_argument('--fixAttention', default=0,type=int, 48 | help='whether to fix attention part') 49 | parser.add_argument('--fixCls', default=0,type=int, 50 | help='whether to fix cls part') 51 | parser.add_argument('--chooseNum', default=15,type=int, 52 | help='number of choosing') 53 | 54 | self.parser = parser 55 | 56 | def parse(self): 57 | return self.parser.parse_args() 58 | -------------------------------------------------------------------------------- /picture/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tankche1/IDeMe-Net/6e8e212da7a9381a5d2bd011e5dc62c23452e010/picture/.DS_Store -------------------------------------------------------------------------------- /picture/FuseNet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tankche1/IDeMe-Net/6e8e212da7a9381a5d2bd011e5dc62c23452e010/picture/FuseNet.png -------------------------------------------------------------------------------- /picture/approach.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tankche1/IDeMe-Net/6e8e212da7a9381a5d2bd011e5dc62c23452e010/picture/approach.png -------------------------------------------------------------------------------- /picture/deformed_images.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tankche1/IDeMe-Net/6e8e212da7a9381a5d2bd011e5dc62c23452e010/picture/deformed_images.png -------------------------------------------------------------------------------- /picture/meta_learning.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tankche1/IDeMe-Net/6e8e212da7a9381a5d2bd011e5dc62c23452e010/picture/meta_learning.png --------------------------------------------------------------------------------