├── README.md ├── LICENSE ├── train.py ├── test.py ├── Networks.py ├── dataset_test.py ├── dataset_train.py └── model.py /README.md: -------------------------------------------------------------------------------- 1 | # MGRL 2 | Multi-Granularity Representation Learning for Sketch-based Dynamic Face Image Retrieval 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 ddw2AIGROUP2CQUPT 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | from model import MGRL_Model 4 | from dataset_train import get_dataloader 5 | import argparse 6 | 7 | torch.manual_seed(42) 8 | torch.cuda.manual_seed_all(42) 9 | from torch.utils.tensorboard import SummaryWriter 10 | torch.backends.cudnn.deterministic = True 11 | torch.backends.cudnn.benchmark = False 12 | 13 | if __name__ == "__main__": 14 | 15 | parser = argparse.ArgumentParser(description='MGRL Model') 16 | parser.add_argument('--dataset_name', type=str, default='Face-450', help='Face-1000 / Face-450') 17 | parser.add_argument('--root_dir', type=str, default='') 18 | parser.add_argument('--nThreads', type=int, default=4) 19 | parser.add_argument('--backbone_lr', type=float, default=0.00005) 20 | parser.add_argument('--lr', type=float, default=0.0005) 21 | parser.add_argument('--max_epoch', type=int, default=20) 22 | parser.add_argument('--print_freq_iter', type=int, default=1) 23 | parser.add_argument('--gpu_id', type=int, default=3) 24 | parser.add_argument('--feature_num', type=int, default=8) 25 | parser.add_argument('--condition', type=int, default=0) 26 | parser.add_argument('--distance_select',type=str,default='com+part4+part9') 27 | hp = parser.parse_args() 28 | tb_logdir = r"./run/" 29 | 30 | if hp.dataset_name == 'Face-1000': 31 | hp.batchsize = 32 32 | hp.eval_freq_iter = 50 33 | hp.backbone_lr = 0.0005 34 | hp.lr = 0.005 35 | elif hp.dataset_name == 'Face-450': 36 | hp.batchsize = 32 37 | hp.eval_freq_iter = 20 38 | hp.backbone_lr = 0.00005 39 | hp.lr = 0.0005 40 | 41 | if hp.condition: 42 | hp.condition_num = 10 43 | else: 44 | hp.condition_num = 0 45 | 46 | hp.device = torch.device("cuda:"+str(hp.gpu_id) if torch.cuda.is_available() else "cpu") 47 | dataloader_Train = get_dataloader(hp) 48 | print(hp) 49 | tb_writer = SummaryWriter(log_dir=tb_logdir) 50 | model = MGRL_Model(hp) 51 | model.to(hp.device) 52 | step_count, top1, top5, top10, top50, top100 = -1, 0, 0, 0, 0, 0 53 | mean_IOU_buffer = 0 54 | real_p = [0, 0, 0, 0, 0, 0] 55 | 56 | for i_epoch in range(hp.max_epoch): 57 | for batch_data in dataloader_Train: 58 | step_count = step_count + 1 59 | start = time.time() 60 | model.train() 61 | loss = model.train_model(batch=batch_data) 62 | tb_writer.add_scalars('loss',{'loss':loss}, step_count) 63 | if step_count % hp.eval_freq_iter==0 and int(step_count / hp.eval_freq_iter)>50: 64 | print('Epoch: {},Iteration: {},Loss:{:.8f}'.format(i_epoch,step_count,loss)) 65 | torch.save(model.backbone_network.state_dict(), 66 | hp.dataset_name + '_f' + str(hp.feature_num) +'_' +str(int(step_count / hp.eval_freq_iter)) + '_backbone.pth') 67 | torch.save(model.attn_network.state_dict(), 68 | hp.dataset_name + '_f' + str(hp.feature_num) +'_' +str(int(step_count / hp.eval_freq_iter)) + '_attn.pth') 69 | torch.save(model.linear_network.state_dict(), 70 | hp.dataset_name + '_f' + str(hp.feature_num) +'_'+str(int(step_count / hp.eval_freq_iter)) + '_linear.pth') 71 | torch.save(model.block.state_dict(), 72 | hp.dataset_name + '_f' + str(hp.feature_num) +'_' +str(int(step_count / hp.eval_freq_iter)) + '_block.pth') 73 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from turtle import distance 2 | import torch 3 | import time 4 | from model import MGRL_Model 5 | import os 6 | from dataset_test import get_dataloader 7 | import argparse 8 | import torch.nn as nn 9 | from Networks import InceptionV3_Network, Attention, Linear,residual_block 10 | from torch import optim 11 | import numpy as np 12 | import torch 13 | import time 14 | import torch.nn.functional as F 15 | import math 16 | torch.manual_seed(42) 17 | torch.cuda.manual_seed_all(42) 18 | 19 | torch.backends.cudnn.deterministic = True 20 | torch.backends.cudnn.benchmark = False 21 | 22 | if __name__ == "__main__": 23 | parser = argparse.ArgumentParser(description='MGRL Model') 24 | parser.add_argument('--dataset_name', type=str, default='Face-1000', help='Face-1000 / Face-450') 25 | parser.add_argument('--root_dir', type=str, default='') 26 | parser.add_argument('--nThreads', type=int, default=4) 27 | parser.add_argument('--backbone_lr', type=float, default=0.0005) 28 | parser.add_argument('--lr', type=float, default=0.005) 29 | parser.add_argument('--max_epoch', type=int, default=20) 30 | parser.add_argument('--print_freq_iter', type=int, default=1) 31 | parser.add_argument('--gpu_id', type=int, default=3) 32 | parser.add_argument('--feature_num', type=int, default=16) 33 | parser.add_argument('--condition', type=int, default=0) 34 | parser.add_argument('--distance_select',type=str,default='com_1+part4_1+part9_1') 35 | hp = parser.parse_args() 36 | if hp.dataset_name == 'Face-1000': 37 | hp.batchsize = 32 38 | hp.eval_freq_iter = 50 39 | hp.backbone_lr = 0.0005 40 | hp.lr = 0.005 41 | elif hp.dataset_name == 'Face-450': 42 | hp.batchsize = 32 43 | hp.eval_freq_iter = 20 44 | hp.backbone_lr = 0.00005 45 | hp.lr = 0.0005 46 | 47 | if hp.condition: 48 | hp.condition_num = 10 49 | else: 50 | hp.condition_num = 0 51 | 52 | hp.device = torch.device("cuda:"+str(hp.gpu_id) if torch.cuda.is_available() else "cpu") 53 | dataloader_Test = get_dataloader(hp) 54 | print(hp) 55 | model = MGRL_Model(hp) 56 | model.to(hp.device) 57 | mean_IOU_buffer = 0 58 | real_p = [0, 0, 0, 0, 0, 0] 59 | model_root_dir=os.path.join('/model/') 60 | model.backbone_network.load_state_dict(torch.load(model_root_dir+hp.dataset_name+'_f'+str(hp.feature_num)+'_best'+'_backbone.pth')) 61 | model.attn_network.load_state_dict(torch.load(model_root_dir+hp.dataset_name+'_f'+str(hp.feature_num)+'_best'+'_attn.pth')) 62 | model.linear_network.load_state_dict(torch.load(model_root_dir+hp.dataset_name+'_f'+str(hp.feature_num)+'_best'+'_linear.pth')) 63 | model.block.load_state_dict(torch.load(model_root_dir+hp.dataset_name+'_f'+str(hp.feature_num)+'_best'+'_block.pth')) 64 | 65 | with torch.no_grad(): 66 | start_time = time.time() 67 | top1, top5, top10, mean_IOU, mean_MA, mean_OurB, mean_OurA = model.evaluate_NN(dataloader_Test) 68 | print("TEST A@1: {}".format(top1)) 69 | print("TEST A@5: {}".format(top5)) 70 | print("TEST A@10: {}".format(top10)) 71 | print("TEST M@B: {}".format(mean_IOU)) 72 | print("TEST M@A: {}".format(mean_MA)) 73 | print("TEST OurB: {}".format(mean_OurB)) 74 | print("TEST OurA: {}".format(mean_OurA)) 75 | print("TEST Time: {}".format(time.time()-start_time)) 76 | if mean_IOU > mean_IOU_buffer: 77 | mean_IOU_buffer = mean_IOU 78 | real_p = [top1, top5, top10, mean_MA, mean_OurB, mean_OurA] 79 | print('Model Upgrate') 80 | print('REAL performance: Top1: {}, Top5: {}, Top10: {}, MB: {}, MA: {}, wMB: {}, wMA: {},'.format(real_p[0], real_p[1], 81 | real_p[2], 82 | mean_IOU_buffer, 83 | real_p[3], 84 | real_p[4], 85 | real_p[5])) -------------------------------------------------------------------------------- /Networks.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torchvision.models as backbone_ 3 | import torch.nn.functional as F 4 | import torch 5 | torch.manual_seed(42) 6 | torch.cuda.manual_seed_all(42) 7 | 8 | torch.backends.cudnn.deterministic = True 9 | torch.backends.cudnn.benchmark = False 10 | 11 | class InceptionV3_Network(nn.Module): 12 | def __init__(self): 13 | super(InceptionV3_Network, self).__init__() 14 | backbone = backbone_.inception_v3(pretrained=True) 15 | 16 | self.Conv2d_1a_3x3 = backbone.Conv2d_1a_3x3 17 | self.Conv2d_2a_3x3 = backbone.Conv2d_2a_3x3 18 | self.Conv2d_2b_3x3 = backbone.Conv2d_2b_3x3 19 | self.Conv2d_3b_1x1 = backbone.Conv2d_3b_1x1 20 | self.Conv2d_4a_3x3 = backbone.Conv2d_4a_3x3 21 | for param in self.parameters(): 22 | param.requires_grad = False 23 | 24 | self.Mixed_5b = backbone.Mixed_5b 25 | self.Mixed_5c = backbone.Mixed_5c 26 | self.Mixed_5d = backbone.Mixed_5d 27 | self.Mixed_6a = backbone.Mixed_6a 28 | self.Mixed_6b = backbone.Mixed_6b 29 | self.Mixed_6c = backbone.Mixed_6c 30 | self.Mixed_6d = backbone.Mixed_6d 31 | self.Mixed_6e = backbone.Mixed_6e 32 | self.Mixed_7a = backbone.Mixed_7a 33 | self.Mixed_7b = backbone.Mixed_7b 34 | self.Mixed_7c = backbone.Mixed_7c 35 | 36 | 37 | def forward(self, x,size_num): 38 | # N x 3 x 299 x 299 39 | x = self.Conv2d_1a_3x3(x) 40 | # N x 32 x 149 x 149 41 | x = self.Conv2d_2a_3x3(x) 42 | # N x 32 x 147 x 147 43 | x = self.Conv2d_2b_3x3(x) 44 | # N x 64 x 147 x 147 45 | x = F.max_pool2d(x, kernel_size=3, stride=2) 46 | # N x 64 x 73 x 73 47 | x = self.Conv2d_3b_1x1(x) 48 | # N x 80 x 73 x 73 49 | x = self.Conv2d_4a_3x3(x) 50 | # N x 192 x 71 x 71 51 | x = F.max_pool2d(x, kernel_size=3, stride=2) 52 | # N x 192 x 35 x 35 53 | x = self.Mixed_5b(x) 54 | # N x 256 x 35 x 35 55 | x = self.Mixed_5c(x) 56 | # N x 288 x 35 x 35 57 | x = self.Mixed_5d(x) 58 | # N x 288 x 35 x 35 59 | x = self.Mixed_6a(x) 60 | # N x 768 x 17 x 17 61 | x = self.Mixed_6b(x) 62 | # N x 768 x 17 x 17 63 | x = self.Mixed_6c(x) 64 | # N x 768 x 17 x 17 65 | x = self.Mixed_6d(x) 66 | # N x 768 x 17 x 17 67 | x = self.Mixed_6e(x) 68 | if size_num==0: 69 | return F.normalize(x) 70 | elif size_num==1: 71 | # N x 768 x 17 x 17 72 | x = self.Mixed_7a(x) 73 | # N x 1280 x 8 x 8 74 | x = self.Mixed_7b(x) 75 | # N x 2048 x 8 x 8 76 | x = self.Mixed_7c(x) 77 | return F.normalize(x) 78 | 79 | class Attention(nn.Module): 80 | def __init__(self): 81 | super(Attention, self).__init__() 82 | self.net = nn.Sequential(nn.Conv2d(2048, 512, kernel_size=1), 83 | nn.BatchNorm2d(512), 84 | nn.ReLU(), 85 | nn.Conv2d(512, 1, kernel_size=1)) 86 | self.pool_method = nn.AdaptiveMaxPool2d(1) # as default 87 | 88 | def forward(self, x): 89 | attn_mask = self.net(x) 90 | attn_mask = attn_mask.view(attn_mask.size(0), -1) 91 | attn_mask = nn.Softmax(dim=1)(attn_mask) 92 | attn_mask = attn_mask.view(attn_mask.size(0), 1, x.size(2), x.size(3)) 93 | x = x + (x * attn_mask) 94 | x = self.pool_method(x).view(-1, 2048) 95 | return F.normalize(x) 96 | 97 | class Linear(nn.Module): 98 | def __init__(self, feature_num): 99 | super(Linear, self).__init__() 100 | self.head_layer = nn.Linear(2048, feature_num) 101 | 102 | 103 | def forward(self, x): 104 | return F.normalize(self.head_layer(x)) 105 | 106 | 107 | class residual_block(nn.Module): 108 | def __init__(self, strides=1, same_shape=True, bottle=True): 109 | super(residual_block, self).__init__() 110 | self.same_shape = same_shape 111 | self.bottle = bottle 112 | if not same_shape: 113 | strides = 2 114 | self.strides = strides 115 | self.block = nn.Sequential( 116 | nn.Conv2d(768, 512, kernel_size=1, bias=False), 117 | nn.BatchNorm2d(512), 118 | nn.ReLU(inplace=True), 119 | 120 | nn.Conv2d(512, 512, kernel_size=3, stride=strides, padding=1, bias=False), 121 | nn.BatchNorm2d(512), 122 | nn.ReLU(inplace=True), 123 | 124 | nn.Conv2d(512, 2048, kernel_size=1, bias=False), 125 | nn.BatchNorm2d(2048) 126 | ) 127 | 128 | self.shortcut = nn.Sequential( 129 | nn.Conv2d(768, 2048, kernel_size=1, bias=False), 130 | nn.BatchNorm2d(2048) 131 | ) 132 | self.relu = nn.ReLU() 133 | def forward(self, x): 134 | # print(x.size()) 135 | out = self.block(x) 136 | identity = self.shortcut(x) 137 | out = self.relu(out + identity) 138 | 139 | return F.normalize(out) 140 | -------------------------------------------------------------------------------- /dataset_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from glob import glob 4 | import torch.utils.data as data 5 | import torchvision.transforms as transforms 6 | import os 7 | import pandas as pd 8 | from random import randint 9 | from PIL import Image 10 | import random 11 | import torchvision.transforms.functional as F 12 | 13 | torch.manual_seed(42) 14 | torch.cuda.manual_seed_all(42) 15 | np.random.seed(42) 16 | torch.backends.cudnn.deterministic = True 17 | torch.backends.cudnn.benchmark = False 18 | 19 | def split_img_4(p_img,n_img,s_img): 20 | p_splitimg = [] 21 | n_splitimg = [] 22 | s_splitimg = [] 23 | bool_mat = [] 24 | size_img = p_img.size 25 | weight = int(size_img[0] // 2) 26 | height = int(size_img[1] // 2) 27 | for j in range(2): 28 | for k in range(2): 29 | box = (weight * k, height * j, weight * (k + 1), height * (j + 1)) 30 | p_region = p_img.crop(box) 31 | n_region = n_img.crop(box) 32 | s_region = s_img.crop(box) 33 | s_region_jud = np.array(s_img.crop(box)) 34 | if s_region_jud.sum()==0: 35 | bool_mat.append(0) 36 | else: 37 | bool_mat.append(1) 38 | p_splitimg.append(p_region) 39 | s_splitimg.append(s_region) 40 | n_splitimg.append(n_region) 41 | return p_splitimg,n_splitimg,s_splitimg,bool_mat 42 | 43 | def split_img_4_test(p_img,s_img): 44 | p_splitimg = [] 45 | s_splitimg = [] 46 | bool_mat = [] 47 | size_img = p_img.size 48 | weight = int(size_img[0] // 2) 49 | height = int(size_img[1] // 2) 50 | for j in range(2): 51 | for k in range(2): 52 | box = (weight * k, height * j, weight * (k + 1), height * (j + 1)) 53 | p_region = p_img.crop(box) 54 | s_region = s_img.crop(box) 55 | s_region_jud = np.array(s_img.crop(box)) 56 | if s_region_jud.sum()==0: 57 | bool_mat.append(0) 58 | else: 59 | bool_mat.append(1) 60 | p_splitimg.append(p_region) 61 | s_splitimg.append(s_region) 62 | return p_splitimg,s_splitimg,bool_mat 63 | def split_img_9(p_img,n_img,s_img): 64 | p_splitimg = [] 65 | n_splitimg = [] 66 | bool_mat = [] 67 | s_splitimg = [] 68 | size_img = p_img.size 69 | weight = int(size_img[0] // 3) 70 | height = int(size_img[1] // 3) 71 | for j in range(3): 72 | for k in range(3): 73 | box = (weight * k, height * j, weight * (k + 1), height * (j + 1)) 74 | p_region = p_img.crop(box) 75 | n_region = n_img.crop(box) 76 | s_region = s_img.crop(box) 77 | s_region_jud = np.array(s_img.crop(box)) 78 | if s_region_jud.sum()==0: 79 | bool_mat.append(0) 80 | else: 81 | bool_mat.append(1) 82 | p_splitimg.append(p_region) 83 | s_splitimg.append(s_region) 84 | n_splitimg.append(n_region) 85 | return p_splitimg,n_splitimg,s_splitimg,bool_mat 86 | 87 | def split_img_9_test(p_img,s_img): 88 | p_splitimg = [] 89 | bool_mat = [] 90 | s_splitimg = [] 91 | size_img = p_img.size 92 | weight = int(size_img[0] // 3) 93 | height = int(size_img[1] // 3) 94 | for j in range(3): 95 | for k in range(3): 96 | box = (weight * k, height * j, weight * (k + 1), height * (j + 1)) 97 | p_region = p_img.crop(box) 98 | s_region = s_img.crop(box) 99 | s_region_jud = np.array(s_region) 100 | if s_region_jud.sum()==0: 101 | bool_mat.append(0) 102 | else: 103 | bool_mat.append(1) 104 | p_splitimg.append(p_region) 105 | s_splitimg.append(s_region) 106 | return p_splitimg,s_splitimg,bool_mat 107 | 108 | def get_transform(type): 109 | transform_list = [] 110 | if type == 'Train': 111 | transform_list.extend([transforms.Resize(320), transforms.CenterCrop(299)]) 112 | elif type == 'Test': 113 | transform_list.extend([transforms.Resize(299)]) 114 | transform_list.extend( 115 | [transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 116 | return transforms.Compose(transform_list) 117 | 118 | def get_part4_transform(type): 119 | transform_list = [] 120 | if type == 'Train': 121 | transform_list.extend([transforms.Resize(180), transforms.CenterCrop(170)]) 122 | elif type == 'Test': 123 | transform_list.extend([transforms.Resize(170)]) 124 | transform_list.extend( 125 | [transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 126 | return transforms.Compose(transform_list) 127 | 128 | def get_part9_transform(type): 129 | transform_list = [] 130 | if type == 'Train': 131 | transform_list.extend([transforms.Resize(180), transforms.CenterCrop(170)]) 132 | elif type == 'Test': 133 | transform_list.extend([transforms.Resize(170)]) 134 | transform_list.extend( 135 | [transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 136 | return transforms.Compose(transform_list) 137 | 138 | class MGRL_Dataset(data.Dataset): 139 | def __init__(self, hp, mode): 140 | 141 | self.hp = hp 142 | self.mode = mode 143 | 144 | if hp.dataset_name == "Face-1000": 145 | self.root_dir = os.path.join(hp.root_dir, 'Dataset', '1000') 146 | elif hp.dataset_name == "Face-450": 147 | self.root_dir = os.path.join(hp.root_dir, 'Dataset', '450') 148 | 149 | self.train_photo_paths = sorted(glob(os.path.join(self.root_dir, 'comp', 'train','photo', '*'))) 150 | self.train_sketch_paths = sorted(glob(os.path.join(self.root_dir, 'sketch', 'train', '*'))) 151 | self.test_photo_paths = sorted(glob(os.path.join(self.root_dir, 'comp', 'test', 'photo', '*'))) 152 | self.test_sketch_paths = sorted(glob(os.path.join(self.root_dir, 'sketch', 'test', '*'))) 153 | 154 | self.train_transform = get_transform('Train') 155 | self.test_transform = get_transform('Test') 156 | self.train_transform_split4 = get_part4_transform('Train') 157 | self.test_transform_split4 = get_part4_transform('Test') 158 | self.train_transform_split9 = get_part9_transform('Train') 159 | self.test_transform_split9 = get_part9_transform('Test') 160 | 161 | def __getitem__(self, item): 162 | sample = {} 163 | if self.mode == 'Train': 164 | sketch_path = self.train_sketch_paths[item] 165 | positive_name = 'image' + sketch_path.split('/')[-1].split('_')[0][6:] 166 | positive_path = os.path.join(self.root_dir, 'comp', 'train', 'photo', positive_name + '.jpg') 167 | negative_path = self.train_photo_paths[randint(0, len(self.train_photo_paths) - 1)] 168 | negative_name = negative_path.split('/')[-1].split('.')[0] 169 | 170 | sketch_img = np.array(Image.open(sketch_path).convert('RGB')) 171 | sketch_img = Image.fromarray(sketch_img).convert('RGB') 172 | 173 | positive_img = Image.open(positive_path).resize((sketch_img.size[0],sketch_img.size[1])).convert('RGB') 174 | negative_img = Image.open(negative_path).resize((sketch_img.size[0],sketch_img.size[1])).convert('RGB') 175 | 176 | n_flip = random.random() 177 | if n_flip > 0.5: 178 | sketch_img = F.hflip(sketch_img) 179 | positive_img = F.hflip(positive_img) 180 | negative_img = F.hflip(negative_img) 181 | 182 | positive_split4,negative_split4,sketch_split4,bool_mat_4=split_img_4(positive_img,negative_img,sketch_img) 183 | positive_split9,negative_split9,sketch_split9,bool_mat_9=split_img_9(positive_img,negative_img,sketch_img) 184 | 185 | sketch_img = self.train_transform(sketch_img) 186 | positive_img = self.train_transform(positive_img) 187 | negative_img = self.train_transform(negative_img) 188 | 189 | sketch_part4 = [self.train_transform_split4(sketch) for sketch in sketch_split4] 190 | positive_part4 = [self.train_transform_split4(positive) for positive in positive_split4] 191 | negative_part4 = [self.train_transform_split4(negative) for negative in negative_split4] 192 | 193 | sketch_part9 = [self.train_transform_split9(sketch) for sketch in sketch_split9] 194 | positive_part9 = [self.train_transform_split9(positive) for positive in positive_split9] 195 | negative_part9 = [self.train_transform_split9(negative) for negative in negative_split9] 196 | 197 | 198 | sample = {'sketch_img': sketch_img, 'sketch_part4': sketch_part4,'sketch_part9': sketch_part9,'sketch_path': sketch_path, 199 | 'positive_img': positive_img,'positive_part4':positive_part4, 'positive_part9':positive_part9,'positive_path': positive_path, 200 | 'negative_img': negative_img,'negative_part4':negative_part4,'negative_part9':negative_part9, 'negative_path': negative_path, 201 | 'bool_mat_4':bool_mat_4,'bool_mat_9':bool_mat_9} 202 | 203 | elif self.mode == 'Test': 204 | 205 | sketch_path = self.test_sketch_paths[item] 206 | positive_name = 'image' + sketch_path.split('/')[-1].split('_')[0][6:] 207 | positive_path = os.path.join(self.root_dir, 'comp', 'test', 'photo', positive_name + '.jpg') 208 | sketch_img = np.array(Image.open(sketch_path).convert('RGB')) 209 | sketch_img = Image.fromarray(sketch_img).convert('RGB') 210 | positive_img = Image.open(positive_path).resize((sketch_img.size[0],sketch_img.size[1])).convert('RGB') 211 | 212 | 213 | positive_split4,sketch_split4,bool_mat_4=split_img_4_test(positive_img,sketch_img) 214 | positive_split9,sketch_split9,bool_mat_9=split_img_9_test(positive_img,sketch_img) 215 | bool_mat=[] 216 | bool_mat.extend(bool_mat_4) 217 | bool_mat.extend(bool_mat_9) 218 | # np_sketch=np.array(sketch_img) 219 | sketch_img = self.test_transform(sketch_img) 220 | positive_img = self.test_transform(positive_img) 221 | sketch_part = [] 222 | sketch_part4 = [self.test_transform_split4(sketch) for sketch in sketch_split4] 223 | positive_part4 = [self.test_transform_split4(positive) for positive in positive_split4] 224 | sketch_part9 = [self.test_transform_split9(sketch) for sketch in sketch_split9] 225 | positive_part9 = [self.test_transform_split9(positive) for positive in positive_split9] 226 | 227 | sketch_part.extend(sketch_part4) 228 | sketch_part.extend(sketch_part9) 229 | 230 | sample = {'sketch_img': sketch_img, 'sketch_part':sketch_part, 'sketch_path': sketch_path, 231 | 'positive_img': positive_img,'positive_part4':positive_part4, 'positive_part9':positive_part9,'positive_path': positive_path,'bool_mat':bool_mat} 232 | 233 | return sample 234 | 235 | def __len__(self): 236 | if self.mode == 'Train': 237 | return len(self.train_sketch_paths) 238 | elif self.mode == 'Test': 239 | return len(self.test_sketch_paths) 240 | 241 | 242 | def get_dataloader(hp): 243 | 244 | dataset_Test = MGRL_Dataset(hp, mode='Test') 245 | dataloader_Test = data.DataLoader(dataset_Test, batch_size=70, shuffle=False, num_workers=0) 246 | return dataloader_Test 247 | -------------------------------------------------------------------------------- /dataset_train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from glob import glob 4 | import torch.utils.data as data 5 | import torchvision.transforms as transforms 6 | import os 7 | import pandas as pd 8 | from random import randint 9 | from PIL import Image 10 | import random 11 | import torchvision.transforms.functional as F 12 | 13 | torch.manual_seed(42) 14 | torch.cuda.manual_seed_all(42) 15 | np.random.seed(42) 16 | torch.backends.cudnn.deterministic = True 17 | torch.backends.cudnn.benchmark = False 18 | 19 | def split_img_4(p_img,n_img,s_img): 20 | p_splitimg = [] 21 | n_splitimg = [] 22 | s_splitimg = [] 23 | bool_mat = [] 24 | size_img = p_img.size 25 | weight = int(size_img[0] // 2) 26 | height = int(size_img[1] // 2) 27 | for j in range(2): 28 | for k in range(2): 29 | box = (weight * k, height * j, weight * (k + 1), height * (j + 1)) 30 | p_region = p_img.crop(box) 31 | n_region = n_img.crop(box) 32 | s_region = s_img.crop(box) 33 | s_region_jud = np.array(s_img.crop(box)) 34 | if s_region_jud.sum()==0: 35 | bool_mat.append(0) 36 | else: 37 | bool_mat.append(1) 38 | p_splitimg.append(p_region) 39 | s_splitimg.append(s_region) 40 | n_splitimg.append(n_region) 41 | return p_splitimg,n_splitimg,s_splitimg,bool_mat 42 | 43 | def split_img_4_test(p_img,s_img): 44 | p_splitimg = [] 45 | s_splitimg = [] 46 | bool_mat = [] 47 | size_img = p_img.size 48 | weight = int(size_img[0] // 2) 49 | height = int(size_img[1] // 2) 50 | for j in range(2): 51 | for k in range(2): 52 | box = (weight * k, height * j, weight * (k + 1), height * (j + 1)) 53 | p_region = p_img.crop(box) 54 | s_region = s_img.crop(box) 55 | s_region_jud = np.array(s_img.crop(box)) 56 | if s_region_jud.sum()==0: 57 | bool_mat.append(0) 58 | else: 59 | bool_mat.append(1) 60 | p_splitimg.append(p_region) 61 | s_splitimg.append(s_region) 62 | return p_splitimg,s_splitimg,bool_mat 63 | def split_img_9(p_img,n_img,s_img): 64 | p_splitimg = [] 65 | n_splitimg = [] 66 | bool_mat = [] 67 | s_splitimg = [] 68 | size_img = p_img.size 69 | weight = int(size_img[0] // 3) 70 | height = int(size_img[1] // 3) 71 | for j in range(3): 72 | for k in range(3): 73 | box = (weight * k, height * j, weight * (k + 1), height * (j + 1)) 74 | p_region = p_img.crop(box) 75 | n_region = n_img.crop(box) 76 | s_region = s_img.crop(box) 77 | s_region_jud = np.array(s_img.crop(box)) 78 | if s_region_jud.sum()==0: 79 | bool_mat.append(0) 80 | else: 81 | bool_mat.append(1) 82 | p_splitimg.append(p_region) 83 | s_splitimg.append(s_region) 84 | n_splitimg.append(n_region) 85 | return p_splitimg,n_splitimg,s_splitimg,bool_mat 86 | 87 | def split_img_9_test(p_img,s_img): 88 | p_splitimg = [] 89 | bool_mat = [] 90 | s_splitimg = [] 91 | size_img = p_img.size 92 | weight = int(size_img[0] // 3) 93 | height = int(size_img[1] // 3) 94 | for j in range(3): 95 | for k in range(3): 96 | box = (weight * k, height * j, weight * (k + 1), height * (j + 1)) 97 | p_region = p_img.crop(box) 98 | s_region = s_img.crop(box) 99 | s_region_jud = np.array(s_region) 100 | if s_region_jud.sum()==0: 101 | bool_mat.append(0) 102 | else: 103 | bool_mat.append(1) 104 | p_splitimg.append(p_region) 105 | s_splitimg.append(s_region) 106 | return p_splitimg,s_splitimg,bool_mat 107 | 108 | def get_transform(type): 109 | transform_list = [] 110 | if type == 'Train': 111 | transform_list.extend([transforms.Resize(320), transforms.CenterCrop(299)]) 112 | elif type == 'Test': 113 | transform_list.extend([transforms.Resize([299,299])]) 114 | transform_list.extend( 115 | [transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 116 | return transforms.Compose(transform_list) 117 | 118 | def get_part4_transform(type): 119 | transform_list = [] 120 | if type == 'Train': 121 | transform_list.extend([transforms.Resize(180), transforms.CenterCrop(170)]) 122 | elif type == 'Test': 123 | transform_list.extend([transforms.Resize([170,170])]) 124 | transform_list.extend( 125 | [transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 126 | return transforms.Compose(transform_list) 127 | 128 | def get_part9_transform(type): 129 | transform_list = [] 130 | if type == 'Train': 131 | transform_list.extend([transforms.Resize(180), transforms.CenterCrop(170)]) 132 | elif type == 'Test': 133 | transform_list.extend([transforms.Resize([170,170])]) 134 | transform_list.extend( 135 | [transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 136 | return transforms.Compose(transform_list) 137 | 138 | class MGRL_Dataset(data.Dataset): 139 | def __init__(self, hp, mode): 140 | 141 | self.hp = hp 142 | self.mode = mode 143 | 144 | if hp.dataset_name == "Face-1000": 145 | self.root_dir = os.path.join(hp.root_dir, 'Dataset', '1000') 146 | elif hp.dataset_name == "Face-450": 147 | self.root_dir = os.path.join(hp.root_dir, 'Dataset', '450') 148 | 149 | self.train_photo_paths = sorted(glob(os.path.join(self.root_dir, 'comp', 'train','photo', '*'))) 150 | self.train_sketch_paths = sorted(glob(os.path.join(self.root_dir, 'sketch', 'train', '*'))) 151 | self.test_photo_paths = sorted(glob(os.path.join(self.root_dir, 'comp', 'test', 'photo', '*'))) 152 | 153 | self.test_sketch_paths = sorted(glob(os.path.join(self.root_dir, 'sketch', 'test', '*'))) 154 | 155 | self.train_transform = get_transform('Train') 156 | self.test_transform = get_transform('Test') 157 | self.train_transform_split4 = get_part4_transform('Train') 158 | self.test_transform_split4 = get_part4_transform('Test') 159 | self.train_transform_split9 = get_part9_transform('Train') 160 | self.test_transform_split9 = get_part9_transform('Test') 161 | 162 | def __getitem__(self, item): 163 | sample = {} 164 | if self.mode == 'Train': 165 | sketch_path = self.train_sketch_paths[item] 166 | positive_name = 'image' + sketch_path.split('/')[-1].split('_')[0][6:] 167 | positive_path = os.path.join(self.root_dir, 'comp', 'train', 'photo', positive_name + '.jpg') 168 | negative_path = self.train_photo_paths[randint(0, len(self.train_photo_paths) - 1)] 169 | negative_name = negative_path.split('/')[-1].split('.')[0] 170 | 171 | sketch_img = np.array(Image.open(sketch_path).convert('RGB')) 172 | sketch_img = Image.fromarray(sketch_img).convert('RGB') 173 | 174 | positive_img = Image.open(positive_path).resize((sketch_img.size[0],sketch_img.size[1])).convert('RGB') 175 | negative_img = Image.open(negative_path).resize((sketch_img.size[0],sketch_img.size[1])).convert('RGB') 176 | 177 | n_flip = random.random() 178 | if n_flip > 0.5: 179 | sketch_img = F.hflip(sketch_img) 180 | positive_img = F.hflip(positive_img) 181 | negative_img = F.hflip(negative_img) 182 | 183 | positive_split4,negative_split4,sketch_split4,bool_mat_4=split_img_4(positive_img,negative_img,sketch_img) 184 | positive_split9,negative_split9,sketch_split9,bool_mat_9=split_img_9(positive_img,negative_img,sketch_img) 185 | 186 | sketch_img = self.train_transform(sketch_img) 187 | positive_img = self.train_transform(positive_img) 188 | negative_img = self.train_transform(negative_img) 189 | 190 | sketch_part4 = [self.train_transform_split4(sketch) for sketch in sketch_split4] 191 | positive_part4 = [self.train_transform_split4(positive) for positive in positive_split4] 192 | negative_part4 = [self.train_transform_split4(negative) for negative in negative_split4] 193 | 194 | sketch_part9 = [self.train_transform_split9(sketch) for sketch in sketch_split9] 195 | positive_part9 = [self.train_transform_split9(positive) for positive in positive_split9] 196 | negative_part9 = [self.train_transform_split9(negative) for negative in negative_split9] 197 | 198 | 199 | sample = {'sketch_img': sketch_img, 'sketch_part4': sketch_part4,'sketch_part9': sketch_part9,'sketch_path': sketch_path, 200 | 'positive_img': positive_img,'positive_part4':positive_part4, 'positive_part9':positive_part9,'positive_path': positive_path, 201 | 'negative_img': negative_img,'negative_part4':negative_part4,'negative_part9':negative_part9, 'negative_path': negative_path, 202 | 'bool_mat_4':bool_mat_4,'bool_mat_9':bool_mat_9} 203 | 204 | elif self.mode == 'Test': 205 | 206 | sketch_path = self.test_sketch_paths[item] 207 | positive_name = 'image' + sketch_path.split('/')[-1].split('_')[0][6:] 208 | positive_path = os.path.join(self.root_dir, 'comp', 'test', 'photo', positive_name + '.jpg') 209 | 210 | sketch_img = np.array(Image.open(sketch_path).convert('RGB')) 211 | sketch_img = Image.fromarray(sketch_img).convert('RGB') 212 | positive_img = Image.open(positive_path).resize((sketch_img.size[0],sketch_img.size[1])).convert('RGB') 213 | 214 | 215 | positive_split4,sketch_split4,bool_mat_4=split_img_4_test(positive_img,sketch_img) 216 | positive_split9,sketch_split9,bool_mat_9=split_img_9_test(positive_img,sketch_img) 217 | bool_mat=[] 218 | bool_mat.extend(bool_mat_4) 219 | bool_mat.extend(bool_mat_9) 220 | 221 | sketch_img = self.test_transform(sketch_img) 222 | positive_img = self.test_transform(positive_img) 223 | sketch_part = [] 224 | sketch_part4 = [self.test_transform_split4(sketch) for sketch in sketch_split4] 225 | positive_part4 = [self.test_transform_split4(positive) for positive in positive_split4] 226 | sketch_part9 = [self.test_transform_split9(sketch) for sketch in sketch_split9] 227 | positive_part9 = [self.test_transform_split9(positive) for positive in positive_split9] 228 | 229 | sketch_part.extend(sketch_part4) 230 | sketch_part.extend(sketch_part9) 231 | 232 | sample = {'sketch_img': sketch_img, 'sketch_part':sketch_part, 'sketch_path': sketch_path, 233 | 'positive_img': positive_img,'positive_part4':positive_part4, 'positive_part9':positive_part9,'positive_path': positive_path,'bool_mat':bool_mat} 234 | 235 | return sample 236 | 237 | def __len__(self): 238 | if self.mode == 'Train': 239 | return len(self.train_sketch_paths) 240 | elif self.mode == 'Test': 241 | return len(self.test_sketch_paths) 242 | 243 | 244 | def get_dataloader(hp): 245 | dataset_Train = MGRL_Dataset(hp, mode='Train') 246 | dataloader_Train = data.DataLoader(dataset_Train, batch_size=hp.batchsize, shuffle=True, num_workers=int(hp.nThreads)) 247 | return dataloader_Train 248 | 249 | 250 | def get_transform(type): 251 | transform_list = [] 252 | if type == 'Train': 253 | transform_list.extend([transforms.Resize(320), transforms.CenterCrop(299)]) 254 | elif type == 'Test': 255 | transform_list.extend([transforms.Resize(299)]) 256 | transform_list.extend( 257 | [transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 258 | return transforms.Compose(transform_list) 259 | 260 | 261 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | 2 | import torch.nn as nn 3 | from Networks import InceptionV3_Network, Attention, Linear,residual_block 4 | from torch import optim 5 | import numpy as np 6 | import torch 7 | import time 8 | import torch.nn.functional as F 9 | import math 10 | class MGRL_Model(nn.Module): 11 | def __init__(self, hp): 12 | super(MGRL_Model, self).__init__() 13 | 14 | self.backbone_network = InceptionV3_Network() 15 | self.backbone_train_params = self.backbone_network.parameters() 16 | 17 | def init_weights(m): 18 | if type(m) == nn.Linear or type(m) == nn.Conv2d: 19 | nn.init.kaiming_normal_(m.weight) 20 | 21 | self.attn_network = Attention() 22 | self.attn_network.apply(init_weights) 23 | self.attn_train_params = self.attn_network.parameters() 24 | 25 | self.linear_network = Linear(hp.feature_num) 26 | self.linear_network.apply(init_weights) 27 | self.linear_train_params = self.linear_network.parameters() 28 | 29 | self.block=residual_block() 30 | self.block.apply(init_weights) 31 | self.block_train_params=self.block.parameters() 32 | 33 | self.optimizer = optim.Adam([ 34 | {'params': filter(lambda param: param.requires_grad, self.backbone_train_params), 'lr': hp.backbone_lr}, 35 | {'params': self.attn_train_params, 'lr': hp.lr}, 36 | {'params': self.linear_train_params, 'lr': hp.lr}, 37 | {'params': self.block_train_params, 'lr': hp.lr}]) 38 | 39 | self.loss = nn.TripletMarginLoss(margin=0.3) 40 | self.hp = hp 41 | 42 | def train_model(self, batch): 43 | self.train() 44 | positive_feature_local_batch=[] 45 | negative_feature_local_batch=[] 46 | sample_feature_local_batch=[] 47 | 48 | positive_feature_complete = self.linear_network(self.attn_network(self.backbone_network(batch['positive_img'].to(self.hp.device),1))) 49 | negative_feature_complete = self.linear_network(self.attn_network(self.backbone_network(batch['negative_img'].to(self.hp.device),1))) 50 | sample_feature_complete = self.linear_network(self.attn_network(self.backbone_network(batch['sketch_img'].to(self.hp.device),1))) 51 | index_4=[(torch.argwhere(data==1)).view(-1) for data in batch['bool_mat_4']] 52 | index_9=[(torch.argwhere(data==1)).view(-1) for data in batch['bool_mat_9']] 53 | 54 | positive_part4_detain=[torch.index_select(sketch_part,dim=0,index=index_4[i]) for i,sketch_part in enumerate(batch['positive_part4'])] 55 | positive_part9_detain=[torch.index_select(sketch_part,dim=0,index=index_9[i]) for i,sketch_part in enumerate(batch['positive_part9'])] 56 | negative_part4_detain=[torch.index_select(sketch_part,dim=0,index=index_4[i]) for i,sketch_part in enumerate(batch['negative_part4'])] 57 | negative_part9_detain=[torch.index_select(sketch_part,dim=0,index=index_9[i]) for i,sketch_part in enumerate(batch['negative_part9'])] 58 | sample_part4_detain=[torch.index_select(sketch_part,dim=0,index=index_4[i]) for i,sketch_part in enumerate(batch['sketch_part4'])] 59 | sample_part9_detain=[torch.index_select(sketch_part,dim=0,index=index_9[i]) for i,sketch_part in enumerate(batch['sketch_part9'])] 60 | 61 | loss_local=0 62 | loss_complete = self.loss(sample_feature_complete, positive_feature_complete, negative_feature_complete) 63 | 64 | for i in range(len(positive_part4_detain)): 65 | positive_feature_local=self.linear_network(self.attn_network(self.block(self.backbone_network(positive_part4_detain[i].to(self.hp.device),0)))) 66 | negative_feature_local=self.linear_network(self.attn_network(self.block(self.backbone_network(negative_part4_detain[i].to(self.hp.device),0)))) 67 | sample_feature_local=self.linear_network(self.attn_network(self.block(self.backbone_network(sample_part4_detain[i].to(self.hp.device),0)))) 68 | loss_part = self.loss(sample_feature_local,positive_feature_local,negative_feature_local) 69 | loss_local = loss_part+loss_local 70 | 71 | for i in range(len(positive_part9_detain)): 72 | positive_feature_local=self.linear_network(self.attn_network(self.block(self.backbone_network(positive_part9_detain[i].to(self.hp.device),0)))) 73 | negative_feature_local=self.linear_network(self.attn_network(self.block(self.backbone_network(negative_part9_detain[i].to(self.hp.device),0)))) 74 | sample_feature_local=self.linear_network(self.attn_network(self.block(self.backbone_network(sample_part9_detain[i].to(self.hp.device),0)))) 75 | loss_part = self.loss(sample_feature_local,positive_feature_local,negative_feature_local) 76 | 77 | loss=loss_complete+loss_local 78 | 79 | self.optimizer.zero_grad() 80 | loss.backward() 81 | 82 | self.optimizer.step() 83 | return loss.item() 84 | 85 | def evaluate_NN(self, dataloader,a,b): 86 | self.eval() 87 | 88 | self.Sketch_Array_Test = [] 89 | self.Image_Array_Test = [] 90 | Sketch_Feature_ALL_local =[] 91 | Image_Feature_ALL_local = [] 92 | Sketch_exist_local=[] 93 | 94 | for idx, batch in enumerate(dataloader): 95 | if self.hp.condition: 96 | sketch_feature = self.cat_feature(self.attn_network( 97 | self.backbone_network(batch['sketch_img'].to(self.hp.device))), batch['condition'].to(self.hp.device)) 98 | positive_feature = self.linear_network(self.cat_feature(self.attn_network( 99 | self.backbone_network(batch['positive_img'].to(self.hp.device))), batch['condition'].to(self.hp.device))) 100 | 101 | else: 102 | sample_feature_complete = self.attn_network(self.backbone_network(batch['sketch_img'].to(self.hp.device),1)) 103 | positive_feature_complete = self.linear_network(self.attn_network(self.backbone_network(batch['positive_img'].to(self.hp.device),1)))[0] 104 | 105 | positive_feature_local_4 = [self.linear_network(self.attn_network(self.block(self.backbone_network((batch[0].view(1,batch.shape[1],batch.shape[2],batch.shape[3]).to(self.hp.device)),0)))) for batch in batch['positive_part4']] 106 | sample_feature_local = [self.attn_network(self.block(self.backbone_network((batch.to(self.hp.device)),0))) for batch in batch['sketch_part']] 107 | positive_feature_local_9 = [self.linear_network(self.attn_network(self.block(self.backbone_network((batch[0].view(1,batch.shape[1],batch.shape[2],batch.shape[3]).to(self.hp.device)),0)))) for batch in batch['positive_part9']] 108 | 109 | self.Sketch_Array_Test.append(sample_feature_complete) 110 | self.Image_Array_Test.append(positive_feature_complete) 111 | 112 | sketch_local_pool=[] 113 | sketch_local_pool.extend(sample_feature_local) 114 | sketch_local_pool=torch.stack(sketch_local_pool) 115 | Sketch_Feature_ALL_local.append(sketch_local_pool) 116 | 117 | positive_local_pool=[] 118 | positive_local_pool.extend(positive_feature_local_4) 119 | positive_local_pool.extend(positive_feature_local_9) 120 | positive_local_pool=torch.stack(positive_local_pool) 121 | 122 | Image_Feature_ALL_local.append(positive_local_pool) 123 | 124 | 125 | exist_pool=[] 126 | exist_pool.extend(batch['bool_mat']) 127 | 128 | exist_pool=torch.stack(exist_pool) 129 | Sketch_exist_local.append(exist_pool) 130 | 131 | 132 | self.Sketch_Array_Test = torch.stack(self.Sketch_Array_Test).to(self.hp.device) 133 | self.Image_Array_Test = torch.stack(self.Image_Array_Test).to(self.hp.device) 134 | Sketch_Feature_ALL_local =torch.stack(Sketch_Feature_ALL_local).to(self.hp.device) 135 | Image_Feature_ALL_local =torch.stack(Image_Feature_ALL_local).to(self.hp.device) 136 | Sketch_exist_local=torch.stack(Sketch_exist_local).to(self.hp.device) 137 | 138 | num_of_Sketch_Step = len(self.Sketch_Array_Test[0]) 139 | avererage_area = [] 140 | avererage_area_percentile = [] 141 | avererage_ourB = [] 142 | avererage_ourA = [] 143 | 144 | exps = np.linspace(1,num_of_Sketch_Step, num_of_Sketch_Step) / num_of_Sketch_Step 145 | factor = np.exp(1 - exps) / np.e 146 | rank_all = torch.zeros(len(self.Sketch_Array_Test), num_of_Sketch_Step) 147 | rank_all_percentile = torch.zeros(len(self.Sketch_Array_Test), num_of_Sketch_Step) 148 | 149 | 150 | num = list(range(70)) 151 | Xmin = np.min(num) 152 | Xmax = np.max(num) 153 | a = 0 154 | b = 1 155 | Atten_num = a + (b-a)/(Xmax-Xmin)*(num-Xmin) 156 | 157 | for i_batch, sanpled_batch in enumerate(self.Sketch_Array_Test): 158 | mean_rank = [] 159 | mean_rank_percentile = [] 160 | mean_rank_ourB = [] 161 | mean_rank_ourA = [] 162 | 163 | for i_sketch in range(sanpled_batch.shape[0]): 164 | 165 | sketch_feature_complete = self.linear_network(sanpled_batch[i_sketch].unsqueeze(0).to(self.hp.device)) 166 | target_distance_complete = F.pairwise_distance(sketch_feature_complete.to(self.hp.device), self.Image_Array_Test[i_batch].unsqueeze(0).to(self.hp.device)) 167 | 168 | distance_complete = F.pairwise_distance(sketch_feature_complete.to(self.hp.device), self.Image_Array_Test.to(self.hp.device)) 169 | 170 | part_exist=Sketch_exist_local[i_batch,0:13,i_sketch] 171 | part_exist_4=Sketch_exist_local[i_batch,:4,i_sketch] 172 | part_exist_9=Sketch_exist_local[i_batch,4:13,i_sketch] 173 | num_4=np.array(part_exist_4.cpu()).sum() 174 | num_9=np.array(part_exist_9.cpu()).sum() 175 | part_index=torch.argwhere(part_exist==1).view(-1).to(self.hp.device) 176 | 177 | sketch_part_feature_detain=torch.index_select(Sketch_Feature_ALL_local[i_batch,0:13,i_sketch,:],dim=0,index=part_index).to(self.hp.device) 178 | 179 | sketch_part_feature_detain=F.normalize(self.linear_network(sketch_part_feature_detain.unsqueeze(0).to(self.hp.device)).squeeze(0)) 180 | positive_part_feature_detain_target=torch.index_select(Image_Feature_ALL_local[i_batch,0:13,:],dim=0,index=part_index).squeeze(1).to(self.hp.device) 181 | target_distance_part = F.pairwise_distance(sketch_part_feature_detain, positive_part_feature_detain_target) 182 | target_distance_4_mean=torch.sum(target_distance_part[:num_4])/num_4 183 | 184 | target_distance_9_mean=torch.sum(target_distance_part[num_4:num_9+num_4])/num_9 185 | 186 | positive_part_feature_detain=torch.index_select(Image_Feature_ALL_local[:,:13,:,:],dim=1,index=part_index).to(self.hp.device).squeeze(2) 187 | distance_part=F.pairwise_distance(sketch_part_feature_detain.to(self.hp.device),positive_part_feature_detain.to(self.hp.device)) 188 | distance_4_mean=torch.sum(distance_part[:,0:num_4],dim=1)/num_4 189 | distance_9_mean=torch.sum(distance_part[:,num_4:num_9+num_4],dim=1)/num_9 190 | 191 | Attention_num = round(Atten_num[i_sketch], 2) 192 | 193 | if self.hp.distance_select =='com+part4_decned+part9_decend': 194 | target_distance = target_distance_complete+round(math.exp(-(0.1*Attention_num)), 2)*target_distance_4_mean+round(math.exp(-(0.1*Attention_num)), 2)*target_distance_9_mean 195 | distance = distance_complete+round(math.exp(-(0.1*Attention_num)), 2)*distance_4_mean+round(math.exp(-(0.1*Attention_num)), 2)*distance_9_mean 196 | elif self.hp.distance_select =='com+part4_decned': 197 | target_distance = target_distance_complete+round(math.exp(-(0.1*Attention_num)), 2)*target_distance_4_mean 198 | distance = distance_complete+round(math.exp(-(0.1*Attention_num)), 2)*distance_4_mean 199 | elif self.hp.distance_select =='com': 200 | target_distance = target_distance_complete 201 | distance = distance_complete 202 | elif self.hp.distance_select =='com_1+part4_1+part9_1': 203 | target_distance = target_distance_complete+target_distance_4_mean+target_distance_9_mean 204 | distance = distance_complete+distance_4_mean+distance_9_mean 205 | elif self.hp.distance_select =='com_1+part4_+part9_': 206 | target_distance = target_distance_complete+a*target_distance_4_mean+b*target_distance_9_mean 207 | distance = distance_complete+a*distance_4_mean+b*distance_9_mean 208 | elif self.hp.distance_select =='com_1+part4_1+part9_decend': 209 | target_distance = target_distance_complete+target_distance_4_mean+round(math.exp(-(0.1*Attention_num)), 2)*target_distance_9_mean 210 | distance = distance_complete+distance_4_mean+round(math.exp(-(0.1*Attention_num)), 2)*distance_9_mean 211 | elif self.hp.distance_select =='com_1+part4_decend+part9_1': 212 | target_distance = target_distance_complete+round(math.exp(-(0.1*Attention_num)), 2)*target_distance_4_mean+target_distance_9_mean 213 | distance = distance_complete+round(math.exp(-(0.1*Attention_num)), 2)*distance_4_mean+distance_9_mean 214 | 215 | rank_all[i_batch, i_sketch] = distance.le(target_distance).sum() 216 | 217 | rank_all_percentile[i_batch, i_sketch] = (len(distance) - rank_all[i_batch, i_sketch]) / (len(distance) - 1) 218 | if rank_all[i_batch, i_sketch].item() == 0: 219 | 220 | mean_rank.append(1.) 221 | else: 222 | mean_rank.append(1/rank_all[i_batch, i_sketch].item()) 223 | mean_rank_percentile.append(rank_all_percentile[i_batch, i_sketch].item()) 224 | mean_rank_ourB.append(1/rank_all[i_batch, i_sketch].item() * factor[i_sketch]) 225 | mean_rank_ourA.append(rank_all_percentile[i_batch, i_sketch].item()*factor[i_sketch]) 226 | 227 | avererage_area.append(np.sum(mean_rank)/len(mean_rank)) 228 | avererage_area_percentile.append(np.sum(mean_rank_percentile)/len(mean_rank_percentile)) 229 | avererage_ourB.append(np.sum(mean_rank_ourB)/len(mean_rank_ourB)) 230 | avererage_ourA.append(np.sum(mean_rank_ourA)/len(mean_rank_ourA)) 231 | 232 | print(rank_all) 233 | print('MB', list(np.sum(np.array(1 / rank_all), axis=0) / len(rank_all))) 234 | print('MA', list(np.sum(np.array(rank_all_percentile), axis=0)/ len(rank_all))) 235 | print('wMB', list(np.sum(np.array(1 / rank_all), axis=0) / len(rank_all)*factor)) 236 | print('wMA', list(np.sum(np.array(rank_all_percentile), axis=0)/ len(rank_all)*factor)) 237 | top1_accuracy = rank_all[:, -1].le(1).sum().numpy() / rank_all.shape[0] 238 | top5_accuracy = rank_all[:, -1].le(5).sum().numpy() / rank_all.shape[0] 239 | top10_accuracy = rank_all[:, -1].le(10).sum().numpy() / rank_all.shape[0] 240 | #A@1 A@5 A%10 241 | meanIOU = np.mean(avererage_area) 242 | meanMA = np.mean(avererage_area_percentile) 243 | meanOurB = np.mean(avererage_ourB) 244 | meanOurA = np.mean(avererage_ourA) 245 | 246 | return top1_accuracy, top5_accuracy, top10_accuracy, meanIOU, meanMA, meanOurB, meanOurA 247 | 248 | def SortNameByData(self, dataList, nameList): 249 | convertDic = {} 250 | sortedDic = {} 251 | sortedNameList = [] 252 | for index in range(len(dataList)): 253 | convertDic[index] = dataList[index] 254 | sortedDic = sorted(convertDic.items(), key=lambda item: item[1], reverse=False) 255 | for key, _ in sortedDic: 256 | sortedNameList.append(nameList[key]) 257 | return sortedNameList --------------------------------------------------------------------------------