├── README.md ├── stage1 ├── __pycache__ │ ├── dataset.cpython-38.pyc │ ├── model.cpython-38.pyc │ ├── Networks.cpython-38.pyc │ ├── rasterize.cpython-38.pyc │ ├── dataset_FS2K.cpython-38.pyc │ └── render_sketch_chairv2.cpython-38.pyc ├── render_sketch_chairv2.py ├── train.py ├── Networks.py ├── dataset.py └── model.py ├── stage2 ├── __pycache__ │ ├── dataset.cpython-38.pyc │ ├── dataset.cpython-39.pyc │ ├── model.cpython-38.pyc │ ├── model.cpython-39.pyc │ ├── Networks.cpython-38.pyc │ ├── Networks.cpython-39.pyc │ ├── dataset_eval_100.cpython-38.pyc │ ├── render_sketch_chairv2.cpython-38.pyc │ └── render_sketch_chairv2.cpython-39.pyc ├── render_sketch_chairv2.py ├── Networks.py ├── train.py ├── dataset.py └── model.py └── LICENSE /README.md: -------------------------------------------------------------------------------- 1 | # SL-CMSE-SBIR 2 | Sequence Learning based on Cross-modal Semantic Embedding for On-the-fly Sketch-Based Image Retrieval 3 | -------------------------------------------------------------------------------- /stage1/__pycache__/dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ddw2AIGROUP2CQUPT/SL-CMSE-SBIR/HEAD/stage1/__pycache__/dataset.cpython-38.pyc -------------------------------------------------------------------------------- /stage1/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ddw2AIGROUP2CQUPT/SL-CMSE-SBIR/HEAD/stage1/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /stage2/__pycache__/dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ddw2AIGROUP2CQUPT/SL-CMSE-SBIR/HEAD/stage2/__pycache__/dataset.cpython-38.pyc -------------------------------------------------------------------------------- /stage2/__pycache__/dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ddw2AIGROUP2CQUPT/SL-CMSE-SBIR/HEAD/stage2/__pycache__/dataset.cpython-39.pyc -------------------------------------------------------------------------------- /stage2/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ddw2AIGROUP2CQUPT/SL-CMSE-SBIR/HEAD/stage2/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /stage2/__pycache__/model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ddw2AIGROUP2CQUPT/SL-CMSE-SBIR/HEAD/stage2/__pycache__/model.cpython-39.pyc -------------------------------------------------------------------------------- /stage1/__pycache__/Networks.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ddw2AIGROUP2CQUPT/SL-CMSE-SBIR/HEAD/stage1/__pycache__/Networks.cpython-38.pyc -------------------------------------------------------------------------------- /stage1/__pycache__/rasterize.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ddw2AIGROUP2CQUPT/SL-CMSE-SBIR/HEAD/stage1/__pycache__/rasterize.cpython-38.pyc -------------------------------------------------------------------------------- /stage2/__pycache__/Networks.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ddw2AIGROUP2CQUPT/SL-CMSE-SBIR/HEAD/stage2/__pycache__/Networks.cpython-38.pyc -------------------------------------------------------------------------------- /stage2/__pycache__/Networks.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ddw2AIGROUP2CQUPT/SL-CMSE-SBIR/HEAD/stage2/__pycache__/Networks.cpython-39.pyc -------------------------------------------------------------------------------- /stage1/__pycache__/dataset_FS2K.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ddw2AIGROUP2CQUPT/SL-CMSE-SBIR/HEAD/stage1/__pycache__/dataset_FS2K.cpython-38.pyc -------------------------------------------------------------------------------- /stage2/__pycache__/dataset_eval_100.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ddw2AIGROUP2CQUPT/SL-CMSE-SBIR/HEAD/stage2/__pycache__/dataset_eval_100.cpython-38.pyc -------------------------------------------------------------------------------- /stage1/__pycache__/render_sketch_chairv2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ddw2AIGROUP2CQUPT/SL-CMSE-SBIR/HEAD/stage1/__pycache__/render_sketch_chairv2.cpython-38.pyc -------------------------------------------------------------------------------- /stage2/__pycache__/render_sketch_chairv2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ddw2AIGROUP2CQUPT/SL-CMSE-SBIR/HEAD/stage2/__pycache__/render_sketch_chairv2.cpython-38.pyc -------------------------------------------------------------------------------- /stage2/__pycache__/render_sketch_chairv2.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ddw2AIGROUP2CQUPT/SL-CMSE-SBIR/HEAD/stage2/__pycache__/render_sketch_chairv2.cpython-39.pyc -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 ddw2AIGROUP2CQUPT 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /stage2/render_sketch_chairv2.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from bresenham import bresenham 3 | import scipy.ndimage 4 | 5 | def mydrawPNG(vector_images, Sample = 25, Side = 256): 6 | for vector_image in vector_images: 7 | pixel_length = 0 8 | #number_of_samples = random. 9 | sample_freq = list(np.round(np.linspace(0, len(vector_image), 21)[1:])) 10 | #np.linspace(start, stop, num), 11 | Sample_len = [] 12 | raster_images = [] 13 | raster_image = np.zeros((int(Side), int(Side)), dtype=np.float32) 14 | initX, initY = int(vector_image[0, 0]), int(vector_image[0, 1]) 15 | for i in range(0, len(vector_image)): 16 | if i > 0: 17 | if vector_image[i-1, 2] == 1: 18 | initX, initY = int(vector_image[i, 0]), int(vector_image[i, 1]) 19 | 20 | cordList = list(bresenham(initX, initY, int(vector_image[i,0]), int(vector_image[i,1]))) 21 | pixel_length += len(cordList) 22 | 23 | for cord in cordList: 24 | if (cord[0] > 0 and cord[1] > 0) and (cord[0] < Side and cord[1] < Side): 25 | raster_image[cord[1], cord[0]] = 255.0 26 | initX , initY = int(vector_image[i,0]), int(vector_image[i,1]) 27 | 28 | if i in sample_freq: 29 | raster_images.append(scipy.ndimage.binary_dilation(raster_image) * 255.0) 30 | Sample_len.append(pixel_length) 31 | 32 | raster_images.append(scipy.ndimage.binary_dilation(raster_image) * 255.0) 33 | Sample_len.append(pixel_length) 34 | 35 | return raster_images, Sample_len 36 | 37 | def Preprocess_QuickDraw_redraw(vector_images, side = 256.0): 38 | vector_images = vector_images.astype(np.float) 39 | vector_images[:, :2] = vector_images[:, :2] / np.array([256, 256]) 40 | vector_images[:,:2] = vector_images[:,:2] * side 41 | vector_images = np.round(vector_images) 42 | return vector_images 43 | 44 | def redraw_Quick2RGB(vector_images): 45 | vector_images_C = Preprocess_QuickDraw_redraw(vector_images) 46 | raster_images, Sample_len = mydrawPNG([vector_images_C]) 47 | return raster_images, Sample_len 48 | -------------------------------------------------------------------------------- /stage1/render_sketch_chairv2.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from bresenham import bresenham 3 | import scipy.ndimage 4 | import random 5 | 6 | 7 | def mydrawPNG(vector_images, Sample = 25, Side = 256): 8 | for vector_image in vector_images: 9 | pixel_length = 0 10 | #number_of_samples = random. 11 | sample_freq = list(np.round(np.linspace(0, len(vector_image), 21)[1:])) 12 | #np.linspace(start, stop, num), 13 | Sample_len = [] 14 | raster_images = [] 15 | raster_image = np.zeros((int(Side), int(Side)), dtype=np.float32) 16 | initX, initY = int(vector_image[0, 0]), int(vector_image[0, 1]) 17 | for i in range(0, len(vector_image)): 18 | if i > 0: 19 | if vector_image[i-1, 2] == 1: 20 | initX, initY = int(vector_image[i, 0]), int(vector_image[i, 1]) 21 | 22 | cordList = list(bresenham(initX, initY, int(vector_image[i,0]), int(vector_image[i,1]))) 23 | pixel_length += len(cordList) 24 | 25 | for cord in cordList: 26 | if (cord[0] > 0 and cord[1] > 0) and (cord[0] < Side and cord[1] < Side): 27 | raster_image[cord[1], cord[0]] = 255.0 28 | initX , initY = int(vector_image[i,0]), int(vector_image[i,1]) 29 | 30 | if i in sample_freq: 31 | raster_images.append(scipy.ndimage.binary_dilation(raster_image) * 255.0) 32 | Sample_len.append(pixel_length) 33 | 34 | raster_images.append(scipy.ndimage.binary_dilation(raster_image) * 255.0) 35 | Sample_len.append(pixel_length) 36 | 37 | return raster_images, Sample_len 38 | 39 | 40 | def Preprocess_QuickDraw_redraw(vector_images, side = 256.0): 41 | vector_images = vector_images.astype(np.float) 42 | vector_images[:, :2] = vector_images[:, :2] / np.array([256, 256]) 43 | vector_images[:,:2] = vector_images[:,:2] * side 44 | vector_images = np.round(vector_images) 45 | return vector_images 46 | 47 | def redraw_Quick2RGB(vector_images): 48 | vector_images_C = Preprocess_QuickDraw_redraw(vector_images) 49 | raster_images, Sample_len = mydrawPNG([vector_images_C]) 50 | return raster_images, Sample_len 51 | 52 | 53 | -------------------------------------------------------------------------------- /stage1/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | from model import SLFIR_Model 4 | from dataset import get_dataloader 5 | import argparse 6 | 7 | torch.manual_seed(42) 8 | torch.cuda.manual_seed_all(42) 9 | 10 | torch.backends.cudnn.deterministic = True 11 | torch.backends.cudnn.benchmark = False 12 | 13 | if __name__ == "__main__": 14 | 15 | parser = argparse.ArgumentParser(description='SLFIR Model') 16 | parser.add_argument('--dataset_name', type=str, default='ChairV2', help='ChairV2 / ShoeV2') 17 | parser.add_argument('--root_dir', type=str, default="/home/ubuntu/workplace/benke-2020/chair_shoe/") 18 | parser.add_argument('--nThreads', type=int, default=4) 19 | parser.add_argument('--backbone_lr', type=float, default=0.0005) 20 | parser.add_argument('--lr', type=float, default=0.005) 21 | parser.add_argument('--max_epoch', type=int, default=200) 22 | parser.add_argument('--print_freq_iter', type=int, default=1) 23 | parser.add_argument('--gpu_id', type=int, default=0) 24 | parser.add_argument('--feature_num', type=int, default=64) 25 | hp = parser.parse_args() 26 | if hp.dataset_name == 'ShoeV2': 27 | hp.batchsize = 64 28 | hp.eval_freq_iter = 50 29 | elif hp.dataset_name == 'ChairV2': 30 | hp.batchsize = 32 31 | hp.eval_freq_iter = 20 32 | 33 | 34 | hp.device = torch.device("cuda:" + str(hp.gpu_id) if torch.cuda.is_available() else "cpu") 35 | dataloader_Train, dataloader_Test = get_dataloader(hp) 36 | print(hp) 37 | 38 | model = SLFIR_Model(hp) 39 | model.to(hp.device) 40 | step_count, top1, top5, top10, top50, top100 = -1, 0, 0, 0, 0, 0 41 | mean_IOU_buffer = 0 42 | real_p = [0, 0, 0, 0, 0, 0] 43 | 44 | for i_epoch in range(hp.max_epoch): 45 | for batch_data in dataloader_Train: 46 | step_count = step_count + 1 47 | start = time.time() 48 | model.train() 49 | loss = model.train_model(batch=batch_data) 50 | if step_count % hp.print_freq_iter == 0: 51 | print( 52 | 'Epoch: {}, Iteration: {}, Loss: {:.8f}, Top1_Accuracy: {:.5f}, Top5_Accuracy; {:.5f}, Top10_Accuracy: {:.5f}, Time: {}'.format( 53 | i_epoch, step_count, loss, top1, top5, top10, time.time() - start)) 54 | 55 | if i_epoch >= 0 and step_count % hp.eval_freq_iter == 0: 56 | with torch.no_grad(): 57 | start_time = time.time() 58 | top1, top5, top10, mean_IOU, mean_MA, mean_OurB, mean_OurA = model.evaluate_NN(dataloader_Test) 59 | model.train() 60 | print('Epoch: {}, Iteration: {}:'.format(i_epoch, step_count)) 61 | print("TEST A@1: {}".format(top1)) 62 | print("TEST A@5: {}".format(top5)) 63 | print("TEST A@10: {}".format(top10)) 64 | print("TEST M@B: {}".format(mean_IOU)) 65 | print("TEST M@A: {}".format(mean_MA)) 66 | print("TEST OurB: {}".format(mean_OurB)) 67 | print("TEST OurA: {}".format(mean_OurA)) 68 | print("TEST Time: {}".format(time.time() - start_time)) 69 | if mean_IOU > mean_IOU_buffer: 70 | torch.save(model.backbone_network.state_dict(), 71 | './models/' + hp.dataset_name + '_feature' + str(hp.feature_num) + '_condition' + str( 72 | 0) + '_backbone_best.pth') 73 | torch.save(model.attn_network.state_dict(), 74 | './models/' + hp.dataset_name + '_feature' + str(hp.feature_num) + '_condition' + str( 75 | 0) + '_attn_best.pth') 76 | torch.save(model.linear_network.state_dict(), 77 | './models/' + hp.dataset_name + '_feature' + str(hp.feature_num) + '_condition' + str( 78 | 0) + '_linear_best.pth') 79 | 80 | mean_IOU_buffer = mean_IOU 81 | 82 | real_p = [top1, top5, top10, mean_MA, mean_OurB, mean_OurA] 83 | 84 | print('Model Updated') 85 | 86 | -------------------------------------------------------------------------------- /stage1/Networks.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torchvision.models as backbone_ 3 | import torch.nn.functional as F 4 | import torch 5 | torch.manual_seed(42) 6 | torch.cuda.manual_seed_all(42) 7 | 8 | torch.backends.cudnn.deterministic = True 9 | torch.backends.cudnn.benchmark = False 10 | 11 | class InceptionV3_Network(nn.Module): 12 | def __init__(self): 13 | super(InceptionV3_Network, self).__init__() 14 | backbone = backbone_.inception_v3(pretrained=True) 15 | 16 | ## Extract Inception Layers ## 17 | self.Conv2d_1a_3x3 = backbone.Conv2d_1a_3x3 18 | self.Conv2d_2a_3x3 = backbone.Conv2d_2a_3x3 19 | self.Conv2d_2b_3x3 = backbone.Conv2d_2b_3x3 20 | self.Conv2d_3b_1x1 = backbone.Conv2d_3b_1x1 21 | self.Conv2d_4a_3x3 = backbone.Conv2d_4a_3x3 22 | # 固定前面层的参数 23 | for param in self.parameters(): 24 | param.requires_grad = False 25 | # 后面这些层仍然使用预训练的参数,但用小学习率更新 26 | self.Mixed_5b = backbone.Mixed_5b 27 | self.Mixed_5c = backbone.Mixed_5c 28 | self.Mixed_5d = backbone.Mixed_5d 29 | self.Mixed_6a = backbone.Mixed_6a 30 | self.Mixed_6b = backbone.Mixed_6b 31 | self.Mixed_6c = backbone.Mixed_6c 32 | self.Mixed_6d = backbone.Mixed_6d 33 | self.Mixed_6e = backbone.Mixed_6e 34 | self.Mixed_7a = backbone.Mixed_7a 35 | self.Mixed_7b = backbone.Mixed_7b 36 | self.Mixed_7c = backbone.Mixed_7c 37 | 38 | 39 | def forward(self, x): 40 | # N x 3 x 299 x 299 41 | x = self.Conv2d_1a_3x3(x) 42 | # N x 32 x 149 x 149 43 | x = self.Conv2d_2a_3x3(x) 44 | # N x 32 x 147 x 147 45 | x = self.Conv2d_2b_3x3(x) 46 | # N x 64 x 147 x 147 47 | x = F.max_pool2d(x, kernel_size=3, stride=2) 48 | # N x 64 x 73 x 73 49 | x = self.Conv2d_3b_1x1(x) 50 | # N x 80 x 73 x 73 51 | x = self.Conv2d_4a_3x3(x) 52 | # N x 192 x 71 x 71 53 | x = F.max_pool2d(x, kernel_size=3, stride=2) 54 | # N x 192 x 35 x 35 55 | x = self.Mixed_5b(x) 56 | # N x 256 x 35 x 35 57 | x = self.Mixed_5c(x) 58 | # N x 288 x 35 x 35 59 | x = self.Mixed_5d(x) 60 | # N x 288 x 35 x 35 61 | x = self.Mixed_6a(x) 62 | # N x 768 x 17 x 17 63 | x = self.Mixed_6b(x) 64 | # N x 768 x 17 x 17 65 | x = self.Mixed_6c(x) 66 | # N x 768 x 17 x 17 67 | x = self.Mixed_6d(x) 68 | # N x 768 x 17 x 17 69 | x = self.Mixed_6e(x) 70 | # N x 768 x 17 x 17 71 | x = self.Mixed_7a(x) 72 | # N x 1280 x 8 x 8 73 | x = self.Mixed_7b(x) 74 | # N x 2048 x 8 x 8 75 | x = self.Mixed_7c(x) 76 | return F.normalize(x) 77 | 78 | 79 | class Attention(nn.Module): 80 | def __init__(self): 81 | super(Attention, self).__init__() 82 | self.net = nn.Sequential(nn.Conv2d(2048, 512, kernel_size=1), 83 | nn.BatchNorm2d(512), 84 | nn.ReLU(), 85 | nn.Conv2d(512, 1, kernel_size=1)) 86 | self.pool_method = nn.AdaptiveMaxPool2d(1) # as default 87 | 88 | def forward(self, x): 89 | attn_mask = self.net(x) 90 | attn_mask = attn_mask.view(attn_mask.size(0), -1) 91 | attn_mask = nn.Softmax(dim=1)(attn_mask) 92 | attn_mask = attn_mask.view(attn_mask.size(0), 1, x.size(2), x.size(3)) 93 | x = x + (x * attn_mask) 94 | x = self.pool_method(x).view(-1, 2048) 95 | return F.normalize(x) 96 | 97 | class Linear(nn.Module): 98 | def __init__(self, feature_num): 99 | super(Linear, self).__init__() 100 | self.head_layer = nn.Linear(2048 , feature_num) 101 | 102 | def forward(self, x): 103 | return F.normalize(self.head_layer(x)) 104 | 105 | if __name__ == '__main__': 106 | enc = InceptionV3_Network() 107 | def init_weights(m): 108 | if type(m) == nn.Linear or type(m) == nn.Conv2d: 109 | nn.init.kaiming_uniform_(m.weight) 110 | enc.apply(init_weights) 111 | 112 | X = torch.rand(size=(10, 3, 256, 256)) 113 | Y = enc(X) 114 | print("Total number of paramerters in networks is {} ".format(sum(x.numel() for x in enc.parameters()))) -------------------------------------------------------------------------------- /stage1/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.utils.data as data 4 | import torchvision.transforms as transforms 5 | import os 6 | from random import randint 7 | from PIL import Image 8 | import random 9 | import torchvision.transforms.functional as F 10 | import pickle 11 | from render_sketch_chairv2 import redraw_Quick2RGB 12 | 13 | torch.manual_seed(42) 14 | torch.cuda.manual_seed_all(42) 15 | np.random.seed(42) 16 | torch.backends.cudnn.deterministic = True 17 | torch.backends.cudnn.benchmark = False 18 | 19 | class SLFIR_Dataset(data.Dataset): 20 | def __init__(self, hp, mode): 21 | 22 | self.hp = hp 23 | self.mode = mode 24 | 25 | if hp.dataset_name == "ChairV2": 26 | self.root_dir = os.path.join(hp.root_dir, 'Dataset', 'ChairV2') 27 | elif hp.dataset_name == "ShoeV2": 28 | self.root_dir = os.path.join(hp.root_dir, 'Dataset', 'ShoeV2') 29 | 30 | with open(os.path.join(self.root_dir, hp.dataset_name + '_' + "Coordinate"), 'rb') as fp: 31 | self.Coordinate = pickle.load(fp) 32 | 33 | self.Skecth_Train_List = [x for x in self.Coordinate if 'train' in x] 34 | self.Skecth_Test_List = [x for x in self.Coordinate if 'test' in x] 35 | 36 | self.train_transform = get_transform('Train') 37 | self.test_transform = get_transform('Test') 38 | 39 | 40 | def __getitem__(self, item): 41 | sample = {} 42 | if self.mode == 'Train': 43 | sketch_path = self.Skecth_Train_List[item] 44 | 45 | positive_name = '_'.join(self.Skecth_Train_List[item].split('/')[-1].split('_')[:-1]) 46 | positive_path = os.path.join(self.root_dir, 'photo', positive_name + '.png') 47 | 48 | possible_list = list(range(len(self.Skecth_Train_List))) 49 | possible_list.remove(item) 50 | negative_item = possible_list[randint(0, len(possible_list) - 1)] 51 | negative_name = '_'.join(self.Skecth_Train_List[negative_item].split('/')[-1].split('_')[:-1]) 52 | negative_path = os.path.join(self.root_dir, 'photo', negative_name + '.png') 53 | 54 | vector_x = self.Coordinate[sketch_path] 55 | sketch_img, Sample_len = redraw_Quick2RGB(vector_x) 56 | 57 | sketch_img = Image.fromarray(sketch_img[-1]).convert('RGB') 58 | positive_img = Image.open(positive_path) 59 | negative_img = Image.open(negative_path) 60 | 61 | n_flip = random.random() 62 | if n_flip > 0.5: 63 | sketch_img = F.hflip(sketch_img) 64 | positive_img = F.hflip(positive_img) 65 | negative_img = F.hflip(negative_img) 66 | 67 | sketch_img = self.train_transform(sketch_img) 68 | positive_img = self.train_transform(positive_img) 69 | negative_img = self.train_transform(negative_img) 70 | 71 | sample = {'sketch_img': sketch_img, 'positive_img': positive_img, 'negative_img': negative_img, 72 | 'sketch_path': sketch_path, 'positive_path': positive_name, 'negative_path': negative_name 73 | } 74 | 75 | elif self.mode == 'Test': 76 | 77 | sketch_path = self.Skecth_Test_List[item] 78 | 79 | positive_name = '_'.join(self.Skecth_Test_List[item].split('/')[-1].split('_')[:-1]) 80 | positive_path = os.path.join(self.root_dir, 'photo', positive_name + '.png') 81 | 82 | vector_x = self.Coordinate[sketch_path] 83 | sketch_img, Sample_len = redraw_Quick2RGB(vector_x) 84 | 85 | sketch_img = self.test_transform(Image.fromarray(sketch_img[-1]).convert('RGB')) 86 | positive_img = self.test_transform(Image.open(positive_path)) 87 | 88 | sample = {'sketch_img': sketch_img, 'positive_img': positive_img, 89 | 'sketch_path': sketch_path, 'positive_path': positive_name} 90 | 91 | return sample 92 | 93 | def __len__(self): 94 | if self.mode == 'Train': 95 | return len(self.Skecth_Train_List) 96 | elif self.mode == 'Test': 97 | return len(self.Skecth_Test_List) 98 | 99 | 100 | def get_dataloader(hp): 101 | dataset_Train = SLFIR_Dataset(hp, mode='Train') 102 | dataloader_Train = data.DataLoader(dataset_Train, batch_size=hp.batchsize, shuffle=True, num_workers=int(hp.nThreads)) 103 | 104 | dataset_Test = SLFIR_Dataset(hp, mode='Test') 105 | dataloader_Test = data.DataLoader(dataset_Test, batch_size=1, shuffle=False, num_workers=int(hp.nThreads)) 106 | 107 | return dataloader_Train, dataloader_Test 108 | 109 | 110 | def get_transform(type): 111 | transform_list = [] 112 | if type == 'Train': 113 | transform_list.extend([transforms.Resize(320), transforms.RandomCrop(299)]) 114 | elif type == 'Test': 115 | transform_list.extend([transforms.Resize(299)]) 116 | transform_list.extend( 117 | [transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 118 | return transforms.Compose(transform_list) 119 | 120 | -------------------------------------------------------------------------------- /stage2/Networks.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torchvision.models as backbone_ 3 | import torch.nn.functional as F 4 | import torch 5 | from torch.autograd import Variable 6 | torch.manual_seed(42) 7 | torch.cuda.manual_seed_all(42) 8 | 9 | torch.backends.cudnn.deterministic = True 10 | torch.backends.cudnn.benchmark = False 11 | 12 | class InceptionV3_Network(nn.Module): 13 | def __init__(self): 14 | super(InceptionV3_Network, self).__init__() 15 | backbone = backbone_.inception_v3(pretrained=True) 16 | 17 | ## Extract Inception Layers ## 18 | self.Conv2d_1a_3x3 = backbone.Conv2d_1a_3x3 19 | self.Conv2d_2a_3x3 = backbone.Conv2d_2a_3x3 20 | self.Conv2d_2b_3x3 = backbone.Conv2d_2b_3x3 21 | self.Conv2d_3b_1x1 = backbone.Conv2d_3b_1x1 22 | self.Conv2d_4a_3x3 = backbone.Conv2d_4a_3x3 23 | self.Mixed_5b = backbone.Mixed_5b 24 | self.Mixed_5c = backbone.Mixed_5c 25 | self.Mixed_5d = backbone.Mixed_5d 26 | self.Mixed_6a = backbone.Mixed_6a 27 | self.Mixed_6b = backbone.Mixed_6b 28 | self.Mixed_6c = backbone.Mixed_6c 29 | self.Mixed_6d = backbone.Mixed_6d 30 | self.Mixed_6e = backbone.Mixed_6e 31 | # # 固定前面层的参数 32 | # for param in self.parameters(): 33 | # param.requires_grad = False 34 | # # 后面这些层仍然使用预训练的参数,但用小学习率更新 35 | self.Mixed_7a = backbone.Mixed_7a 36 | self.Mixed_7b = backbone.Mixed_7b 37 | self.Mixed_7c = backbone.Mixed_7c 38 | 39 | 40 | def forward(self, x): 41 | # N x 3 x 299 x 299 42 | x = self.Conv2d_1a_3x3(x) 43 | # N x 32 x 149 x 149 44 | x = self.Conv2d_2a_3x3(x) 45 | # N x 32 x 147 x 147 46 | x = self.Conv2d_2b_3x3(x) 47 | # N x 64 x 147 x 147 48 | x = F.max_pool2d(x, kernel_size=3, stride=2) 49 | # N x 64 x 73 x 73 50 | x = self.Conv2d_3b_1x1(x) 51 | # N x 80 x 73 x 73 52 | x = self.Conv2d_4a_3x3(x) 53 | # N x 192 x 71 x 71 54 | x = F.max_pool2d(x, kernel_size=3, stride=2) 55 | # N x 192 x 35 x 35 56 | x = self.Mixed_5b(x) 57 | # N x 256 x 35 x 35 58 | x = self.Mixed_5c(x) 59 | # N x 288 x 35 x 35 60 | x = self.Mixed_5d(x) 61 | # N x 288 x 35 x 35 62 | x = self.Mixed_6a(x) 63 | # N x 768 x 17 x 17 64 | x = self.Mixed_6b(x) 65 | # N x 768 x 17 x 17 66 | x = self.Mixed_6c(x) 67 | # N x 768 x 17 x 17 68 | x = self.Mixed_6d(x) 69 | # N x 768 x 17 x 17 70 | x = self.Mixed_6e(x) 71 | # N x 768 x 17 x 17 72 | x = self.Mixed_7a(x) 73 | # N x 1280 x 8 x 8 74 | x = self.Mixed_7b(x) 75 | # N x 2048 x 8 x 8 76 | x = self.Mixed_7c(x) 77 | return F.normalize(x) 78 | 79 | def fixed_param(self): 80 | for param in self.parameters(): 81 | param.requires_grad = False 82 | 83 | 84 | class Attention(nn.Module): 85 | def __init__(self): 86 | super(Attention, self).__init__() 87 | self.net = nn.Sequential(nn.Conv2d(2048, 512, kernel_size=1), 88 | nn.BatchNorm2d(512), 89 | nn.ReLU(), 90 | nn.Conv2d(512, 1, kernel_size=1)) 91 | self.pool_method = nn.AdaptiveMaxPool2d(1) # as default 92 | 93 | def forward(self, x): 94 | attn_mask = self.net(x) 95 | attn_mask = attn_mask.view(attn_mask.size(0), -1) 96 | attn_mask = nn.Softmax(dim=1)(attn_mask) 97 | attn_mask = attn_mask.view(attn_mask.size(0), 1, x.size(2), x.size(3)) 98 | x = x + (x * attn_mask) 99 | x = self.pool_method(x).view(-1, 2048) 100 | return F.normalize(x) 101 | 102 | def fixed_param(self): 103 | for param in self.parameters(): 104 | param.requires_grad = False 105 | 106 | 107 | class Linear(nn.Module): 108 | def __init__(self, feature_num): 109 | super(Linear, self).__init__() 110 | self.head_layer = nn.Linear(2048, feature_num) 111 | 112 | def forward(self, x): 113 | return F.normalize(self.head_layer(x)) 114 | 115 | 116 | def fixed_param(self): 117 | for param in self.parameters(): 118 | param.requires_grad = False 119 | 120 | class Linear_s2(nn.Module): 121 | def __init__(self, feature_num, condition_num): 122 | super(Linear_s2, self).__init__() 123 | self.head_layer = nn.Linear(2048 + condition_num, feature_num) 124 | self.l1 = nn.Linear(2048 + condition_num, 256) 125 | self.relu = nn.ReLU() 126 | self.l2 = nn.Linear(256, feature_num) 127 | 128 | def forward(self, x): 129 | return F.normalize(self.l2(self.relu(self.l1(x)))) 130 | 131 | def fixed_param(self): 132 | for param in self.parameters(): 133 | param.requires_grad = False 134 | 135 | 136 | class Block_lstm(nn.Module): 137 | def __init__(self, opt): 138 | super(Block_lstm, self).__init__() 139 | self.opt = opt 140 | self.lstm_0 = nn.LSTM(input_size=2048 + opt.condition_num, hidden_size=512, bidirectional=True) 141 | self.lstm_1 = nn.LSTM(input_size=1024, hidden_size=int(self.opt.feature_num // 2), bidirectional=True) 142 | 143 | def forward(self, X): 144 | X = X.unsqueeze(dim=0) 145 | _,b,_ = X.shape 146 | hidden_state = torch.zeros(1*2, b, 512) 147 | cell_state = torch.zeros(1*2, b, 512) 148 | outputs, (_, _) = self.lstm_0(X, (hidden_state.to(self.opt.device), cell_state.to(self.opt.device))) 149 | hidden_state_1 = torch.zeros(1 * 2, b, int(self.opt.feature_num // 2)) 150 | cell_state_1 = torch.zeros(1 * 2, b, int(self.opt.feature_num // 2)) 151 | outputs, (_, _) = self.lstm_1(outputs, (hidden_state_1.to(self.opt.device), cell_state_1.to(self.opt.device))) 152 | outputs = outputs.squeeze(dim=0) 153 | return outputs 154 | 155 | def fixed_param(self): 156 | for param in self.parameters(): 157 | param.requires_grad = False 158 | -------------------------------------------------------------------------------- /stage1/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from Networks import InceptionV3_Network, Attention, Linear 3 | from torch import optim 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | class SLFIR_Model(nn.Module): 9 | def __init__(self, hp): 10 | super(SLFIR_Model, self).__init__() 11 | 12 | self.backbone_network = InceptionV3_Network() 13 | self.backbone_train_params = self.backbone_network.parameters() 14 | 15 | def init_weights(m): 16 | if type(m) == nn.Linear or type(m) == nn.Conv2d: 17 | nn.init.kaiming_normal_(m.weight) 18 | 19 | self.attn_network = Attention() 20 | self.attn_network.apply(init_weights) 21 | self.attn_train_params = self.attn_network.parameters() 22 | 23 | self.linear_network = Linear(hp.feature_num) 24 | self.linear_network.apply(init_weights) 25 | self.linear_train_params = self.linear_network.parameters() 26 | 27 | self.optimizer = optim.Adam([ 28 | {'params': filter(lambda param: param.requires_grad, self.backbone_train_params), 'lr': hp.backbone_lr}, 29 | {'params': self.attn_train_params, 'lr': hp.lr}, 30 | {'params': self.linear_train_params, 'lr': hp.lr}]) 31 | # 训练的模型 32 | self.loss = nn.TripletMarginLoss(margin=0.2) 33 | self.hp = hp 34 | 35 | def train_model(self, batch): 36 | self.train() 37 | positive_feature = self.linear_network(self.attn_network( 38 | self.backbone_network(batch['positive_img'].to(self.hp.device)))) 39 | negative_feature = self.linear_network(self.attn_network( 40 | self.backbone_network(batch['negative_img'].to(self.hp.device)))) 41 | sample_feature = self.linear_network(self.attn_network( 42 | self.backbone_network(batch['sketch_img'].to(self.hp.device)))) 43 | 44 | loss = self.loss(sample_feature, positive_feature, negative_feature) 45 | self.optimizer.zero_grad() 46 | loss.backward() 47 | self.optimizer.step() 48 | return loss.item() 49 | 50 | def evaluate_NN(self, dataloader): 51 | self.eval() 52 | 53 | self.Sketch_Array_Test = [] 54 | self.Image_Array_Test = [] 55 | self.Sketch_Path = [] 56 | self.Image_Path = [] 57 | for idx, batch in enumerate(dataloader): 58 | sketch_feature = self.attn_network( 59 | self.backbone_network(batch['sketch_img'].to(self.hp.device))) 60 | positive_feature = self.linear_network(self.attn_network( 61 | self.backbone_network(batch['positive_img'].to(self.hp.device)))) 62 | self.Sketch_Array_Test.append(sketch_feature) 63 | self.Sketch_Path.append(batch['sketch_path']) 64 | 65 | for i_num, positive_path in enumerate(batch['positive_path']): 66 | if positive_path not in self.Image_Path: 67 | self.Image_Path.append(batch['positive_path'][i_num]) 68 | self.Image_Array_Test.append(positive_feature[i_num]) 69 | 70 | self.Sketch_Array_Test = torch.stack(self.Sketch_Array_Test) 71 | self.Image_Array_Test = torch.stack(self.Image_Array_Test) 72 | num_of_Sketch_Step = len(self.Sketch_Array_Test[0]) 73 | avererage_area = [] 74 | avererage_area_percentile = [] 75 | avererage_ourB = [] 76 | avererage_ourA = [] 77 | exps = np.linspace(1,num_of_Sketch_Step, num_of_Sketch_Step) / num_of_Sketch_Step 78 | factor = np.exp(1 - exps) / np.e 79 | rank_all = torch.zeros(len(self.Sketch_Array_Test), num_of_Sketch_Step) 80 | rank_all_percentile = torch.zeros(len(self.Sketch_Array_Test), num_of_Sketch_Step) 81 | 82 | for i_batch, sanpled_batch in enumerate(self.Sketch_Array_Test): 83 | mean_rank = [] 84 | mean_rank_percentile = [] 85 | mean_rank_ourB = [] 86 | mean_rank_ourA = [] 87 | 88 | for i_sketch in range(sanpled_batch.shape[0]): 89 | sketch_feature = self.linear_network(sanpled_batch[i_sketch].unsqueeze(0).to(self.hp.device)) 90 | 91 | s_path =self.Sketch_Path[i_batch] 92 | s_path=''.join(s_path) 93 | positive_path = '_'.join(s_path.split('/')[-1].split('_')[:-1]) 94 | position_query = self.Image_Path.index(positive_path) 95 | 96 | target_distance = F.pairwise_distance(F.normalize(sketch_feature.to(self.hp.device)), self.Image_Array_Test[position_query].unsqueeze(0).to(self.hp.device)) 97 | distance = F.pairwise_distance(F.normalize(sketch_feature.to(self.hp.device)), self.Image_Array_Test.to(self.hp.device)) 98 | 99 | rank_all[i_batch, i_sketch] = distance.le(target_distance).sum() 100 | 101 | rank_all_percentile[i_batch, i_sketch] = (len(distance) - rank_all[i_batch, i_sketch]) / (len(distance) - 1) 102 | 103 | if rank_all[i_batch, i_sketch].item() == 0: 104 | mean_rank.append(1.) 105 | else: 106 | mean_rank.append(1/rank_all[i_batch, i_sketch].item()) 107 | mean_rank_percentile.append(rank_all_percentile[i_batch, i_sketch].item()) 108 | mean_rank_ourB.append(1/rank_all[i_batch, i_sketch].item() * factor[i_sketch]) 109 | mean_rank_ourA.append(rank_all_percentile[i_batch, i_sketch].item()*factor[i_sketch]) 110 | 111 | avererage_area.append(np.sum(mean_rank)/len(mean_rank)) 112 | avererage_area_percentile.append(np.sum(mean_rank_percentile)/len(mean_rank_percentile)) 113 | avererage_ourB.append(np.sum(mean_rank_ourB)/len(mean_rank_ourB)) 114 | avererage_ourA.append(np.sum(mean_rank_ourA)/len(mean_rank_ourA)) 115 | 116 | top1_accuracy = rank_all[:, -1].le(1).sum().numpy() / rank_all.shape[0] 117 | top5_accuracy = rank_all[:, -1].le(5).sum().numpy() / rank_all.shape[0] 118 | top10_accuracy = rank_all[:, -1].le(10).sum().numpy() / rank_all.shape[0] 119 | 120 | meanIOU = np.mean(avererage_area) 121 | meanMA = np.mean(avererage_area_percentile) 122 | meanOurB = np.mean(avererage_ourB) 123 | meanOurA = np.mean(avererage_ourA) 124 | 125 | return top1_accuracy, top5_accuracy, top10_accuracy, meanIOU, meanMA, meanOurB, meanOurA 126 | -------------------------------------------------------------------------------- /stage2/train.py: -------------------------------------------------------------------------------- 1 | from time import time 2 | from model import SLFIR_Model 3 | import time 4 | import os 5 | import torch 6 | import numpy as np 7 | import argparse 8 | from dataset import * 9 | from torch.utils.tensorboard import SummaryWriter 10 | 11 | np.random.seed(42) 12 | torch.manual_seed(42) 13 | torch.cuda.manual_seed_all(42) 14 | 15 | torch.backends.cudnn.deterministic = True 16 | torch.backends.cudnn.benchmark = False 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--dataset_name', type=str, default='ChairV2', help='ChairV2 / ShoeV2') 20 | parser.add_argument('--root_dir', type=str, default="/home/ubuntu/workplace/benke-2020/chair_shoe/") 21 | parser.add_argument('--batchsize', type=int, default=64) 22 | parser.add_argument('--print_freq_iter', type=int, default=20) 23 | parser.add_argument('--nThreads', type=int, default=4) 24 | parser.add_argument('--lr', type=float, default=0.0001) 25 | parser.add_argument('--epoches', type=int, default=300) 26 | parser.add_argument('--feature_num', type=int, default=64) 27 | parser.add_argument('--gpu_id', type=int, default=0) 28 | parser.add_argument('--stage2_net', type=str, default='LSTM', help='LSTM / MLP') 29 | hp = parser.parse_args() 30 | hp.device = torch.device('cuda:' + str(hp.gpu_id) if torch.cuda.is_available() else 'cpu') 31 | 32 | if hp.dataset_name == 'ShoeV2': 33 | hp.condition_num = 15 34 | elif hp.dataset_name == 'ChairV2': 35 | hp.condition_num = 19 36 | 37 | hp.backbone_model_dir = '../' +'stage1/models/' + hp.dataset_name + '_feature' + str(hp.feature_num) + '_condition' + str(0) + '_backbone_best.pth' 38 | hp.attn_model_dir = '../' + 'stage1/models/' + hp.dataset_name + '_feature' + str(hp.feature_num) + '_condition' + str(0) + '_attn_best.pth' 39 | hp.linear_model_dir = '../' + 'stage1/models/' + hp.dataset_name + '_feature' + str(hp.feature_num) + '_condition' + str(0) + '_linear_best.pth' 40 | 41 | print(hp) 42 | 43 | tb_logdir = r"./run/" 44 | slfir_model = SLFIR_Model(hp) 45 | dataloader_sketch_train, dataloader_sketch_test = get_dataloader(hp) 46 | 47 | def main_train(): 48 | meanMB_buffer = 0 49 | real_p = [0, 0, 0, 0, 0, 0] 50 | loss_buffer = [] 51 | tb_writer = SummaryWriter(log_dir=tb_logdir) 52 | Top1_Song = [0] 53 | Top5_Song = [0] 54 | Top10_Song = [0] 55 | meanMB_Song = [] 56 | meanMA_Song = [] 57 | meanWMB_Song = [] 58 | meanWMA_Song = [] 59 | step_stddev = 0 60 | for epoch in range(hp.epoches): 61 | for i, sanpled_batch in enumerate(dataloader_sketch_train): 62 | start_time = time.time() 63 | loss_triplet = slfir_model.train_model(sanpled_batch) 64 | loss_buffer.append(loss_triplet) 65 | 66 | step_stddev += 1 67 | tb_writer.add_scalar('total loss', loss_triplet, step_stddev) 68 | print('epoch: {}, iter: {}, loss: {}, time cost{}'.format(epoch, step_stddev, loss_triplet, time.time()-start_time)) 69 | 70 | if epoch >= 5 and step_stddev % hp.print_freq_iter==0: 71 | 72 | with torch.no_grad(): 73 | start_time = time.time() 74 | top1, top5, top10, meanMB, meanMA, meanWMB, meanWMA = slfir_model.evaluate_NN(dataloader_sketch_test) 75 | slfir_model.train() 76 | print('Epoch: {}, Iteration: {}:'.format(epoch, step_stddev)) 77 | print("TEST A@1: {}".format(top1)) 78 | print("TEST A@5: {}".format(top5)) 79 | print("TEST A@10: {}".format(top10)) 80 | print("TEST M@B: {}".format(meanMB)) 81 | print("TEST M@A: {}".format(meanMA)) 82 | print("TEST W@MB: {}".format(meanWMB)) 83 | print("TEST W@MA: {}".format(meanWMA)) 84 | print("TEST Time: {}".format(time.time()-start_time)) 85 | Top1_Song.append(top1) 86 | Top5_Song.append(top5) 87 | Top10_Song.append(top10) 88 | meanMB_Song.append(meanMB) 89 | meanMA_Song.append(meanMA) 90 | meanWMB_Song.append(meanWMB) 91 | meanWMA_Song.append(meanWMA) 92 | tb_writer.add_scalar('TEST A@1', top1, step_stddev) 93 | tb_writer.add_scalar('TEST A@5', top5, step_stddev) 94 | tb_writer.add_scalar('TEST A@10', top10, step_stddev) 95 | tb_writer.add_scalar('TEST M@B', meanMB, step_stddev) 96 | tb_writer.add_scalar('TEST M@A', meanMA, step_stddev) 97 | tb_writer.add_scalar('TEST W@MB', meanWMB, step_stddev) 98 | tb_writer.add_scalar('TEST W@MA', meanWMA, step_stddev) 99 | 100 | if meanMB > meanMB_buffer: 101 | 102 | torch.save(slfir_model.stage2_network.state_dict(), './models/' + hp.dataset_name + '_feature' + str(hp.feature_num) + '_condition' + str(hp.condition_num) + '_' + str(hp.stage2_net) + '.pth') 103 | 104 | meanMB_buffer = meanMB 105 | 106 | real_p = [top1, top5, top10, meanMA, meanWMB, meanWMA] 107 | print('Model Updated') 108 | print('REAL performance: Top1: {}, Top5: {}, Top10: {}, MB: {}, MA: {}, WMB: {}, WMA: {}'.format(real_p[0], real_p[1], 109 | real_p[2], 110 | meanMB_buffer, 111 | real_p[3], 112 | real_p[4], 113 | real_p[5])) 114 | 115 | print("TOP1_MAX: {}".format(max(Top1_Song))) 116 | print("TOP5_MAX: {}".format(max(Top5_Song))) 117 | print("TOP10_MAX: {}".format(max(Top10_Song))) 118 | print("meaIOU_MAX: {}".format(max((meanMB_Song)))) 119 | print("meaMA_MAX: {}".format(max((meanMA_Song)))) 120 | print("meanWMB_MAX: {}".format(max(meanWMB_Song))) 121 | print("meanWMA_MAX: {}".format(max(meanWMA_Song))) 122 | print(Top1_Song) 123 | print(Top5_Song) 124 | print(Top10_Song) 125 | print(meanMB_Song) 126 | print(meanMA_Song) 127 | print(meanWMB_Song) 128 | print(meanWMA_Song) 129 | 130 | if __name__ == "__main__": 131 | main_train() 132 | 133 | 134 | 135 | 136 | -------------------------------------------------------------------------------- /stage2/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | import torchvision.transforms as transforms 4 | import os 5 | from random import randint 6 | from PIL import Image 7 | import random 8 | import pandas as pd 9 | import numpy as np 10 | import torchvision.transforms.functional as F 11 | import pickle 12 | from render_sketch_chairv2 import redraw_Quick2RGB 13 | 14 | torch.manual_seed(42) 15 | torch.cuda.manual_seed_all(42) 16 | 17 | torch.backends.cudnn.deterministic = True 18 | torch.backends.cudnn.benchmark = False 19 | 20 | class createDataset(data.Dataset): 21 | def __init__(self, hp, mode): 22 | 23 | self.hp = hp 24 | self.mode = mode 25 | 26 | if hp.dataset_name == "ChairV2": 27 | self.root_dir = os.path.join(hp.root_dir, 'Dataset', 'ChairV2') 28 | self.condition = Condition(os.path.join(self.root_dir, 'chair_condition.csv'), hp.dataset_name) 29 | elif hp.dataset_name == "ShoeV2": 30 | self.root_dir = os.path.join(hp.root_dir, 'Dataset', 'ShoeV2') 31 | self.condition = Condition(os.path.join(self.root_dir, 'shoe_condition.csv'), hp.dataset_name) 32 | 33 | with open(os.path.join(self.root_dir, hp.dataset_name + '_' + "Coordinate"), 'rb') as fp: 34 | self.Coordinate = pickle.load(fp) 35 | 36 | self.Skecth_Train_List = [x for x in self.Coordinate if 'train' in x] 37 | self.Skecth_Test_List = [x for x in self.Coordinate if 'test' in x] 38 | 39 | self.train_transform = get_transform('Train') 40 | self.test_transform = get_transform('Test') 41 | 42 | def __getitem__(self, item): 43 | sample = {} 44 | if self.mode == 'Train': 45 | sketch_path = self.Skecth_Train_List[item] 46 | 47 | positive_name = '_'.join(self.Skecth_Train_List[item].split('/')[-1].split('_')[:-1]) 48 | positive_path = os.path.join(self.root_dir, 'photo', positive_name + '.png') 49 | 50 | possible_list = list(range(len(self.Skecth_Train_List))) 51 | possible_list.remove(item) 52 | negative_item = possible_list[randint(0, len(possible_list) - 1)] 53 | negative_name = '_'.join(self.Skecth_Train_List[negative_item].split('/')[-1].split('_')[:-1]) 54 | negative_path = os.path.join(self.root_dir, 'photo', negative_name + '.png') 55 | 56 | vector_x = self.Coordinate[sketch_path] 57 | sketch_img, Sample_len = redraw_Quick2RGB(vector_x) 58 | 59 | sketch_seq = [Image.fromarray(sk_img).convert('RGB') for sk_img in sketch_img] 60 | positive_img = Image.open(positive_path) 61 | negative_img = Image.open(negative_path) 62 | 63 | n_flip = random.random() 64 | if n_flip > 0.5: 65 | sketch_seq = [F.hflip(sk_img) for sk_img in sketch_seq] 66 | positive_img = F.hflip(positive_img) 67 | negative_img = F.hflip(negative_img) 68 | 69 | sketch_seq = [self.train_transform(sk_img) for sk_img in sketch_seq] 70 | sketch_seq = torch.stack(sketch_seq) 71 | positive_img = self.train_transform(positive_img) 72 | negative_img = self.train_transform(negative_img) 73 | 74 | sample = {'sketch_seq': sketch_seq, 'positive_img': positive_img, 'negative_img': negative_img, 75 | 'sketch_seq_paths': sketch_path, 'positive_path': positive_name, 'negative_path': negative_name, 76 | 'condition': self.condition[positive_name], 'negative_condition': self.condition[negative_name] 77 | } 78 | 79 | elif self.mode == 'Test': 80 | sketch_path = self.Skecth_Test_List[item] 81 | positive_name = '_'.join(self.Skecth_Test_List[item].split('/')[-1].split('_')[:-1]) 82 | positive_path = os.path.join(self.root_dir, 'photo', positive_name + '.png') 83 | 84 | vector_x = self.Coordinate[sketch_path] 85 | sketch_img, Sample_len = redraw_Quick2RGB(vector_x) 86 | 87 | sketch_seq = [self.test_transform(Image.fromarray(sk_img).convert('RGB')) for sk_img in sketch_img] 88 | sketch_seq = torch.stack(sketch_seq) 89 | positive_img = Image.open(positive_path).convert('RGB') 90 | positive_img = self.test_transform(positive_img) 91 | 92 | sample = {'sketch_seq': sketch_seq, 'positive_img': positive_img, 93 | 'sketch_seq_paths': sketch_path, 'positive_path': positive_name, 'condition': self.condition[positive_name]} 94 | 95 | return sample 96 | 97 | def __len__(self): 98 | if self.mode == 'Train': 99 | return len(self.Skecth_Train_List) 100 | elif self.mode == 'Test': 101 | return len(self.Skecth_Test_List) 102 | 103 | def get_dataloader(hp): 104 | dataset_Train = createDataset(hp, mode='Train') 105 | dataloader_Train = data.DataLoader(dataset_Train, batch_size=hp.batchsize, shuffle=True, num_workers=int(hp.nThreads)) 106 | 107 | dataset_Test = createDataset(hp, mode='Test') 108 | dataloader_Test = data.DataLoader(dataset_Test, batch_size=1, shuffle=False, num_workers=int(hp.nThreads)) 109 | 110 | return dataloader_Train, dataloader_Test 111 | 112 | def get_transform(type): 113 | transform_list = [] 114 | if type == 'Train': 115 | transform_list.extend([transforms.Resize(320), transforms.RandomCrop(299)]) 116 | elif type == 'Test': 117 | transform_list.extend([transforms.Resize(299)]) 118 | transform_list.extend( 119 | [transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 120 | return transforms.Compose(transform_list) 121 | 122 | def Condition(path, mode): 123 | data = pd.read_csv(path, header=None) 124 | cond_list = data.values.tolist() 125 | cond_list = cond_list[1:] 126 | cond_dict = {} 127 | for item in cond_list: 128 | name = item[0].split('.')[0] 129 | val = torch.from_numpy(np.array(item[1:], dtype=np.int8)).long() 130 | if mode == "ChairV2": 131 | Legnum_one_hot = torch.nn.functional.one_hot(val[0], 7) 132 | Back_one_hot = torch.nn.functional.one_hot(val[1], 2) 133 | Handrail_one_hot = torch.nn.functional.one_hot(val[2], 2) 134 | Shape_one_hot = torch.nn.functional.one_hot(val[3], 3) 135 | Bottom_one_hot = torch.nn.functional.one_hot(val[4], 3) 136 | Thickness_one_hot = torch.nn.functional.one_hot(val[5], 2) 137 | val_ = torch.cat( 138 | [Legnum_one_hot, Back_one_hot, Handrail_one_hot, Shape_one_hot, Bottom_one_hot, Thickness_one_hot]) 139 | 140 | elif mode == "ShoeV2": 141 | Thickness_one_hot = torch.nn.functional.one_hot(val[0], 3) 142 | Heel_one_hot = torch.nn.functional.one_hot(val[1], 3) 143 | Hollow_one_hot = torch.nn.functional.one_hot(val[2], 2) 144 | Heigt_one_hot = torch.nn.functional.one_hot(val[3], 3) 145 | Shoelace_one_hot = torch.nn.functional.one_hot(val[4], 2) 146 | Button_one_hot = torch.nn.functional.one_hot(val[5], 2) 147 | val_ = torch.cat( 148 | [Thickness_one_hot, Heel_one_hot, Hollow_one_hot, Heigt_one_hot, Shoelace_one_hot, Button_one_hot]) 149 | 150 | dict_i = {name: val_} 151 | cond_dict.update(dict_i) 152 | return cond_dict -------------------------------------------------------------------------------- /stage2/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from Networks import InceptionV3_Network, Attention, Block_lstm, Linear, Linear_s2 3 | from torch import optim 4 | import torch 5 | import numpy as np 6 | import torch.nn.functional as F 7 | 8 | torch.manual_seed(42) 9 | torch.cuda.manual_seed_all(42) 10 | 11 | torch.backends.cudnn.deterministic = True 12 | torch.backends.cudnn.benchmark = False 13 | 14 | class SLFIR_Model(nn.Module): 15 | def __init__(self, opt): 16 | super(SLFIR_Model, self).__init__() 17 | 18 | self.backbone_network = InceptionV3_Network() 19 | self.backbone_network.load_state_dict(torch.load(opt.backbone_model_dir, map_location=opt.device)) 20 | self.backbone_network.to(opt.device) 21 | self.backbone_network.fixed_param() 22 | self.backbone_network.eval() 23 | 24 | def init_weights(m): 25 | if type(m) == nn.Linear or type(m) == nn.Conv2d: 26 | nn.init.kaiming_normal_(m.weight) 27 | 28 | self.attn_network = Attention() 29 | self.attn_network.load_state_dict(torch.load(opt.attn_model_dir, map_location=opt.device)) 30 | self.attn_network.to(opt.device) 31 | self.attn_network.fixed_param() 32 | self.attn_network.eval() 33 | 34 | self.linear_network = Linear(opt.feature_num) 35 | self.linear_network.load_state_dict(torch.load(opt.linear_model_dir, map_location=opt.device)) 36 | self.linear_network.to(opt.device) 37 | self.linear_network.fixed_param() 38 | self.linear_network.eval() 39 | 40 | if opt.stage2_net == 'LSTM': 41 | self.stage2_network = Block_lstm(opt) 42 | elif opt.stage2_net == 'MLP': 43 | self.stage2_network = Linear_s2(opt.feature_num, opt.condition_num) 44 | else: 45 | print("不支持的stage2 network") 46 | exit(0) 47 | self.stage2_network.apply(init_weights) 48 | self.stage2_network.train() 49 | self.stage2_network.to(opt.device) 50 | self.stage2_net_train_params = self.stage2_network.parameters() 51 | 52 | self.optimizer = optim.Adam([ 53 | {'params': self.stage2_net_train_params, 'lr': opt.lr}]) 54 | 55 | self.loss = nn.TripletMarginLoss(margin=0.3, p=2) 56 | self.opt = opt 57 | 58 | def train_model(self, batch): 59 | self.backbone_network.eval() 60 | self.attn_network.eval() 61 | self.linear_network.eval() 62 | self.stage2_network.train() 63 | loss = 0 64 | 65 | for idx in range(len(batch['sketch_seq'])): 66 | positive_feature = self.linear_network(self.attn_network( 67 | self.backbone_network(batch['positive_img'][idx].unsqueeze(0).to(self.opt.device)))) 68 | negative_feature = self.linear_network(self.attn_network( 69 | self.backbone_network(batch['negative_img'][idx].unsqueeze(0).to(self.opt.device)))) 70 | sketch_seq_feature = self.stage2_network(self.cat_feature(self.attn_network( 71 | self.backbone_network(batch['sketch_seq'][idx].to(self.opt.device))), 72 | batch['condition'][idx].repeat(batch['sketch_seq'][idx].shape[0], 1).to(self.opt.device))) 73 | 74 | positive_feature = positive_feature.repeat(sketch_seq_feature.shape[0], 1) 75 | negative_feature = negative_feature.repeat(sketch_seq_feature.shape[0], 1) 76 | loss += self.loss(sketch_seq_feature, positive_feature, negative_feature) 77 | 78 | self.optimizer.zero_grad() 79 | loss.backward() 80 | self.optimizer.step() 81 | return loss.item() 82 | 83 | def evaluate_NN(self, dataloader): 84 | self.backbone_network.eval() 85 | self.attn_network.eval() 86 | self.stage2_network.eval() 87 | 88 | self.Sketch_Array_Test = [] 89 | self.Image_Array_Test = [] 90 | self.Sketch_Path = [] 91 | self.Image_Path = [] 92 | 93 | for idx, batch in enumerate(dataloader): 94 | positive_feature = self.linear_network(self.attn_network( 95 | self.backbone_network(batch['positive_img'].to(self.opt.device)))) 96 | 97 | sketch_feature = self.cat_feature(self.attn_network( 98 | self.backbone_network(batch['sketch_seq'].squeeze(0).to(self.opt.device))), 99 | batch['condition'].repeat(batch['sketch_seq'].squeeze(0).shape[0], 1).to(self.opt.device)) 100 | 101 | self.Sketch_Array_Test.append(sketch_feature) 102 | self.Sketch_Path.append(batch['sketch_seq_paths']) 103 | 104 | for i_num, positive_path in enumerate(batch['positive_path']): 105 | if positive_path not in self.Image_Path: 106 | self.Image_Path.append(batch['positive_path'][i_num]) 107 | self.Image_Array_Test.append(positive_feature[i_num]) 108 | 109 | self.Sketch_Array_Test = torch.stack(self.Sketch_Array_Test) 110 | self.Image_Array_Test = torch.stack(self.Image_Array_Test) 111 | num_of_Sketch_Step = len(self.Sketch_Array_Test[0]) 112 | avererage_area = [] 113 | avererage_area_percentile = [] 114 | avererage_ourB = [] 115 | avererage_ourA = [] 116 | 117 | exps = np.linspace(1, num_of_Sketch_Step, num_of_Sketch_Step) / num_of_Sketch_Step 118 | factor = np.exp(1 - exps) / np.e 119 | rank_all = torch.zeros(len(self.Sketch_Array_Test), num_of_Sketch_Step) 120 | rank_all_percentile = torch.zeros(len(self.Sketch_Array_Test), num_of_Sketch_Step) 121 | 122 | for i_batch, sanpled_batch in enumerate(self.Sketch_Array_Test): 123 | mean_rank = [] 124 | mean_rank_percentile = [] 125 | mean_rank_ourB = [] 126 | mean_rank_ourA = [] 127 | 128 | for i_sketch in range(sanpled_batch.shape[0]): 129 | 130 | sketch_feature = self.stage2_network(sanpled_batch[:i_sketch+1].to(self.opt.device)) 131 | 132 | s_path =self.Sketch_Path[i_batch] 133 | s_path=''.join(s_path) 134 | positive_name = '_'.join(s_path.split('/')[-1].split('_')[:-1]) 135 | position_query = self.Image_Path.index(positive_name) 136 | 137 | target_distance = F.pairwise_distance(F.normalize(sketch_feature[-1].unsqueeze(0).to(self.opt.device)), self.Image_Array_Test[position_query].unsqueeze(0).to(self.opt.device)) 138 | distance = F.pairwise_distance(F.normalize(sketch_feature[-1].unsqueeze(0).to(self.opt.device)), self.Image_Array_Test.to(self.opt.device)) 139 | 140 | rank_all[i_batch, i_sketch] = distance.le(target_distance).sum() 141 | rank_all_percentile[i_batch, i_sketch] = (len(distance) - rank_all[i_batch, i_sketch]) / (len(distance) - 1) 142 | 143 | if rank_all[i_batch, i_sketch].item() == 0: 144 | mean_rank.append(1.) 145 | else: 146 | mean_rank.append(1/rank_all[i_batch, i_sketch].item()) 147 | mean_rank_percentile.append(rank_all_percentile[i_batch, i_sketch].item()) 148 | mean_rank_ourB.append(1/rank_all[i_batch, i_sketch].item()*factor[i_sketch]) 149 | mean_rank_ourA.append(rank_all_percentile[i_batch, i_sketch].item()*factor[i_sketch]) 150 | 151 | avererage_area.append(np.sum(mean_rank)/len(mean_rank)) 152 | avererage_area_percentile.append(np.sum(mean_rank_percentile)/len(mean_rank_percentile)) 153 | avererage_ourB.append(np.sum(mean_rank_ourB)/len(mean_rank_ourB)) 154 | avererage_ourA.append(np.sum(mean_rank_ourA)/len(mean_rank_ourA)) 155 | 156 | top1_accuracy = rank_all[:, -1].le(1).sum().numpy() / rank_all.shape[0] 157 | top5_accuracy = rank_all[:, -1].le(5).sum().numpy() / rank_all.shape[0] 158 | top10_accuracy = rank_all[:, -1].le(10).sum().numpy() / rank_all.shape[0] 159 | meanMB = np.mean(avererage_area) 160 | meanMA = np.mean(avererage_area_percentile) 161 | meanOurB = np.mean(avererage_ourB) 162 | meanOurA = np.mean(avererage_ourA) 163 | 164 | return top1_accuracy, top5_accuracy, top10_accuracy, meanMB, meanMA, meanOurB, meanOurA 165 | 166 | def cat_feature(self, feature_attn, extend_feature): 167 | return torch.cat([feature_attn, extend_feature], dim=1) --------------------------------------------------------------------------------