├── README.md ├── main.py ├── model.py ├── prompt.py ├── result ├── arch.png └── vis.png └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # ClipPrompt 2 | 3 | A PyTorch implementation of ClipPrompt based on CVPR 2023 paper 4 | [CLIP for All Things Zero-Shot Sketch-Based Image Retrieval, Fine-Grained or Not](https://openaccess.thecvf.com/content/CVPR2023/html/Sain_CLIP_for_All_Things_Zero-Shot_Sketch-Based_Image_Retrieval_Fine-Grained_or_CVPR_2023_paper.html). 5 | 6 | ![Network Architecture](result/arch.png) 7 | 8 | ## Requirements 9 | 10 | - [Anaconda](https://www.anaconda.com/download/) 11 | - [PyTorch](https://pytorch.org) 12 | 13 | ``` 14 | conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia 15 | ``` 16 | 17 | - [TorchMetrics](https://lightning.ai/docs/torchmetrics/stable/) 18 | 19 | ``` 20 | conda install -c conda-forge torchmetrics 21 | ``` 22 | 23 | - [CLIP](https://github.com/openai/CLIP) 24 | 25 | ``` 26 | pip install git+https://github.com/openai/CLIP.git 27 | ``` 28 | 29 | ## Dataset 30 | 31 | [Sketchy Extended](http://sketchy.eye.gatech.edu) and 32 | [TU-Berlin Extended](http://cybertron.cg.tu-berlin.de/eitz/projects/classifysketch/) datasets are used in this repo, you 33 | could download these datasets from official websites, or download them from 34 | [Google Drive](https://drive.google.com/drive/folders/1lce41k7cGNUOwzt-eswCeahDLWG6Cdk0?usp=sharing). The data directory 35 | structure is shown as follows: 36 | 37 | ``` 38 | ├──sketchy 39 | ├── train 40 | ├── sketch 41 | ├── airplane 42 | ├── n02691156_58-1.jpg 43 | └── ... 44 | ... 45 | ├── photo 46 | same structure as sketch 47 | ├── val 48 | same structure as train 49 | ... 50 | ├──tuberlin 51 | same structure as sketchy 52 | ... 53 | ``` 54 | 55 | ## Usage 56 | 57 | To train a model on `Sketchy Extended` dataset, run: 58 | 59 | ``` 60 | python main.py --mode train --data_name sketchy 61 | ``` 62 | 63 | To test a model on `Sketchy Extended` dataset, run: 64 | 65 | ``` 66 | python main.py --mode test --data_name sketchy --query_name 67 | ``` 68 | 69 | common arguments: 70 | 71 | ``` 72 | --data_root Datasets root path [default value is '/home/data'] 73 | --data_name Dataset name [default value is 'sketchy'](choices=['sketchy', 'tuberlin']) 74 | --prompt_num Number of prompt embedding [default value is 3] 75 | --save_root Result saved root path [default value is 'result'] 76 | --mode Mode of the script [default value is 'train'](choices=['train', 'test']) 77 | ``` 78 | 79 | train arguments: 80 | 81 | ``` 82 | --batch_size Number of images in each mini-batch [default value is 64] 83 | --epochs Number of epochs over the model to train [default value is 60] 84 | --triplet_margin Margin of triplet loss [default value is 0.3] 85 | --encoder_lr Learning rate of encoder [default value is 1e-4] 86 | --prompt_lr Learning rate of prompt embedding [default value is 1e-3] 87 | --cls_weight Weight of classification loss [default value is 0.5] 88 | --seed Random seed (-1 for no manual seed) [default value is -1] 89 | ``` 90 | 91 | test arguments: 92 | 93 | ``` 94 | --query_name Query image path [default value is '/home/data/sketchy/val/sketch/cow/n01887787_591-14.jpg'] 95 | --retrieval_num Number of retrieved images [default value is 8] 96 | ``` 97 | 98 | ## Benchmarks 99 | 100 | The models are trained on one NVIDIA GeForce RTX 3090 (24G) GPU. `seed` is `42`, `prompt_lr` is `1e-3` 101 | and `distance function` is `1.0 - F.cosine_similarity(x, y)`, the other hyperparameters are the default values. 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 |
DatasetPrompt NummAP@200mAP@allP@100P@200Download
Sketchy Extended371.964.370.868.1MEGA
TU-Berlin Extended375.366.073.969.7MEGA
136 | 137 | ## Results 138 | 139 | ![vis](result/vis.png) 140 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | import clip 5 | import numpy as np 6 | import pandas as pd 7 | import torch 8 | import torch.nn.functional as F 9 | from PIL import Image, ImageDraw 10 | from torch.nn import TripletMarginWithDistanceLoss, CrossEntropyLoss 11 | from torch.optim import Adam 12 | from torch.utils.data.dataloader import DataLoader 13 | from tqdm import tqdm 14 | 15 | from model import Model 16 | from utils import DomainDataset, compute_metric, parse_args 17 | 18 | 19 | # train for one epoch 20 | def train(net, data_loader, train_optimizer): 21 | net.train() 22 | total_loss, total_num, train_bar = 0.0, 0, tqdm(data_loader, dynamic_ncols=True) 23 | for img, pos, neg, label in train_bar: 24 | sketch_emb = net(img.cuda(), img_type='sketch') 25 | pos_emb = net(pos.cuda(), img_type='photo') 26 | neg_emb = net(neg.cuda(), img_type='photo') 27 | triplet_loss = triplet_criterion(sketch_emb, pos_emb, neg_emb) 28 | 29 | # normalized embeddings 30 | sketch_emb = F.normalize(sketch_emb, dim=-1) 31 | pos_emb = F.normalize(pos_emb, dim=-1) 32 | # cosine similarity as logits 33 | logit_scale = net.clip_model.logit_scale 34 | logits_sketch = logit_scale * sketch_emb @ text_emb.t() 35 | logits_pos = logit_scale * pos_emb @ text_emb.t() 36 | 37 | cls_sketch_loss = cls_criterion(logits_sketch, label.cuda()) 38 | cls_photo_loss = cls_criterion(logits_pos, label.cuda()) 39 | loss = triplet_loss + (cls_sketch_loss + cls_photo_loss) * args.cls_weight 40 | train_optimizer.zero_grad() 41 | loss.backward() 42 | train_optimizer.step() 43 | total_num += img.size(0) 44 | total_loss += loss.item() * img.size(0) 45 | train_bar.set_description('Train Epoch: [{}/{}] Loss: {:.4f}' 46 | .format(epoch, args.epochs, total_loss / total_num)) 47 | 48 | return total_loss / total_num 49 | 50 | 51 | # val for one epoch 52 | def val(net, data_loader): 53 | net.eval() 54 | vectors, domains, labels = [], [], [] 55 | with torch.no_grad(): 56 | for img, domain, label in tqdm(data_loader, desc='Feature extracting', dynamic_ncols=True): 57 | emb = net(img.cuda(), img_type='photo' if domain == 0 else 'sketch') 58 | vectors.append(emb) 59 | domains.append(domain.cuda()) 60 | labels.append(label.cuda()) 61 | vectors = torch.cat(vectors, dim=0) 62 | domains = torch.cat(domains, dim=0) 63 | labels = torch.cat(labels, dim=0) 64 | acc = compute_metric(vectors, domains, labels) 65 | results['P@100'].append(acc['P@100'] * 100) 66 | results['P@200'].append(acc['P@200'] * 100) 67 | results['mAP@200'].append(acc['mAP@200'] * 100) 68 | results['mAP@all'].append(acc['mAP@all'] * 100) 69 | print('Val Epoch: [{}/{}] | P@100:{:.1f}% | P@200:{:.1f}% | mAP@200:{:.1f}% | mAP@all:{:.1f}%' 70 | .format(epoch, args.epochs, acc['P@100'] * 100, acc['P@200'] * 100, acc['mAP@200'] * 100, 71 | acc['mAP@all'] * 100)) 72 | return acc['precise'], vectors 73 | 74 | 75 | if __name__ == '__main__': 76 | # args parse 77 | args = parse_args() 78 | save_name_pre = '{}_{}'.format(args.data_name, args.prompt_num) 79 | val_data = DomainDataset(args.data_root, args.data_name, split='val') 80 | 81 | if args.mode == 'train': 82 | # data prepare 83 | train_data = DomainDataset(args.data_root, args.data_name, split='train') 84 | train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=8) 85 | val_loader = DataLoader(val_data, batch_size=1, shuffle=False, num_workers=8) 86 | # model and loss setup 87 | model = Model(args.prompt_num).cuda() 88 | triplet_criterion = TripletMarginWithDistanceLoss( 89 | distance_function=lambda x, y: 1.0 - F.cosine_similarity(x, y), 90 | margin=args.triplet_margin) 91 | text = torch.cat([clip.tokenize('a photo of a {}'.format(train_data.names[c].replace('_', ' '))) 92 | for c in sorted(train_data.names.keys())]) 93 | with torch.no_grad(): 94 | text_emb = F.normalize(model.clip_model.encode_text(text.cuda()), dim=-1) 95 | cls_criterion = CrossEntropyLoss() 96 | # optimizer config 97 | optimizer = Adam([{'params': model.sketch_encoder.parameters(), 'lr': args.encoder_lr}, 98 | {'params': model.photo_encoder.parameters(), 'lr': args.encoder_lr}, 99 | {'params': [model.sketch_prompt, model.photo_prompt], 'lr': args.prompt_lr}]) 100 | # training loop 101 | results = {'train_loss': [], 'val_precise': [], 'P@100': [], 'P@200': [], 'mAP@200': [], 'mAP@all': []} 102 | best_precise = 0.0 103 | for epoch in range(1, args.epochs + 1): 104 | train_loss = train(model, train_loader, optimizer) 105 | results['train_loss'].append(train_loss) 106 | val_precise, features = val(model, val_loader) 107 | results['val_precise'].append(val_precise * 100) 108 | # save statistics 109 | data_frame = pd.DataFrame(data=results, index=range(1, epoch + 1)) 110 | data_frame.to_csv('{}/{}_results.csv'.format(args.save_root, save_name_pre), index_label='epoch') 111 | 112 | if val_precise > best_precise: 113 | best_precise = val_precise 114 | torch.save(model.state_dict(), '{}/{}_model.pth'.format(args.save_root, save_name_pre)) 115 | torch.save(features.cpu(), '{}/{}_vectors.pth'.format(args.save_root, save_name_pre)) 116 | else: 117 | data_base = '{}/{}_vectors.pth'.format(args.save_root, save_name_pre) 118 | if not os.path.exists(data_base): 119 | raise FileNotFoundError('{} not found'.format(data_base)) 120 | embeddings = torch.load(data_base) 121 | if args.query_name not in val_data.images: 122 | raise FileNotFoundError('{} not found'.format(args.query_name)) 123 | query_index = val_data.images.index(args.query_name) 124 | query_image = Image.open(args.query_name).resize((224, 224), resample=Image.BICUBIC) 125 | query_label = val_data.labels[query_index] 126 | query_class = val_data.names[query_label] 127 | query_emb = embeddings[query_index] 128 | 129 | gallery_indices = np.array(val_data.domains) == 0 130 | gallery_images = np.array(val_data.images)[gallery_indices] 131 | gallery_labels = np.array(val_data.labels)[gallery_indices] 132 | gallery_embs = embeddings[gallery_indices] 133 | 134 | sim_matrix = F.cosine_similarity(query_emb.unsqueeze(dim=0), gallery_embs).squeeze(dim=0) 135 | idx = sim_matrix.topk(k=args.retrieval_num, dim=-1)[1] 136 | 137 | result_path = '{}/{}/{}'.format(args.save_root, save_name_pre, args.query_name.split('/')[-1].split('.')[0]) 138 | if os.path.exists(result_path): 139 | shutil.rmtree(result_path) 140 | os.makedirs(result_path) 141 | query_image.save('{}/query ({}).jpg'.format(result_path, query_class)) 142 | correct = 0 143 | for num, index in enumerate(idx): 144 | retrieval_image = Image.open(gallery_images[index.item()]).resize((224, 224), resample=Image.BICUBIC) 145 | draw = ImageDraw.Draw(retrieval_image) 146 | retrieval_label = gallery_labels[index.item()] 147 | retrieval_class = val_data.names[retrieval_label] 148 | retrieval_status = retrieval_label == query_label 149 | retrieval_sim = sim_matrix[index.item()].item() 150 | if retrieval_status: 151 | draw.rectangle((0, 0, 223, 223), outline='green', width=8) 152 | correct += 1 153 | else: 154 | draw.rectangle((0, 0, 223, 223), outline='red', width=8) 155 | retrieval_image.save('{}/{}_{} ({}).jpg'.format(result_path, num + 1, 156 | '%.4f' % retrieval_sim, retrieval_class)) 157 | print('Query: {} | Class: {} | Retrieval: {}/{} | Saved: {}' 158 | .format(args.query_name, query_class, correct, args.retrieval_num, result_path)) -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from prompt import load_clip 7 | 8 | 9 | def unfreeze_ln(m): 10 | if isinstance(m, nn.LayerNorm): 11 | if hasattr(m, 'weight') and m.weight is not None: 12 | m.weight.requires_grad_(True) 13 | if hasattr(m, 'bias') and m.bias is not None: 14 | m.bias.requires_grad_(True) 15 | 16 | 17 | class Model(nn.Module): 18 | def __init__(self, prompt_num): 19 | super(Model, self).__init__() 20 | # backbone 21 | clip_model = load_clip('ViT-B/32') 22 | for param in clip_model.parameters(): 23 | param.requires_grad_(False) 24 | visual = clip_model.visual 25 | visual.apply(unfreeze_ln) 26 | visual.proj.requires_grad_(True) 27 | 28 | self.sketch_encoder = visual 29 | self.photo_encoder = copy.deepcopy(visual) 30 | self.clip_model = clip_model 31 | 32 | # prompts 33 | self.sketch_prompt = nn.Parameter(torch.randn(prompt_num, self.sketch_encoder.class_embedding.shape[0])) 34 | self.photo_prompt = nn.Parameter(torch.randn(prompt_num, self.photo_encoder.class_embedding.shape[0])) 35 | 36 | def forward(self, img, img_type): 37 | if img_type == 'sketch': 38 | proj = self.sketch_encoder(img, self.sketch_prompt.expand(img.shape[0], -1, -1)) 39 | else: 40 | proj = self.photo_encoder(img, self.photo_prompt.expand(img.shape[0], -1, -1)) 41 | return proj -------------------------------------------------------------------------------- /prompt.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Tuple, Union 3 | 4 | import torch 5 | from clip.clip import _MODELS, _download, available_models 6 | from clip.model import convert_weights, CLIP, VisionTransformer 7 | 8 | 9 | def load_clip(name: str): 10 | """Load a CLIP model 11 | 12 | Parameters 13 | ---------- 14 | name : str 15 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 16 | Returns 17 | ------- 18 | model : torch.nn.Module 19 | The CLIP model 20 | """ 21 | if name in _MODELS: 22 | model_path = _download(_MODELS[name], os.path.expanduser("~/.cache/clip")) 23 | elif os.path.isfile(name): 24 | model_path = name 25 | else: 26 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 27 | 28 | with open(model_path, 'rb') as opened_file: 29 | try: 30 | # loading JIT archive 31 | model = torch.jit.load(opened_file, map_location="cpu").eval() 32 | state_dict = None 33 | except RuntimeError: 34 | # loading saved state dict 35 | state_dict = torch.load(opened_file, map_location="cpu") 36 | 37 | model = build_model(state_dict or model.state_dict()).to("cpu") 38 | model.float() 39 | return model 40 | 41 | 42 | def build_model(state_dict: dict): 43 | vit = "visual.proj" in state_dict 44 | 45 | if vit: 46 | vision_width = state_dict["visual.conv1.weight"].shape[0] 47 | vision_layers = len( 48 | [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 49 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 50 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 51 | image_resolution = vision_patch_size * grid_size 52 | else: 53 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in 54 | [1, 2, 3, 4]] 55 | vision_layers = tuple(counts) 56 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 57 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 58 | vision_patch_size = None 59 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 60 | image_resolution = output_width * 32 61 | 62 | embed_dim = state_dict["text_projection"].shape[1] 63 | context_length = state_dict["positional_embedding"].shape[0] 64 | vocab_size = state_dict["token_embedding.weight"].shape[0] 65 | transformer_width = state_dict["ln_final.weight"].shape[0] 66 | transformer_heads = transformer_width // 64 67 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks"))) 68 | 69 | model = PromptCLIP(embed_dim, image_resolution, vision_layers, vision_width, vision_patch_size, 70 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers) 71 | 72 | for key in ["input_resolution", "context_length", "vocab_size"]: 73 | if key in state_dict: 74 | del state_dict[key] 75 | 76 | convert_weights(model) 77 | model.load_state_dict(state_dict) 78 | return model.eval() 79 | 80 | 81 | class PromptCLIP(CLIP): 82 | def __init__(self, 83 | embed_dim: int, 84 | # vision 85 | image_resolution: int, 86 | vision_layers: Union[Tuple[int, int, int, int], int], 87 | vision_width: int, 88 | vision_patch_size: int, 89 | # text 90 | context_length: int, 91 | vocab_size: int, 92 | transformer_width: int, 93 | transformer_heads: int, 94 | transformer_layers: int 95 | ): 96 | super().__init__(embed_dim, image_resolution, vision_layers, vision_width, vision_patch_size, context_length, 97 | vocab_size, transformer_width, transformer_heads, transformer_layers) 98 | if not isinstance(vision_layers, (tuple, list)): 99 | vision_heads = vision_width // 64 100 | self.visual = PromptVisionTransformer( 101 | input_resolution=image_resolution, 102 | patch_size=vision_patch_size, 103 | width=vision_width, 104 | layers=vision_layers, 105 | heads=vision_heads, 106 | output_dim=embed_dim) 107 | 108 | 109 | class PromptVisionTransformer(VisionTransformer): 110 | def forward(self, x: torch.Tensor, prompt: torch.Tensor = None): 111 | x = self.conv1(x) # shape = [*, width, grid, grid] 112 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 113 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 114 | x = torch.cat( 115 | [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), 116 | x], dim=1) # shape = [*, grid ** 2 + 1, width] 117 | x = x + self.positional_embedding.to(x.dtype) 118 | if prompt is not None: 119 | # prompt should be of shape [*, K, width] 120 | x = torch.cat([prompt, x], dim=1) # [*, grid ** 2 + 1 + K, width] 121 | x = self.ln_pre(x) 122 | 123 | x = x.permute(1, 0, 2) # NLD -> LND 124 | x = self.transformer(x) 125 | x = x.permute(1, 0, 2) # LND -> NLD 126 | 127 | x = self.ln_post(x[:, 0, :]) 128 | 129 | if self.proj is not None: 130 | x = x @ self.proj 131 | 132 | return x -------------------------------------------------------------------------------- /result/arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leftthomas/ClipPrompt/e946e6dc54c837e3e01c4897bd95d31d2adaa2b1/result/arch.png -------------------------------------------------------------------------------- /result/vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leftthomas/ClipPrompt/e946e6dc54c837e3e01c4897bd95d31d2adaa2b1/result/vis.png -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import os 4 | import random 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | from PIL import Image 10 | from torch.backends import cudnn 11 | from torch.utils.data.dataset import Dataset 12 | from torchmetrics.functional.retrieval import retrieval_precision, retrieval_average_precision 13 | from torchvision import transforms 14 | 15 | 16 | def get_transform(): 17 | return transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor(), 18 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) 19 | 20 | 21 | class DomainDataset(Dataset): 22 | def __init__(self, data_root, data_name, split='train'): 23 | super(DomainDataset, self).__init__() 24 | 25 | self.split = split 26 | 27 | images, self.refs = [], {} 28 | for classes in os.listdir(os.path.join(data_root, data_name, split, 'sketch')): 29 | sketches = glob.glob(os.path.join(data_root, data_name, split, 'sketch', str(classes), '*.jpg')) 30 | photos = glob.glob(os.path.join(data_root, data_name, split, 'photo', str(classes), '*.jpg')) 31 | images += sketches 32 | if split == 'val': 33 | images += photos 34 | else: 35 | self.refs[str(classes)] = photos 36 | self.images = sorted(images) 37 | self.transform = get_transform() 38 | 39 | self.domains, self.labels, self.classes = [], [], {} 40 | i = 0 41 | for img in self.images: 42 | domain, label = os.path.dirname(img).split('/')[-2:] 43 | self.domains.append(0 if domain == 'photo' else 1) 44 | if label not in self.classes: 45 | self.classes[label] = i 46 | i += 1 47 | self.labels.append(self.classes[label]) 48 | 49 | self.names = {} 50 | for key, value in self.classes.items(): 51 | self.names[value] = key 52 | 53 | def __getitem__(self, index): 54 | img_name = self.images[index] 55 | domain = self.domains[index] 56 | label = self.labels[index] 57 | img = self.transform(Image.open(img_name)) 58 | if self.split == 'train': 59 | pos_name = np.random.choice(self.refs[self.names[label]]) 60 | remain_classes = sorted(set(self.classes.keys()) - {self.names[label]}) 61 | neg_name = np.random.choice(self.refs[np.random.choice(remain_classes)]) 62 | pos = self.transform(Image.open(pos_name)) 63 | neg = self.transform(Image.open(neg_name)) 64 | return img, pos, neg, label 65 | else: 66 | return img, domain, label 67 | 68 | def __len__(self): 69 | return len(self.images) 70 | 71 | 72 | def compute_metric(vectors, domains, labels): 73 | acc = {} 74 | sketch_vectors, photo_vectors = vectors[domains == 1], vectors[domains == 0] 75 | sketch_labels, photo_labels = labels[domains == 1], labels[domains == 0] 76 | precs_100, precs_200, maps_200, maps_all = 0, 0, 0, 0 77 | for sketch_vector, sketch_label in zip(sketch_vectors, sketch_labels): 78 | sim = F.cosine_similarity(sketch_vector.unsqueeze(dim=0), photo_vectors).squeeze(dim=0) 79 | target = torch.zeros_like(sim, dtype=torch.bool) 80 | target[sketch_label == photo_labels] = True 81 | precs_100 += retrieval_precision(sim, target, top_k=100).item() 82 | precs_200 += retrieval_precision(sim, target, top_k=200).item() 83 | maps_200 += retrieval_average_precision(sim, target, top_k=200).item() 84 | maps_all += retrieval_average_precision(sim, target).item() 85 | 86 | prec_100 = precs_100 / sketch_vectors.shape[0] 87 | prec_200 = precs_200 / sketch_vectors.shape[0] 88 | map_200 = maps_200 / sketch_vectors.shape[0] 89 | map_all = maps_all / sketch_vectors.shape[0] 90 | 91 | acc['P@100'], acc['P@200'], acc['mAP@200'], acc['mAP@all'] = prec_100, prec_200, map_200, map_all 92 | # the mean value is chosen as the representative of precise 93 | acc['precise'] = (acc['P@100'] + acc['P@200'] + acc['mAP@200'] + acc['mAP@all']) / 4 94 | return acc 95 | 96 | 97 | def parse_args(): 98 | parser = argparse.ArgumentParser(description='Train/Test Model') 99 | # common args 100 | parser.add_argument('--data_root', default='/home/data', type=str, help='Datasets root path') 101 | parser.add_argument('--data_name', default='sketchy', type=str, choices=['sketchy', 'tuberlin'], 102 | help='Dataset name') 103 | parser.add_argument('--prompt_num', default=3, type=int, help='Number of prompt embedding') 104 | parser.add_argument('--save_root', default='result', type=str, help='Result saved root path') 105 | parser.add_argument('--mode', default='train', choices=['train', 'test'], type=str, help='Mode of the script') 106 | 107 | # train args 108 | parser.add_argument('--batch_size', default=64, type=int, help='Number of images in each mini-batch') 109 | parser.add_argument('--epochs', default=60, type=int, help='Number of epochs over the model to train') 110 | parser.add_argument('--triplet_margin', default=0.3, type=float, help='Margin of triplet loss') 111 | parser.add_argument('--encoder_lr', default=1e-4, type=float, help='Learning rate of encoder') 112 | parser.add_argument('--prompt_lr', default=1e-3, type=float, help='Learning rate of prompt embedding') 113 | parser.add_argument('--cls_weight', default=0.5, type=float, help='Weight of classification loss') 114 | parser.add_argument('--seed', default=-1, type=int, help='random seed (-1 for no manual seed)') 115 | 116 | # test args 117 | parser.add_argument('--query_name', default='/home/data/sketchy/val/sketch/cow/n01887787_591-14.jpg', type=str, 118 | help='Query image path') 119 | parser.add_argument('--retrieval_num', default=8, type=int, help='Number of retrieved images') 120 | 121 | args = parser.parse_args() 122 | if args.seed >= 0: 123 | random.seed(args.seed) 124 | np.random.seed(args.seed) 125 | torch.manual_seed(args.seed) 126 | torch.cuda.manual_seed_all(args.seed) 127 | cudnn.deterministic = True 128 | cudnn.benchmark = False 129 | 130 | if not os.path.exists(args.save_root): 131 | os.makedirs(args.save_root) 132 | return args --------------------------------------------------------------------------------