├── .gitignore ├── LICENSE ├── README.md ├── aux ├── README.md ├── Semantic │ ├── PART-SHREC14 │ │ └── w2v.npz │ ├── SHREC13 │ │ └── w2v.npz │ ├── SHREC14 │ │ └── w2v.npz │ ├── Sketchy │ │ └── word2vec-google-news.npy │ ├── TU-Berlin │ │ └── word2vec-google-news.npy │ └── domainnet │ │ └── word2vec-google-news.npy └── data │ ├── PART-SHREC14 │ ├── cad.hdf5 │ └── sk.hdf5 │ ├── SHREC13 │ ├── cad.hdf5 │ └── sk.hdf5 │ ├── SHREC14 │ ├── cad.hdf5 │ └── sk.hdf5 │ ├── Sketchy │ ├── im.hdf5 │ └── sk.hdf5 │ ├── TU-Berlin │ ├── im.hdf5 │ └── sk.hdf5 │ └── domainnet │ ├── im.hdf5 │ └── sk.hdf5 ├── data.py ├── fewshot.py ├── imgs ├── problem.png └── setup.png ├── metrics.py ├── models.py ├── retrieve-any.py ├── retrieve-many.py ├── retrieve.py ├── runs ├── closed-eval-sb3dr.sh ├── closed-eval-sbic.sh ├── closed-eval-sbir.sh ├── closed-train-sb3dr.sh ├── closed-train-sbic.sh ├── closed-train-sbir.sh ├── open-eval.sh └── open-train.sh ├── train.py ├── utils.py └── validate.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | *.pyc 3 | *.pkl 4 | *.pth 5 | 6 | # Mac 7 | .DS_Store 8 | ._.DS_Store 9 | 10 | # Other 11 | *.jpg 12 | *.eps 13 | *.mp4 14 | *.mat 15 | *.tar 16 | *.json 17 | *.sublime* 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 William Thong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # open-search 2 | 3 | This is the source code for reproducing the experiments from the paper: 4 | 5 | [Open Cross-Domain Visual Search](https://arxiv.org/abs/1911.08621) 6 | **William Thong, Pascal Mettes, Cees G.M. Snoek** 7 | Computer Vision and Image Understanding (CVIU), vol.200, 2020 8 | 9 | **TL;DR** *We search for seen and unseen categories from any source domain to any target domain. 10 | To achieve this, we train domain-specific prototype-learner with a normalized temperature-scaled cross entropy loss to map inputs to a common semantic space.* 11 | 12 | ![alt text](imgs/problem.png "Method: common semantic space") 13 | 14 | We validate the proposed approach on three well-established sketch-based tasks in a closed setting (a). 15 | We propose three novel open cross-domain tasks to search for categories from and within any number of domains (b-d). 16 | 17 | ![alt text](imgs/setup.png "Close vs. open cross-domain search") 18 | 19 | ## Closed cross-domain experiments 20 | 21 | ### 1. Zero-shot sketch-based image retrieval 22 | 23 | **Dataset** 24 | To download the TU-Berlin and Sketchy datasets, check this [repository](https://github.com/qliu24/SAKE). 25 | 26 | **Training** 27 | Run [runs/closed-train-sbir.sh](runs/closed-train-sbir.sh). This will train one model for sketches and one model for images, for both datasets and for both normal zero-shot and generalized zero-shot settings. 28 | 29 | **Evaluation** 30 | Run [runs/closed-eval-sbir.sh](runs/closed-eval-sbir.sh). This will perform sketch-based image retrieval in both zero-shot and generalized zero-shot settings and measure the mAP@all and prec@100 metrics. 31 | 32 | ### 2. Few-shot sketch-based image classification 33 | 34 | **Dataset** 35 | Similar to the previous section, check this [repository](https://github.com/qliu24/SAKE) for downloading the Sketchy dataset. 36 | 37 | **Training** 38 | Run [runs/closed-train-sbic.sh](runs/closed-train-sbic.sh). This will train one model for sketches and one model for images. 39 | 40 | **Evaluation** 41 | Run [runs/closed-eval-sbic.sh](runs/closed-eval-sbic.sh). This will perform few-shot sketch-based image classification and measure the accuracy over 500 different runs. 42 | 43 | ### 3. Many-shot sketch-based 3D shape retrieval 44 | 45 | **Dataset** 46 | To prepare the [SHREC13](http://orca.st.usm.edu/~bli/sharp/sharp/contest/2013/SBR/), 47 | [SHREC14](http://orca.st.usm.edu/~bli/sharp/sharp/contest/2014/SBR/) and Part-SHREC14 datasets, check this [repository](https://github.com/twuilliam/shrec-sketches-helpers). 48 | 49 | **Training** 50 | Run [runs/closed-train-sb3dr.sh](runs/closed-train-sb3dr.sh). This will train one model for sketches and one model for 2D projections of 3D shapes, for all three datasets. 51 | 52 | **Evaluation** 53 | Run [runs/closed-eval-sb3dr.sh](runs/closed-eval-sb3dr.sh). This will perform sketch-based 3D shape retrieval and measure the NN, FT, ST, E, DGC and mAP metrics. 54 | 55 | ## Open cross-domain experiments 56 | 57 | **Dataset** 58 | To prepare the [DomainNet](http://ai.bu.edu/M3SDA/) dataset, check this [repository](https://github.com/twuilliam/domainnet-helpers). 59 | 60 | **Training** 61 | Run [runs/open-train.sh](runs/open-train.sh). This will train one model for every domain in DomainNet, for zero-shot and many-shot settings (totalling 12 models). 62 | 63 | **Features extraction** 64 | Run [runs/open-eval.sh](runs/open-eval.sh). This will extract the features for every domain in both zero-shot and many-shot settings. 65 | 66 | ### 1. From *any* source to *any* target domains 67 | Run [runs/open-eval.sh](runs/open-eval.sh). This will produce 36 cross-domain retrieval experiments and measure the mAP@all, for each setting. 68 | 69 | ### 2. From *many* source to *any* target domains 70 | Run [runs/open-eval.sh](runs/open-eval.sh). This will combine several source domains to improve their initial mAP@all. 71 | 72 | ### 3. From *any* source to *many* target domains 73 | Run [runs/open-eval.sh](runs/open-eval.sh). This will produce a search within multiple domains and measure the intent-aware mAP@100. 74 | 75 | ## Requirements 76 | 77 | The code was initially implemented with python 2.7 and pytorch 0.3.1. I'll to migrate the source code to python 3.6 and pytorch 1.+ later. In the meantime, here is the configuration I used: 78 | 79 | ``` 80 | matplotlib=2.2.2 81 | numpy=1.15.4 82 | pandas=0.24.2 83 | python=2.7.16 84 | pytorch=0.3.1=py27_cuda8.0.61_cudnn7.1.2_3 85 | faiss-gpu=1.5.3 86 | torchvision=0.2.0=py27hfb27419_1 87 | opencv=3.3.1 88 | pretrainedmodels=0.7.4 89 | ``` 90 | 91 | ## Citation 92 | 93 | If you find these scripts useful, please consider citing our paper: 94 | 95 | ``` 96 | @article{ 97 | Thong2020OpenSearch, 98 | title={Open Cross-Domain Visual Search}, 99 | author={Thong, William and Mettes, Pascal and Snoek, Cees G.M.}, 100 | journal={CVIU}, 101 | year={2020}, 102 | url={https://doi.org/10.1016/j.cviu.2020.103045} 103 | } 104 | ``` 105 | -------------------------------------------------------------------------------- /aux/README.md: -------------------------------------------------------------------------------- 1 | ## Folder with auxiliary data for all datasets 2 | 3 | [data](data) contains compressed hdf5 files with image file paths and meta-data. 4 | 5 | [Semantic](Semantic) contains the word2vec representations of class names. 6 | -------------------------------------------------------------------------------- /aux/Semantic/PART-SHREC14/w2v.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twuilliam/open-search/5f74e3de5552a185e5d13d706bb3a9322606e704/aux/Semantic/PART-SHREC14/w2v.npz -------------------------------------------------------------------------------- /aux/Semantic/SHREC13/w2v.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twuilliam/open-search/5f74e3de5552a185e5d13d706bb3a9322606e704/aux/Semantic/SHREC13/w2v.npz -------------------------------------------------------------------------------- /aux/Semantic/SHREC14/w2v.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twuilliam/open-search/5f74e3de5552a185e5d13d706bb3a9322606e704/aux/Semantic/SHREC14/w2v.npz -------------------------------------------------------------------------------- /aux/Semantic/Sketchy/word2vec-google-news.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twuilliam/open-search/5f74e3de5552a185e5d13d706bb3a9322606e704/aux/Semantic/Sketchy/word2vec-google-news.npy -------------------------------------------------------------------------------- /aux/Semantic/TU-Berlin/word2vec-google-news.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twuilliam/open-search/5f74e3de5552a185e5d13d706bb3a9322606e704/aux/Semantic/TU-Berlin/word2vec-google-news.npy -------------------------------------------------------------------------------- /aux/Semantic/domainnet/word2vec-google-news.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twuilliam/open-search/5f74e3de5552a185e5d13d706bb3a9322606e704/aux/Semantic/domainnet/word2vec-google-news.npy -------------------------------------------------------------------------------- /aux/data/PART-SHREC14/cad.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twuilliam/open-search/5f74e3de5552a185e5d13d706bb3a9322606e704/aux/data/PART-SHREC14/cad.hdf5 -------------------------------------------------------------------------------- /aux/data/PART-SHREC14/sk.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twuilliam/open-search/5f74e3de5552a185e5d13d706bb3a9322606e704/aux/data/PART-SHREC14/sk.hdf5 -------------------------------------------------------------------------------- /aux/data/SHREC13/cad.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twuilliam/open-search/5f74e3de5552a185e5d13d706bb3a9322606e704/aux/data/SHREC13/cad.hdf5 -------------------------------------------------------------------------------- /aux/data/SHREC13/sk.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twuilliam/open-search/5f74e3de5552a185e5d13d706bb3a9322606e704/aux/data/SHREC13/sk.hdf5 -------------------------------------------------------------------------------- /aux/data/SHREC14/cad.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twuilliam/open-search/5f74e3de5552a185e5d13d706bb3a9322606e704/aux/data/SHREC14/cad.hdf5 -------------------------------------------------------------------------------- /aux/data/SHREC14/sk.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twuilliam/open-search/5f74e3de5552a185e5d13d706bb3a9322606e704/aux/data/SHREC14/sk.hdf5 -------------------------------------------------------------------------------- /aux/data/Sketchy/im.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twuilliam/open-search/5f74e3de5552a185e5d13d706bb3a9322606e704/aux/data/Sketchy/im.hdf5 -------------------------------------------------------------------------------- /aux/data/Sketchy/sk.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twuilliam/open-search/5f74e3de5552a185e5d13d706bb3a9322606e704/aux/data/Sketchy/sk.hdf5 -------------------------------------------------------------------------------- /aux/data/TU-Berlin/im.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twuilliam/open-search/5f74e3de5552a185e5d13d706bb3a9322606e704/aux/data/TU-Berlin/im.hdf5 -------------------------------------------------------------------------------- /aux/data/TU-Berlin/sk.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twuilliam/open-search/5f74e3de5552a185e5d13d706bb3a9322606e704/aux/data/TU-Berlin/sk.hdf5 -------------------------------------------------------------------------------- /aux/data/domainnet/im.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twuilliam/open-search/5f74e3de5552a185e5d13d706bb3a9322606e704/aux/data/domainnet/im.hdf5 -------------------------------------------------------------------------------- /aux/data/domainnet/sk.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twuilliam/open-search/5f74e3de5552a185e5d13d706bb3a9322606e704/aux/data/domainnet/sk.hdf5 -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | import pandas as pd 5 | from torch.utils.data import Dataset 6 | from utils import zero_cnames 7 | 8 | 9 | def image_loader(path): 10 | return cv2.imread(path)[:, :, ::-1] 11 | 12 | 13 | def sketch_loader(path): 14 | return cv2.imread(path)[:, :, ::-1] 15 | 16 | 17 | def create_splits(path, overwrite=False, gzsl=False): 18 | '''Create Train and Test splits''' 19 | splits = {} 20 | for modality in ['im', 'sk']: 21 | splits[modality] = {} 22 | 23 | fname = os.path.join(path, modality + '.hdf5') 24 | df = pd.read_hdf(fname) 25 | 26 | # get zero-shot class names 27 | if overwrite: 28 | if 'Sketchy' in path: 29 | dataset = 'Sketchy' 30 | elif 'TU-Berlin' in path: 31 | dataset = 'TU-Berlin' 32 | cnames = zero_cnames(dataset) 33 | cond = df['cat'].isin(cnames) 34 | 35 | df.loc[~cond, 'split'] = 'train' 36 | df.loc[cond, 'split'] = 'test' 37 | 38 | if gzsl: 39 | np.random.seed(1234) 40 | fnames = df.loc[df['split'] == 'train'].index 41 | to_select = np.random.choice(fnames, 42 | size=int(len(fnames)*0.2), 43 | replace=False) 44 | cond = df.index.isin(to_select) 45 | df.loc[cond, 'split'] = 'test' 46 | 47 | df_train = df.loc[df['split'] == 'train'] 48 | df_train = df_train.assign(cat=df_train['cat'].astype('category')) 49 | 50 | df_test = df.loc[df['split'] == 'test'] 51 | df_test = df_test.assign(cat=df_test['cat'].astype('category')) 52 | 53 | df_gal = df.loc[df['split'] == 'test'] 54 | df_gal = df_gal.assign(cat=df_gal['cat'].astype('category')) 55 | 56 | else: 57 | df_train = df.loc[df['split'] == 'train'] 58 | df_train = df_train.assign(cat=df_train['cat'].astype('category')) 59 | 60 | df_val = df.loc[df['split'] == 'val'] 61 | df_val = df_val.assign(cat=df_val['cat'].astype('category')) 62 | 63 | df_test = df.loc[df['split'] == 'test'] 64 | df_test = df_test.assign(cat=df_test['cat'].astype('category')) 65 | 66 | df_gal = pd.concat([df_val, df_test]) 67 | df_gal = df_gal.assign(cat=df_gal['cat'].astype('category')) 68 | 69 | splits[modality]['train'] = df_train 70 | splits[modality]['test'] = df_test 71 | splits[modality]['gal'] = df_gal 72 | return splits 73 | 74 | 75 | def is_ext(fnames): 76 | return [True if 'ext' in os.path.basename(f) else False for f in fnames] 77 | 78 | 79 | def create_fewshot_splits(path, subsample=True): 80 | '''Create Train and Test splits 81 | Following Hu et al, CVPR 2018 82 | ''' 83 | test_classes = ['car_(sedan)', 'pear', 'deer', 'couch', 'duck', 84 | 'airplane', 'cat', 'mouse', 'seagull', 'knife'] 85 | 86 | splits = {} 87 | for modality in ['im', 'sk']: 88 | splits[modality] = {} 89 | 90 | fname = os.path.join(path, modality + '.hdf5') 91 | df = pd.read_hdf(fname) 92 | 93 | if subsample: 94 | # subsampling extended images to match Hu et al CVPR18 95 | np.random.seed(1234) 96 | 97 | # get how many to discard 98 | cond = is_ext(df.index) 99 | df['ext'] = cond 100 | vv, cc = np.unique(df.loc[cond, 'cat'], return_counts=True) 101 | n_select = np.asarray(np.round(cc / float(np.sum(cc)) * 4336), dtype=int) 102 | 103 | # collect fnames to discard 104 | to_remove = [] 105 | for v, n in zip(vv, n_select): 106 | idx = df[(df['ext'] == True) & (df['cat'] == v)].index 107 | to_remove.extend(np.random.choice(idx, size=n, replace=False)) 108 | 109 | # subsampled df 110 | df = df[~df.index.isin(to_remove)] 111 | 112 | cond = df['cat'].isin(test_classes) 113 | 114 | df_train = df.loc[~cond] 115 | df_train = df_train.assign(cat=df_train['cat'].astype('category')) 116 | 117 | df_test = df.loc[cond] 118 | df_test = df_test.assign(cat=df_test['cat'].astype('category')) 119 | 120 | splits[modality]['train'] = df_train 121 | splits[modality]['test'] = df_test 122 | return splits 123 | 124 | 125 | def create_shape_splits(path): 126 | '''Create Train and Test splits for 3D shapes''' 127 | splits = {} 128 | for modality in ['cad', 'sk']: 129 | splits[modality] = {} 130 | 131 | fname = os.path.join(path, modality + '.hdf5') 132 | df = pd.read_hdf(fname) 133 | 134 | if 'split' in df.columns: 135 | df_train = df.loc[df['split'] == 'train'] 136 | df_train = df_train.assign(cat=df_train['cat'].astype('category')) 137 | 138 | df_test = df.loc[df['split'] == 'test'] 139 | df_test = df_test.assign(cat=df_test['cat'].astype('category')) 140 | else: 141 | df_train = df.copy() 142 | df_train = df_train.assign(cat=df_train['cat'].astype('category')) 143 | 144 | df_test = df.copy() 145 | df_test = df_test.assign(cat=df_test['cat'].astype('category')) 146 | 147 | splits[modality]['train'] = df_train 148 | splits[modality]['test'] = df_test 149 | splits[modality]['gal'] = df_test 150 | 151 | if modality == 'cad': 152 | splits['im'] = {} 153 | splits['im']['train'] = df_train 154 | splits['im']['gal'] = df_test 155 | splits['im']['test'] = df_test 156 | return splits 157 | 158 | 159 | def create_multi_splits(path, domain, overwrite=False): 160 | '''Create Train and Test splits for DomainNet''' 161 | splits = {} 162 | for modality in ['im', 'sk']: 163 | splits[modality] = {} 164 | 165 | fname = os.path.join(path, modality + '.hdf5') 166 | df = pd.read_hdf(fname) 167 | 168 | if modality == 'im': 169 | cond = df['domain'] == domain 170 | df = df.loc[cond] 171 | 172 | if overwrite: 173 | dataset = 'domainnet' 174 | cnames = zero_cnames(dataset) 175 | cond = df['cat'].isin(cnames) 176 | 177 | df.loc[~cond, 'split'] = 'train' 178 | df.loc[cond, 'split'] = 'test' 179 | 180 | cond = df['split'] == 'train' 181 | 182 | df_train = df.loc[cond] 183 | df_train = df_train.assign(cat=df_train['cat'].astype('category')) 184 | 185 | df_test = df.loc[~cond] 186 | df_test = df_test.assign(cat=df_test['cat'].astype('category')) 187 | 188 | splits[modality]['train'] = df_train 189 | splits[modality]['test'] = df_test 190 | splits[modality]['gal'] = df_test 191 | return splits 192 | 193 | 194 | class DataLoader(Dataset): 195 | def __init__(self, split, transform, root='', mode='im'): 196 | self.split = split 197 | self.transform = transform 198 | self.root = root 199 | if mode == 'im': 200 | self.loader = image_loader 201 | elif mode == 'sk': 202 | self.loader = sketch_loader 203 | 204 | def __getitem__(self, index): 205 | """ Read img, transform img and return class label in long int """ 206 | # read img and apply transformations 207 | fname = self.split.iloc[index].name 208 | img = self.loader(os.path.join(self.root, fname)) 209 | img = self.transform(img) 210 | 211 | # get class label 212 | item = self.split['cat'].cat.codes.iloc[index].astype('int64') 213 | return img, item 214 | 215 | def __len__(self): 216 | return self.split.shape[0] 217 | 218 | 219 | def get_proxies(path_semantic, class_names): 220 | try: 221 | semantic = np.load(path_semantic, allow_pickle=True).item() 222 | except: 223 | if os.path.splitext(path_semantic)[-1] == '.npz': 224 | semantic = np.load(path_semantic)['wv'].item() 225 | elif os.path.splitext(path_semantic)[-1] == '.pkl': 226 | import pickle 227 | with open(path_semantic, 'rb') as f: 228 | semantic = pickle.load(f) 229 | else: 230 | semantic = np.load(path_semantic).reshape(-1)[0] 231 | proxies = np.stack([semantic[c] for c in class_names]) 232 | return np.float32(proxies) 233 | -------------------------------------------------------------------------------- /fewshot.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import json 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import transforms 8 | from data import create_fewshot_splits 9 | from data import DataLoader, get_proxies 10 | from models import LinearProjection, ConvNet 11 | from models import ProxyNet, ProxyLoss 12 | from utils import get_semantic_fname, get_backbone 13 | from validate import extract_predict, retrieve 14 | 15 | 16 | # Training settings 17 | parser = argparse.ArgumentParser(description='PyTorch SBIR') 18 | parser.add_argument('--im-path', type=str, default='exp', metavar='ED', 19 | help='im model path') 20 | parser.add_argument('--sk-path', type=str, default='exp', metavar='ED', 21 | help='sk model path') 22 | parser.add_argument('--rewrite', action='store_true', default=False, 23 | help='Do not consider existing saved features') 24 | parser.add_argument('--mixing', action='store_true', default=False, 25 | help='Mix w2v with sk representations') 26 | parser.add_argument('--seed', type=int, default=1234, 27 | help='Seed for selecting the sketches') 28 | 29 | 30 | args = parser.parse_args() 31 | im_path = os.path.dirname(args.im_path) 32 | SEED = args.seed 33 | GROUPS = 500 34 | 35 | with open(os.path.join(im_path, 'config.json')) as f: 36 | tmp = json.load(f) 37 | 38 | tmp['im_model_path'] = args.im_path 39 | tmp['sk_model_path'] = args.sk_path 40 | tmp['rewrite'] = args.rewrite 41 | tmp['mixing'] = args.mixing 42 | args = type('parser', (object,), tmp) 43 | 44 | # get data splits 45 | df_dir = os.path.join('aux', 'data', args.dataset) 46 | splits = create_fewshot_splits(df_dir) 47 | 48 | 49 | def main(): 50 | # data normalization 51 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 52 | std=[0.229, 0.224, 0.225]) 53 | 54 | # data loaders 55 | kwargs = {'num_workers': 8, 'pin_memory': True} if args.cuda else {} 56 | 57 | test_transforms = transforms.Compose([ 58 | transforms.ToPILImage(), 59 | transforms.Resize((224, 224)), 60 | transforms.ToTensor(), 61 | normalize]) 62 | 63 | feats = {} 64 | labels = {} 65 | 66 | for domain in ['im', 'sk']: 67 | key = '_'.join([domain, 'model_path']) 68 | dirname = os.path.dirname(args.__dict__[key]) 69 | fpath = os.path.join(dirname, 'features.npz') 70 | 71 | results_path = os.path.join(dirname, 'results.txt') 72 | 73 | if os.path.isfile(fpath) and args.rewrite is False: 74 | data = np.load(fpath) 75 | feats[domain] = data['features'] 76 | labels[domain] = data['labels'] 77 | 78 | txt = ('Domain (%s): Acc %.2f' % (domain, data['acc'] * 100.)) 79 | print(txt) 80 | write_logs(txt, results_path) 81 | 82 | df_gal = splits[domain]['test'] 83 | fsem = get_semantic_fname(args.word) 84 | path_semantic = os.path.join('aux', 'Semantic', args.dataset, fsem) 85 | test_proxies = get_proxies( 86 | path_semantic, df_gal['cat'].cat.categories) 87 | else: 88 | df_gal = splits[domain]['test'] 89 | 90 | test_loader = torch.utils.data.DataLoader( 91 | DataLoader(df_gal, test_transforms, 92 | root=args.data_dir, mode=domain), 93 | batch_size=args.batch_size * 1, shuffle=False, **kwargs) 94 | 95 | # instanciate the models 96 | output_shape, backbone = get_backbone(args) 97 | embed = LinearProjection(output_shape, args.dim_embed) 98 | model = ConvNet(backbone, embed) 99 | 100 | # instanciate the proxies 101 | fsem = get_semantic_fname(args.word) 102 | path_semantic = os.path.join('aux', 'Semantic', args.dataset, fsem) 103 | test_proxies = get_proxies( 104 | path_semantic, df_gal['cat'].cat.categories) 105 | 106 | test_proxynet = ProxyNet(args.n_classes_gal, args.dim_embed, 107 | proxies=torch.from_numpy(test_proxies)) 108 | 109 | # criterion 110 | criterion = ProxyLoss(args.temperature) 111 | 112 | if args.multi_gpu: 113 | model = nn.DataParallel(model) 114 | 115 | # loading 116 | checkpoint = torch.load(args.__dict__[key]) 117 | model.load_state_dict(checkpoint['state_dict']) 118 | txt = ("\n=> loaded checkpoint '{}' (epoch {})" 119 | .format(args.__dict__[key], checkpoint['epoch'])) 120 | print(txt) 121 | 122 | if args.cuda: 123 | backbone.cuda() 124 | embed.cuda() 125 | model.cuda() 126 | test_proxynet.cuda() 127 | 128 | txt = 'Extracting testing set (%s)...' % (domain) 129 | print(txt) 130 | x, y, acc = extract_predict( 131 | test_loader, model, 132 | test_proxynet.proxies.weight, criterion) 133 | 134 | feats[domain] = x 135 | labels[domain] = y 136 | 137 | np.savez( 138 | fpath, 139 | features=feats[domain], labels=labels[domain], acc=acc) 140 | 141 | txt = ('Domain (%s): Acc %.2f' % (domain, acc * 100.)) 142 | print(txt) 143 | 144 | print('\nFew-Shot') 145 | fs(feats, labels, test_proxies) 146 | 147 | 148 | def fs(feats, labels, proxies): 149 | 150 | feats['sk'] = L2norm(feats['sk']) 151 | feats['im'] = L2norm(feats['im']) 152 | proxies = L2norm(proxies) 153 | 154 | # word vectors 155 | acc = classify(feats['im'], labels['im'], proxies) 156 | print('word vectors:\t\t%.2f' % (acc * 100)) 157 | 158 | # few-shot with sketches 159 | info = do_sk_fewshot(feats['sk'], labels['sk'], 160 | feats['im'], labels['im'], 161 | k=1, v=True) 162 | to_save(info, 'best_worst.npy') 163 | do_sk_fewshot(feats['sk'], labels['sk'], feats['im'], labels['im'], k=5) 164 | 165 | if args.mixing: 166 | do_sk_fewshot_mixer(feats['sk'], labels['sk'], 167 | feats['im'], labels['im'], 168 | proxies, k=1) 169 | 170 | do_sk_fewshot_mixer(feats['sk'], labels['sk'], 171 | feats['im'], labels['im'], 172 | proxies, k=5) 173 | 174 | # few-shot with images 175 | do_im_fewshot(feats['im'], labels['im']) 176 | do_im_fewshot(feats['im'], labels['im'], k=5) 177 | 178 | if args.mixing: 179 | info = do_im_fewshot_mixer( 180 | feats['im'], labels['im'], proxies, k=1, v=True) 181 | to_save(info, 'im_1_mixture.npy') 182 | info = do_im_fewshot_mixer( 183 | feats['im'], labels['im'], proxies, k=5, v=True) 184 | to_save(info, 'im_5_mixture.npy') 185 | 186 | 187 | def do_sk_fewshot(query_x, query_y, gallery_x, gallery_y, 188 | k=1, v=False): 189 | np.random.seed(SEED) 190 | 191 | best_acc = 0. 192 | worst_acc = 1. 193 | 194 | acc = [] 195 | for i in range(GROUPS): 196 | new_p, idx = shot_selector(query_x, query_y, k=k) 197 | acc.append(classify(gallery_x, gallery_y, new_p)) 198 | 199 | if acc[-1] > best_acc: 200 | best_acc = acc[-1] 201 | best_idx = list(idx) 202 | elif acc[-1] < worst_acc: 203 | worst_acc = acc[-1] 204 | worst_idx = list(idx) 205 | 206 | print('%d-shot (sketches):\t%.2f (+/- %.2f)' % 207 | (k, np.mean(acc) * 100, np.std(acc) * 100)) 208 | 209 | if v: 210 | return parse_samples((best_acc, best_idx), (worst_acc, worst_idx)) 211 | 212 | 213 | def parse_samples(best, worst): 214 | info = {} 215 | 216 | info['best'] = {} 217 | info['best']['acc'] = best[0] 218 | info['best']['idx'] = best[1] 219 | 220 | info['worst'] = {} 221 | info['worst']['acc'] = worst[0] 222 | info['worst']['idx'] = worst[1] 223 | 224 | return info 225 | 226 | 227 | def to_save(info, fname): 228 | dirname = os.path.dirname(args.__dict__['sk_model_path']) 229 | fpath = os.path.join(dirname, fname) 230 | np.save(fpath, info) 231 | 232 | 233 | def do_sk_fewshot_mixer(query_x, query_y, gallery_x, gallery_y, proxies, k=1): 234 | np.random.seed(SEED) 235 | acc = [] 236 | alpha = 0.7 237 | for i in range(GROUPS): 238 | new_p, _ = shot_selector_mixer( 239 | query_x, query_y, proxies, k=k, alpha=alpha) 240 | acc.append(classify(gallery_x, gallery_y, new_p)) 241 | print('(w. refinement) %d-shot (sketches):\t%.2f (+/- %.2f)' % 242 | (k, np.mean(acc) * 100, np.std(acc) * 100)) 243 | 244 | 245 | def do_im_fewshot(gallery_x, gallery_y, k=1): 246 | np.random.seed(SEED) 247 | acc = [] 248 | for i in range(GROUPS): 249 | new_p, idx = shot_selector(gallery_x, gallery_y, k=k) 250 | cond = np.isin(np.arange(len(gallery_y)), idx) 251 | acc.append(classify(gallery_x[~cond], gallery_y[~cond], new_p)) 252 | print('%d-shot (images):\t%.2f (+/- %.2f)' % 253 | (k, np.mean(acc) * 100, np.std(acc) * 100)) 254 | 255 | 256 | def do_im_fewshot_mixer(gallery_x, gallery_y, proxies, k=1, v=False): 257 | np.random.seed(SEED) 258 | acc = [] 259 | alpha = 0.7 260 | for i in range(GROUPS): 261 | new_p, idx = shot_selector_mixer( 262 | gallery_x, gallery_y, proxies, k=k, alpha=alpha) 263 | cond = np.isin(np.arange(len(gallery_y)), idx) 264 | acc.append(classify(gallery_x[~cond], gallery_y[~cond], new_p)) 265 | print('(w. refinement) %d-shot (images):\t%.2f (+/- %.2f)' % 266 | (k, np.mean(acc) * 100, np.std(acc) * 100)) 267 | 268 | 269 | def classify(feats, labels, proxies): 270 | idx = retrieve(feats, proxies) 271 | acc = np.mean(idx[:, 0] == labels) 272 | return acc 273 | 274 | 275 | def shot_selector(feats, labels, k=1): 276 | vv = np.unique(labels) 277 | proxies = [] 278 | all_idx = [] 279 | for v in vv: 280 | to_select = np.argwhere(labels == v).squeeze() 281 | idx = np.random.choice(to_select, size=k, replace=False) 282 | proxies.append(np.mean(feats[idx, :], axis=0)) 283 | all_idx.extend(idx) 284 | return np.asarray(proxies), np.asarray(all_idx) 285 | 286 | 287 | def shot_selector_mixer(feats, labels, proxies, k=1, alpha=0.5): 288 | vv = np.unique(labels) 289 | new_proxies = [] 290 | all_idx = [] 291 | for v in vv: 292 | to_select = np.argwhere(labels == v).squeeze() 293 | idx = np.random.choice(to_select, size=k, replace=False) 294 | if alpha == 0: 295 | new_proxies.append(np.mean(feats[idx, :], axis=0)) 296 | elif alpha == 1: 297 | new_proxies.append(np.mean(proxies[v, :][None, :], axis=0)) 298 | else: 299 | if k == 1: 300 | second = feats[idx, :] 301 | else: 302 | second = L2norm(np.mean(feats[idx, :], axis=0)[None, :]) 303 | tmp = np.concatenate((alpha * proxies[v, :][None, :], 304 | (1 - alpha) * second)) 305 | new_proxies.append(np.mean(tmp, axis=0)) 306 | 307 | all_idx.extend(idx) 308 | return np.asarray(new_proxies), np.asarray(all_idx) 309 | 310 | 311 | def L2norm(x): 312 | return x / np.linalg.norm(x, axis=1)[:, None] 313 | 314 | 315 | def write_logs(txt, logpath): 316 | with open(logpath, 'a') as f: 317 | f.write('\n') 318 | f.write(txt) 319 | 320 | 321 | if __name__ == '__main__': 322 | main() 323 | -------------------------------------------------------------------------------- /imgs/problem.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twuilliam/open-search/5f74e3de5552a185e5d13d706bb3a9322606e704/imgs/problem.png -------------------------------------------------------------------------------- /imgs/setup.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twuilliam/open-search/5f74e3de5552a185e5d13d706bb3a9322606e704/imgs/setup.png -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | """Information Retrieval metrics 2 | (from https://gist.github.com/bwhite/3726239) 3 | 4 | Useful Resources: 5 | http://www.cs.utexas.edu/~mooney/ir-course/slides/Evaluation.ppt 6 | http://www.nii.ac.jp/TechReports/05-014E.pdf 7 | http://www.stanford.edu/class/cs276/handouts/EvaluationNew-handout-6-per.pdf 8 | http://hal.archives-ouvertes.fr/docs/00/72/67/60/PDF/07-busa-fekete.pdf 9 | Learning to Rank for Information Retrieval (Tie-Yan Liu) 10 | """ 11 | import numpy as np 12 | 13 | 14 | def mean_reciprocal_rank(rs): 15 | """Score is reciprocal of the rank of the first relevant item 16 | 17 | First element is 'rank 1'. Relevance is binary (nonzero is relevant). 18 | 19 | Example from http://en.wikipedia.org/wiki/Mean_reciprocal_rank 20 | >>> rs = [[0, 0, 1], [0, 1, 0], [1, 0, 0]] 21 | >>> mean_reciprocal_rank(rs) 22 | 0.61111111111111105 23 | >>> rs = np.array([[0, 0, 0], [0, 1, 0], [1, 0, 0]]) 24 | >>> mean_reciprocal_rank(rs) 25 | 0.5 26 | >>> rs = [[0, 0, 0, 1], [1, 0, 0], [1, 0, 0]] 27 | >>> mean_reciprocal_rank(rs) 28 | 0.75 29 | 30 | Args: 31 | rs: Iterator of relevance scores (list or numpy) in rank order 32 | (first element is the first item) 33 | 34 | Returns: 35 | Mean reciprocal rank 36 | """ 37 | rs = (np.asarray(r).nonzero()[0] for r in rs) 38 | return np.mean([1. / (r[0] + 1) if r.size else 0. for r in rs]) 39 | 40 | 41 | def r_precision(r): 42 | """Score is precision after all relevant documents have been retrieved 43 | 44 | Relevance is binary (nonzero is relevant). 45 | 46 | >>> r = [0, 0, 1] 47 | >>> r_precision(r) 48 | 0.33333333333333331 49 | >>> r = [0, 1, 0] 50 | >>> r_precision(r) 51 | 0.5 52 | >>> r = [1, 0, 0] 53 | >>> r_precision(r) 54 | 1.0 55 | 56 | Args: 57 | r: Relevance scores (list or numpy) in rank order 58 | (first element is the first item) 59 | 60 | Returns: 61 | R Precision 62 | """ 63 | r = np.asarray(r) != 0 64 | z = r.nonzero()[0] 65 | if not z.size: 66 | return 0. 67 | return np.mean(r[:z[-1] + 1]) 68 | 69 | 70 | def precision_at_k(r, k): 71 | """Score is precision @ k 72 | 73 | Relevance is binary (nonzero is relevant). 74 | 75 | >>> r = [0, 0, 1] 76 | >>> precision_at_k(r, 1) 77 | 0.0 78 | >>> precision_at_k(r, 2) 79 | 0.0 80 | >>> precision_at_k(r, 3) 81 | 0.33333333333333331 82 | >>> precision_at_k(r, 4) 83 | Traceback (most recent call last): 84 | File "", line 1, in ? 85 | ValueError: Relevance score length < k 86 | 87 | 88 | Args: 89 | r: Relevance scores (list or numpy) in rank order 90 | (first element is the first item) 91 | 92 | Returns: 93 | Precision @ k 94 | 95 | Raises: 96 | ValueError: len(r) must be >= k 97 | """ 98 | assert k >= 1 99 | r = np.asarray(r)[:k] != 0 100 | if r.size != k: 101 | raise ValueError('Relevance score length < k') 102 | return np.mean(r) 103 | 104 | 105 | def average_precision(r): 106 | """Score is average precision (area under PR curve) 107 | 108 | Relevance is binary (nonzero is relevant). 109 | 110 | >>> r = [1, 1, 0, 1, 0, 1, 0, 0, 0, 1] 111 | >>> delta_r = 1. / sum(r) 112 | >>> sum([sum(r[:x + 1]) / (x + 1.) * delta_r for x, y in enumerate(r) if y]) 113 | 0.7833333333333333 114 | >>> average_precision(r) 115 | 0.78333333333333333 116 | 117 | Args: 118 | r: Relevance scores (list or numpy) in rank order 119 | (first element is the first item) 120 | 121 | Returns: 122 | Average precision 123 | """ 124 | r = np.asarray(r) != 0 125 | out = [precision_at_k(r, k + 1) for k in range(r.size) if r[k]] 126 | if not out: 127 | return 0. 128 | return np.mean(out) 129 | 130 | 131 | def mean_average_precision(rs): 132 | """Score is mean average precision 133 | 134 | Relevance is binary (nonzero is relevant). 135 | 136 | >>> rs = [[1, 1, 0, 1, 0, 1, 0, 0, 0, 1]] 137 | >>> mean_average_precision(rs) 138 | 0.78333333333333333 139 | >>> rs = [[1, 1, 0, 1, 0, 1, 0, 0, 0, 1], [0]] 140 | >>> mean_average_precision(rs) 141 | 0.39166666666666666 142 | 143 | Args: 144 | rs: Iterator of relevance scores (list or numpy) in rank order 145 | (first element is the first item) 146 | 147 | Returns: 148 | Mean average precision 149 | """ 150 | return np.mean([average_precision(r) for r in rs]) 151 | 152 | 153 | def dcg_at_k(r, k, method=0): 154 | """Score is discounted cumulative gain (dcg) 155 | 156 | Relevance is positive real values. Can use binary 157 | as the previous methods. 158 | 159 | Example from 160 | http://www.stanford.edu/class/cs276/handouts/EvaluationNew-handout-6-per.pdf 161 | >>> r = [3, 2, 3, 0, 0, 1, 2, 2, 3, 0] 162 | >>> dcg_at_k(r, 1) 163 | 3.0 164 | >>> dcg_at_k(r, 1, method=1) 165 | 3.0 166 | >>> dcg_at_k(r, 2) 167 | 5.0 168 | >>> dcg_at_k(r, 2, method=1) 169 | 4.2618595071429155 170 | >>> dcg_at_k(r, 10) 171 | 9.6051177391888114 172 | >>> dcg_at_k(r, 11) 173 | 9.6051177391888114 174 | 175 | Args: 176 | r: Relevance scores (list or numpy) in rank order 177 | (first element is the first item) 178 | k: Number of results to consider 179 | method: If 0 then weights are [1.0, 1.0, 0.6309, 0.5, 0.4307, ...] 180 | If 1 then weights are [1.0, 0.6309, 0.5, 0.4307, ...] 181 | 182 | Returns: 183 | Discounted cumulative gain 184 | """ 185 | r = np.asfarray(r)[:k] 186 | if r.size: 187 | if method == 0: 188 | return r[0] + np.sum(r[1:] / np.log2(np.arange(2, r.size + 1))) 189 | elif method == 1: 190 | return np.sum(r / np.log2(np.arange(2, r.size + 2))) 191 | else: 192 | raise ValueError('method must be 0 or 1.') 193 | return 0. 194 | 195 | 196 | def ndcg_at_k(r, k, method=0): 197 | """Score is normalized discounted cumulative gain (ndcg) 198 | 199 | Relevance is positive real values. Can use binary 200 | as the previous methods. 201 | 202 | Example from 203 | http://www.stanford.edu/class/cs276/handouts/EvaluationNew-handout-6-per.pdf 204 | >>> r = [3, 2, 3, 0, 0, 1, 2, 2, 3, 0] 205 | >>> ndcg_at_k(r, 1) 206 | 1.0 207 | >>> r = [2, 1, 2, 0] 208 | >>> ndcg_at_k(r, 4) 209 | 0.9203032077642922 210 | >>> ndcg_at_k(r, 4, method=1) 211 | 0.96519546960144276 212 | >>> ndcg_at_k([0], 1) 213 | 0.0 214 | >>> ndcg_at_k([1], 2) 215 | 1.0 216 | 217 | Args: 218 | r: Relevance scores (list or numpy) in rank order 219 | (first element is the first item) 220 | k: Number of results to consider 221 | method: If 0 then weights are [1.0, 1.0, 0.6309, 0.5, 0.4307, ...] 222 | If 1 then weights are [1.0, 0.6309, 0.5, 0.4307, ...] 223 | 224 | Returns: 225 | Normalized discounted cumulative gain 226 | """ 227 | dcg_max = dcg_at_k(sorted(r, reverse=True), k, method) 228 | if not dcg_max: 229 | return 0. 230 | return dcg_at_k(r, k, method) / dcg_max 231 | 232 | 233 | if __name__ == "__main__": 234 | import doctest 235 | doctest.testmod() 236 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | from torchvision import models 6 | from utils import cosine_similarity 7 | 8 | 9 | class VGG16(nn.Module): 10 | def __init__(self, pretrained=True): 11 | super(VGG16, self).__init__() 12 | model = models.vgg16(pretrained=pretrained) 13 | self.features = model.features 14 | layers = list(model.classifier.children())[:-1] 15 | self.classifier = nn.Sequential(*layers) 16 | 17 | def forward(self, x): 18 | # from 224x224 to 4096 19 | x = self.features(x) 20 | x = self.classifier(x.view(x.size(0), -1)) 21 | return x 22 | 23 | 24 | class VGG19(nn.Module): 25 | def __init__(self, pretrained=True): 26 | super(VGG19, self).__init__() 27 | model = models.vgg19(pretrained=pretrained) 28 | self.features = model.features 29 | layers = list(model.classifier.children())[:-1] 30 | self.classifier = nn.Sequential(*layers) 31 | 32 | def forward(self, x): 33 | # from 224x224 to 4096 34 | x = self.features(x) 35 | x = self.classifier(x.view(x.size(0), -1)) 36 | return x 37 | 38 | 39 | class ResNet50(nn.Module): 40 | def __init__(self, pretrained=True): 41 | super(ResNet50, self).__init__() 42 | model = models.resnet50(pretrained=pretrained) 43 | layers = list(model.children())[:-1] 44 | self.model = nn.Sequential(*layers) 45 | 46 | def forward(self, x): 47 | # from 224x224 to 2048 48 | x = self.model(x) 49 | return x.view(x.size(0), -1) 50 | 51 | def logits(self, x): 52 | return self.last_layer(x) 53 | 54 | 55 | class SEResNet50(nn.Module): 56 | def __init__(self, pretrained=True): 57 | super(SEResNet50, self).__init__() 58 | import pretrainedmodels 59 | if pretrained: 60 | model = pretrainedmodels.se_resnet50() 61 | else: 62 | model = pretrainedmodels.se_resnet50(pretrained=None) 63 | layers = list(model.children())[:-1] 64 | self.model = nn.Sequential(*layers) 65 | 66 | def forward(self, x): 67 | # from 224x224 to 2048 68 | x = self.model(x) 69 | return x.view(x.size(0), -1) 70 | 71 | 72 | class LinearProjection(nn.Module): 73 | '''Linear projection''' 74 | def __init__(self, n_in, n_out): 75 | super(LinearProjection, self).__init__() 76 | self.fc_embed = nn.Linear(n_in, n_out, bias=True) 77 | self.bn1d = nn.BatchNorm1d(n_out) 78 | self._init_params() 79 | 80 | def forward(self, x): 81 | x = self.fc_embed(x) 82 | x = self.bn1d(x) 83 | return x 84 | 85 | def _init_params(self): 86 | nn.init.xavier_normal(self.fc_embed.weight) 87 | nn.init.constant(self.fc_embed.bias, 0) 88 | nn.init.constant(self.bn1d.weight, 1) 89 | nn.init.constant(self.bn1d.bias, 0) 90 | 91 | 92 | class ConvNet(nn.Module): 93 | def __init__(self, backbone, embedding): 94 | super(ConvNet, self).__init__() 95 | self.backbone = backbone 96 | self.embedding = embedding 97 | 98 | def forward(self, x): 99 | x = self.backbone(x) 100 | x = self.embedding(x) 101 | return x 102 | 103 | 104 | class ProxyNet(nn.Module): 105 | """ProxyNet""" 106 | def __init__(self, n_classes, dim, 107 | proxies=None, L2=False): 108 | super(ProxyNet, self).__init__() 109 | self.n_classes = n_classes 110 | self.dim = dim 111 | 112 | self.proxies = nn.Embedding(n_classes, dim, 113 | scale_grad_by_freq=False) 114 | 115 | if proxies is None: 116 | self.proxies.weight = nn.Parameter( 117 | torch.randn(self.n_classes, self.dim), 118 | requires_grad=True) 119 | else: 120 | self.proxies.weight = nn.Parameter(proxies, requires_grad=False) 121 | 122 | if L2: 123 | self.normalize_proxies() 124 | 125 | def normalize_proxies(self): 126 | norm = self.proxies.weight.data.norm(p=2, dim=1)[:, None] 127 | self.proxies.weight.data = self.proxies.weight.data / norm 128 | 129 | def forward(self, y_true): 130 | proxies_y_true = self.proxies(Variable(y_true)) 131 | return proxies_y_true 132 | 133 | 134 | class ProxyLoss(nn.Module): 135 | def __init__(self, temperature=1.): 136 | super(ProxyLoss, self).__init__() 137 | 138 | self.temperature = temperature 139 | 140 | def forward(self, x, y, proxies): 141 | """Proxy loss 142 | 143 | Arguments: 144 | x (Tensor): batch of features 145 | y (LongTensor): corresponding instance 146 | """ 147 | loss = self.softmax_embedding_loss(x, y, proxies) 148 | 149 | preds = self.predict(x, proxies) 150 | 151 | acc = (y == preds).type(torch.FloatTensor).mean() 152 | 153 | return loss.mean(), acc 154 | 155 | def softmax_embedding_loss(self, x, y, proxies): 156 | idx = torch.from_numpy(np.arange(len(x), dtype=np.int)).cuda() 157 | diff_iZ = cosine_similarity(x, proxies) 158 | 159 | numerator_ip = torch.exp(diff_iZ[idx, y] / self.temperature) 160 | denominator_ip = torch.exp(diff_iZ / self.temperature).sum(1) + 1e-8 161 | return - torch.log(numerator_ip / denominator_ip) 162 | 163 | def classify(self, x, proxies): 164 | idx = torch.from_numpy(np.arange(len(x), dtype=np.int)).cuda() 165 | diff_iZ = cosine_similarity(x, proxies) 166 | 167 | numerator_ip = torch.exp(diff_iZ[idx, :] / self.temperature) 168 | denominator_ip = torch.exp(diff_iZ / self.temperature).sum(1) + 1e-8 169 | 170 | probs = numerator_ip / denominator_ip[:, None] 171 | return probs 172 | 173 | def predict(self, x, proxies): 174 | probs = self.classify(x, proxies) 175 | return probs.max(1)[1].data 176 | -------------------------------------------------------------------------------- /retrieve-any.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import itertools 4 | import json 5 | import numpy as np 6 | import pandas as pd 7 | import matplotlib 8 | matplotlib.use('agg') 9 | import matplotlib.pylab as plt 10 | from validate import L2norm 11 | from validate import retrieve, KNN, score 12 | from utils import heatmap, annotate_heatmap 13 | 14 | 15 | # Training settings 16 | parser = argparse.ArgumentParser(description='PyTorch SBIR') 17 | parser.add_argument('--dir-path', type=str, default='exp', metavar='ED', 18 | help='directory with domainnet models') 19 | 20 | args = parser.parse_args() 21 | 22 | GROUPS = 500 23 | SEED = 1234 24 | 25 | 26 | def get_config(): 27 | # iterate over folders 28 | configs = {} 29 | for path in os.listdir(args.dir_path): 30 | fname = os.path.join(args.dir_path, path, 'config.json') 31 | if os.path.isfile(fname): 32 | 33 | with open(fname) as f: 34 | tmp = json.load(f) 35 | 36 | if tmp['mode'] == 'im': 37 | configs[tmp['domain']] = tmp 38 | configs[tmp['domain']]['working_path'] = os.path.join( 39 | args.dir_path, path) 40 | else: 41 | configs['quickdraw'] = tmp 42 | configs['quickdraw']['working_path'] = os.path.join( 43 | args.dir_path, path) 44 | return configs 45 | 46 | 47 | def any2any_retrieval(configs): 48 | keys = configs.keys() 49 | keys.sort() 50 | 51 | res = {} 52 | for k in keys: 53 | res[k] = {} 54 | for j in keys: 55 | res[k][j] = 0. 56 | 57 | for (source, target) in itertools.combinations(keys, 2): 58 | feats = {} 59 | labels = {} 60 | 61 | for domain in [source, target]: 62 | dirname = configs[domain]['working_path'] 63 | fpath = os.path.join(dirname, 'features.npz') 64 | 65 | data = np.load(fpath) 66 | 67 | feats[domain] = {} 68 | labels[domain] = {} 69 | 70 | feats[domain] = data['features'] 71 | labels[domain] = data['labels'] 72 | 73 | if res[source][source] == 0: 74 | print('\nRetrieval from %s to %s' % (source, source)) 75 | tmp = cross_domain_retrieval( 76 | feats[source], labels[source], 77 | feats[source], labels[source], 78 | zeroshot=configs[source]['overwrite']) 79 | res[source][source] = tmp 80 | 81 | if res[target][target] == 0: 82 | print('\nRetrieval from %s to %s' % (target, target)) 83 | tmp = cross_domain_retrieval( 84 | feats[target], labels[target], 85 | feats[target], labels[target], 86 | zeroshot=configs[source]['overwrite']) 87 | res[target][target] = tmp 88 | 89 | print('\nRetrieval from %s to %s' % (source, target)) 90 | tmp = cross_domain_retrieval( 91 | feats[source], labels[source], 92 | feats[target], labels[target], 93 | zeroshot=configs[source]['overwrite']) 94 | res[source][target] = tmp 95 | 96 | print('\nRetrieval from %s to %s' % (target, source)) 97 | tmp = cross_domain_retrieval( 98 | feats[target], labels[target], 99 | feats[source], labels[source], 100 | zeroshot=configs[source]['overwrite']) 101 | res[target][source] = tmp 102 | 103 | # col: source, row: target 104 | df = pd.DataFrame(res) 105 | df.to_csv(os.path.join(args.dir_path, 'res.csv')) 106 | 107 | plot_heatmap(df, os.path.join(args.dir_path, 'res.pdf')) 108 | 109 | 110 | def cross_domain_retrieval(x_src, y_src, x_tgt, y_tgt, zeroshot=False): 111 | mAP, prec = evaluate(x_tgt, y_tgt, x_src, y_src) 112 | txt = ('mAP@all: %.04f Prec@100: %.04f\t' % (mAP, prec)) 113 | print(txt) 114 | 115 | # perform refinement 116 | g_src_x = KNN(x_src, x_tgt, K=1, mode='ones') 117 | 118 | if zeroshot: 119 | alpha = 0.7 120 | else: 121 | alpha = 0.4 122 | new_src_x = slerp(alpha, L2norm(x_src), L2norm(g_src_x)) 123 | mAP, prec = evaluate(x_tgt, y_tgt, new_src_x, y_src) 124 | txt = ('(w. refinement) mAP@all: %.04f Prec@100: %.04f\t' % (mAP, prec)) 125 | print(txt) 126 | 127 | return mAP 128 | 129 | 130 | def evaluate(im_x, im_y, sk_x, sk_y, return_idx=False): 131 | idx = retrieve(sk_x, im_x) 132 | if np.array_equal(sk_x, L2norm(im_x)) or np.array_equal(sk_x, im_x): 133 | idx = idx[:, 1:] 134 | prec, mAP = score(sk_y, im_y, idx) 135 | if return_idx: 136 | return mAP, prec, idx 137 | else: 138 | return mAP, prec 139 | 140 | 141 | def slerp(val, low, high): 142 | """Spherical interpolation. val has a range of 0 to 1.""" 143 | if val <= 0: 144 | return low 145 | elif val >= 1: 146 | return high 147 | elif np.allclose(low, high): 148 | return low 149 | omega = np.arccos(np.einsum('ij, ij->i', low, high)) 150 | so = np.sin(omega) 151 | return (np.sin((1.0-val)*omega) / so)[:, None] * low + (np.sin(val*omega)/so)[:, None] * high 152 | 153 | 154 | def plot_heatmap(df, path, 155 | vmin=0, vmax=1, nticks=11, 156 | digits="{x:.3f}", cmap="viridis", 157 | metric="mAP@all"): 158 | fig, ax = plt.subplots() 159 | 160 | arr = [0, 1, 2, 5, 4, 3] 161 | val = df.values[arr, :][:, arr] 162 | leg = ['clipart', 'infograph', 'painting', 'pencil', 'photo', 'sketch'] 163 | 164 | im, cbar = heatmap(val, leg, leg, ax=ax, 165 | cmap=cmap, cbarlabel=metric, 166 | vmax=0.7, 167 | cbar_kw={'boundaries': np.linspace(vmin, vmax, nticks)}) 168 | texts = annotate_heatmap(im, valfmt=digits, textcolors=["white", "black"]) 169 | 170 | fig.savefig(path, bbox_inches='tight', pad_inches=0, dpi=300) 171 | 172 | 173 | if __name__ == '__main__': 174 | configs = get_config() 175 | fname = os.path.join(args.dir_path, 'res.csv') 176 | if os.path.isfile(fname): 177 | df = pd.read_csv(fname, index_col=0) 178 | plot_heatmap(df, os.path.join(args.dir_path, 'res.pdf'), 179 | vmax=0.7, nticks=8) 180 | else: 181 | any2any_retrieval(configs) 182 | -------------------------------------------------------------------------------- /retrieve-many.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import json 4 | import faiss 5 | import numpy as np 6 | import pandas as pd 7 | from data import create_multi_splits 8 | from validate import L2norm 9 | from validate import retrieve, KNN, score 10 | 11 | 12 | # Training settings 13 | parser = argparse.ArgumentParser(description='PyTorch SBIR') 14 | parser.add_argument('--dir-path', type=str, default='exp', metavar='ED', 15 | help='directory with domainnet models') 16 | parser.add_argument('--new-data-path', type=str, default='', metavar='ED', 17 | help='overwrite data path') 18 | parser.add_argument('--eval', type=str, required=True, metavar='ED', 19 | help='many2any|any2many') 20 | args = parser.parse_args() 21 | 22 | GROUPS = 500 23 | SEED = 1234 24 | 25 | 26 | def get_config(): 27 | # iterate over folders in the directory 28 | configs = {} 29 | for path in os.listdir(args.dir_path): 30 | fname = os.path.join(args.dir_path, path, 'config.json') 31 | if os.path.isfile(fname): 32 | 33 | with open(fname) as f: 34 | tmp = json.load(f) 35 | 36 | if tmp['mode'] == 'im': 37 | configs[tmp['domain']] = tmp 38 | configs[tmp['domain']]['working_path'] = os.path.join( 39 | args.dir_path, path) 40 | 41 | if not args.new_data_path == '': 42 | configs[tmp['domain']]['data_dir'] = args.new_data_path 43 | else: 44 | configs['quickdraw'] = tmp 45 | configs['quickdraw']['working_path'] = os.path.join( 46 | args.dir_path, path) 47 | if not args.new_data_path == '': 48 | configs['quickdraw']['data_dir'] = args.new_data_path 49 | 50 | return configs 51 | 52 | 53 | def get_splits(configs): 54 | keys = configs.keys() 55 | keys.sort() 56 | 57 | fpaths = [] 58 | domains = [] 59 | y = [] 60 | for key in keys: 61 | # get data splits 62 | df_dir = os.path.join('aux', 'data', configs[key]['dataset']) 63 | splits = create_multi_splits(df_dir, configs[key]['domain']) 64 | if key == 'quickdraw': 65 | fpaths.extend(splits['sk']['test'].index.values) 66 | domains.extend(splits['sk']['test']['domain'].values) 67 | y.extend(splits['sk']['test']['cat'].values) 68 | else: 69 | fpaths.extend(splits['im']['test'].index.values) 70 | domains.extend(splits['im']['test']['domain'].values) 71 | y.extend(splits['im']['test']['cat'].values) 72 | df = pd.DataFrame({'domain': domains, 'cat': y}, index=fpaths) 73 | return df 74 | 75 | 76 | def read_data(fpath): 77 | data = np.load(fpath) 78 | return data['features'], data['labels'] 79 | 80 | 81 | def mix_queries(base, complement, alpha=0.5): 82 | idx = sample_complement(base['y'], complement['y']) 83 | mixture = alpha * base['x'] + (1-alpha) * complement['x'][idx, :] 84 | return mixture, idx 85 | 86 | 87 | def sample_complement(y_base, y_complement): 88 | np.random.seed(SEED) 89 | idx = [] 90 | for y in y_base: 91 | cond_idx = np.argwhere(y_complement == y).squeeze() 92 | idx.append(np.random.choice(cond_idx)) 93 | return idx 94 | 95 | 96 | def many2any_retrieval(configs, sources=['quickdraw', 'real']): 97 | keys = configs.keys() 98 | keys.sort() 99 | 100 | source_data = {} 101 | 102 | for domain in sources: 103 | dirname = configs[domain]['working_path'] 104 | fpath = os.path.join(dirname, 'features.npz') 105 | 106 | x, y = read_data(fpath) 107 | 108 | source_data[domain] = {} 109 | source_data[domain]['x'] = x 110 | source_data[domain]['y'] = y 111 | 112 | # save images that have been mixed, such that they don't get retrived 113 | x_src, idx = mix_queries(source_data[sources[0]], source_data[sources[1]]) 114 | y_src = source_data[sources[0]]['y'] 115 | np.save('plop.npy', idx) 116 | 117 | res = {} 118 | for domain in keys: 119 | dirname = configs[domain]['working_path'] 120 | fpath = os.path.join(dirname, 'features.npz') 121 | 122 | x_tgt, y_tgt = read_data(fpath) 123 | 124 | if sources[0] == domain and sources[1] == domain: 125 | pass 126 | else: 127 | print('\nRetrieval from %s+%s to %s' % 128 | (sources[0], sources[1], domain)) 129 | 130 | if domain == sources[1]: 131 | do_mixture = True 132 | else: 133 | do_mixture = False 134 | 135 | tmp = cross_domain_retrieval( 136 | x_src, y_src, x_tgt, y_tgt, 137 | zeroshot=configs[domain]['overwrite'], 138 | mixture=do_mixture) 139 | res[domain] = tmp 140 | 141 | os.remove('plop.npy') 142 | 143 | 144 | def get_data(configs): 145 | keys = configs.keys() 146 | keys.sort() 147 | 148 | feats = [] 149 | labels = [] 150 | domains = [] 151 | for i, key in enumerate(keys): 152 | dirname = configs[key]['working_path'] 153 | fpath = os.path.join(dirname, 'features.npz') 154 | 155 | data = np.load(fpath) 156 | nsamples = len(data['labels']) 157 | 158 | feats.extend(data['features']) 159 | labels.extend(data['labels']) 160 | domains.extend([key] * nsamples) 161 | 162 | return feats, labels, domains 163 | 164 | 165 | def one2many_retrieve_intent_aware(feats, labels, domains, splits, 166 | source='quickdraw', 167 | zeroshot=False): 168 | cond = np.asarray(domains) == source 169 | 170 | x_src = np.asarray(feats)[cond, :] 171 | y_src = np.asarray(labels)[cond] 172 | x_tgt = np.asarray(feats)[~cond, :] 173 | y_tgt = np.asarray(labels)[~cond] 174 | 175 | d_tgt = np.asarray(domains)[~cond] 176 | 177 | # KNN 178 | g_src_x = KNN(x_src, x_tgt, K=1, mode='ones') 179 | 180 | if zeroshot: 181 | alpha = 0.7 182 | else: 183 | alpha = 0.4 184 | x_src = slerp(alpha, L2norm(x_src), L2norm(g_src_x)) 185 | 186 | idx = myretrieve(x_src, x_tgt, topK=100) 187 | 188 | yd_tgt = np.char.add(y_tgt.astype(d_tgt.dtype), d_tgt) 189 | 190 | domains = np.unique(d_tgt) 191 | categories = np.unique(y_tgt) 192 | 193 | # compute occurrences of every category per domain 194 | occ = [] 195 | for d in domains: 196 | occ_inner = [] 197 | for c in categories: 198 | cond = np.logical_and(d_tgt == d, y_tgt == c) 199 | occ_inner.append(np.sum(cond)) 200 | occ.append(occ_inner) 201 | occ = np.asarray(occ, dtype=np.float) 202 | 203 | # normalize occurences 204 | occ /= np.sum(occ, axis=0) 205 | 206 | import multiprocessing as mp 207 | from metrics import average_precision 208 | 209 | # compute intent-aware mAP per domain 210 | mAP_ia = [] 211 | for d in domains: 212 | yd_src = np.char.add(y_src.astype(d_tgt.dtype), d) 213 | res = np.char.equal(yd_tgt[idx], yd_src[:, None]) 214 | pool = mp.Pool(processes=10) 215 | results = [pool.apply_async(average_precision, args=(r,)) for r in res] 216 | mAP = np.asarray([p.get() for p in results]) 217 | pool.close() 218 | 219 | mAP_ia.append(mAP) 220 | 221 | print('%s: %.3f' % (d, np.mean(mAP))) 222 | mAP_ia = np.asarray(mAP_ia) 223 | 224 | mAP_ia_final = (occ[:, y_src] * mAP_ia).sum(0).mean() 225 | print('mAP-IA: %.3f' % mAP_ia_final) 226 | 227 | return idx 228 | 229 | 230 | def cross_domain_retrieval(x_src, y_src, x_tgt, y_tgt, 231 | zeroshot=False, mixture=False): 232 | mAP, prec = evaluate(x_tgt, y_tgt, x_src, y_src, mixture=mixture) 233 | 234 | txt = ('mAP@all: %.04f Prec@100: %.04f\t' % (mAP, prec)) 235 | print(txt) 236 | 237 | g_src_x = KNN(x_src, x_tgt, K=1, mode='ones') 238 | 239 | if zeroshot: 240 | alpha = 0.7 241 | else: 242 | alpha = 0.4 243 | new_src_x = slerp(alpha, L2norm(x_src), L2norm(g_src_x)) 244 | mAP, prec = evaluate(x_tgt, y_tgt, new_src_x, y_src, mixture=mixture) 245 | txt = ('mAP@all: %.04f Prec@100: %.04f\t' % (mAP, prec)) 246 | tmp = '(w. refinement)' % alpha 247 | txt = tmp + ' ' + txt 248 | print(txt) 249 | 250 | return mAP 251 | 252 | 253 | def evaluate(im_x, im_y, sk_x, sk_y, K=False, return_idx=False, mixture=False): 254 | if not K: 255 | idx = retrieve(sk_x, im_x) 256 | else: 257 | idx = myretrieve(sk_x, im_x, topK=K) 258 | 259 | if mixture: 260 | selection = np.load('plop.npy') 261 | rows, cols = idx.shape 262 | idx = idx[idx != selection[:, None]].reshape(rows, -1) 263 | 264 | prec, mAP = score(sk_y, im_y, idx) 265 | if return_idx: 266 | return mAP, prec, idx 267 | else: 268 | return mAP, prec 269 | 270 | 271 | def myretrieve(query, gallery, dist='euc', L2=True, topK=101): 272 | d = query.shape[1] 273 | if dist == 'euc': 274 | index_flat = faiss.IndexFlatL2(d) 275 | elif dist == 'cos': 276 | index_flat = faiss.IndexFlatIP(d) 277 | 278 | if L2: 279 | query = L2norm(query) 280 | gallery = L2norm(gallery) 281 | 282 | index_flat.add(gallery) 283 | D, I = index_flat.search(query, topK) 284 | return I 285 | 286 | 287 | def get_stats(splits): 288 | domains = splits['domain'].unique() 289 | categories = splits['cat'].unique() 290 | 291 | stats = {} 292 | for c in categories: 293 | stats[c] = {} 294 | total = 0. 295 | for d in domains: 296 | cond = np.logical_and(splits['domain'] == d, splits['cat'] == c) 297 | stats[c][d] = np.sum(cond) 298 | total += np.sum(cond) 299 | for d in domains: 300 | stats[c][d] /= total 301 | return stats 302 | 303 | 304 | def slerp(val, low, high): 305 | """Spherical interpolation. val has a range of 0 to 1.""" 306 | if val <= 0: 307 | return low 308 | elif val >= 1: 309 | return high 310 | elif np.allclose(low, high): 311 | return low 312 | omega = np.arccos(np.einsum('ij, ij->i', low, high)) 313 | so = np.sin(omega) 314 | return (np.sin((1.0-val)*omega) / so)[:, None] * low + (np.sin(val*omega)/so)[:, None] * high 315 | 316 | 317 | if __name__ == '__main__': 318 | configs = get_config() 319 | if args.eval == 'many2any': 320 | many2any_retrieval(configs, sources=['quickdraw', 'quickdraw']) 321 | many2any_retrieval(configs, sources=['quickdraw', 'infograph']) 322 | many2any_retrieval(configs) 323 | 324 | many2any_retrieval(configs, sources=['clipart', 'clipart']) 325 | many2any_retrieval(configs, sources=['clipart', 'quickdraw']) 326 | many2any_retrieval(configs, sources=['clipart', 'infograph']) 327 | 328 | many2any_retrieval(configs, sources=['real', 'real']) 329 | many2any_retrieval(configs, sources=['real', 'quickdraw']) 330 | many2any_retrieval(configs, sources=['real', 'infograph']) 331 | 332 | elif args.eval == 'any2many': 333 | feats, labels, domains = get_data(configs) 334 | splits = get_splits(configs) 335 | 336 | one2many_retrieve_intent_aware(feats, labels, domains, splits) 337 | -------------------------------------------------------------------------------- /retrieve.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import json 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import transforms 8 | from data import create_splits, create_shape_splits, create_multi_splits 9 | from data import DataLoader, get_proxies 10 | from models import LinearProjection, ConvNet 11 | from models import ProxyNet, ProxyLoss 12 | from utils import get_semantic_fname, get_backbone 13 | from validate import extract_predict 14 | from validate import L2norm 15 | from validate import retrieve, KNN, score, score_shape 16 | 17 | 18 | # Training settings 19 | parser = argparse.ArgumentParser(description='PyTorch SBIR') 20 | parser.add_argument('--im-path', type=str, default='exp', metavar='ED', 21 | help='im model path') 22 | parser.add_argument('--sk-path', type=str, default='exp', metavar='ED', 23 | help='sk model path') 24 | parser.add_argument('--new-data-path', type=str, default='', metavar='ED', 25 | help='overwrite the original data path') 26 | parser.add_argument('--rewrite', action='store_true', default=False, 27 | help='Do not consider existing saved features') 28 | parser.add_argument('--train', action='store_true', default=False, 29 | help='Also extract the training set') 30 | 31 | 32 | args = parser.parse_args() 33 | im_path = os.path.dirname(args.im_path) 34 | 35 | with open(os.path.join(im_path, 'config.json')) as f: 36 | tmp = json.load(f) 37 | 38 | tmp['im_model_path'] = args.im_path 39 | tmp['sk_model_path'] = args.sk_path 40 | tmp['rewrite'] = args.rewrite 41 | tmp['train'] = args.train 42 | tmp['new_data_path'] = args.new_data_path 43 | args = type('parser', (object,), tmp) 44 | 45 | if not args.new_data_path == '': 46 | args.data_dir = args.new_data_path 47 | 48 | # get data splits 49 | df_dir = os.path.join('aux', 'data', args.dataset) 50 | if args.shape: 51 | splits = create_shape_splits(df_dir) 52 | elif args.dataset == 'domainnet': 53 | splits = create_multi_splits( 54 | df_dir, domain=args.domain, overwrite=args.overwrite) 55 | else: 56 | splits = create_splits(df_dir, args.overwrite, args.gzsl) 57 | 58 | 59 | def main(): 60 | # data normalization 61 | input_size = 224 62 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 63 | std=[0.229, 0.224, 0.225]) 64 | 65 | # data loaders 66 | kwargs = {'num_workers': 8, 'pin_memory': True} if args.cuda else {} 67 | 68 | test_transforms = transforms.Compose([ 69 | transforms.ToPILImage(), 70 | transforms.Resize((input_size, input_size)), 71 | transforms.ToTensor(), 72 | normalize]) 73 | 74 | feats = {} 75 | labels = {} 76 | 77 | for domain in ['im', 'sk']: 78 | key = '_'.join([domain, 'model_path']) 79 | dirname = os.path.dirname(args.__dict__[key]) 80 | fpath = os.path.join(dirname, 'features.npz') 81 | 82 | results_path = os.path.join(dirname, 'results.txt') 83 | 84 | if os.path.isfile(fpath) and args.rewrite is False: 85 | data = np.load(fpath) 86 | feats[domain] = data['features'] 87 | labels[domain] = data['labels'] 88 | 89 | txt = ('Domain (%s): Acc %.2f' % (domain, data['acc'] * 100.)) 90 | print(txt) 91 | write_logs(txt, results_path) 92 | 93 | df_gal = splits[domain]['gal'] 94 | fsem = get_semantic_fname(args.word) 95 | path_semantic = os.path.join('aux', 'Semantic', args.dataset, fsem) 96 | test_proxies = get_proxies( 97 | path_semantic, df_gal['cat'].cat.categories) 98 | else: 99 | df_gal = splits[domain]['gal'] 100 | 101 | test_loader = torch.utils.data.DataLoader( 102 | DataLoader(df_gal, test_transforms, 103 | root=args.data_dir, mode=domain), 104 | batch_size=args.batch_size * 10, shuffle=False, **kwargs) 105 | 106 | # instanciate the models 107 | output_shape, backbone = get_backbone(args) 108 | embed = LinearProjection(output_shape, args.dim_embed) 109 | model = ConvNet(backbone, embed) 110 | 111 | # instanciate the proxies 112 | fsem = get_semantic_fname(args.word) 113 | path_semantic = os.path.join('aux', 'Semantic', args.dataset, fsem) 114 | test_proxies = get_proxies( 115 | path_semantic, df_gal['cat'].cat.categories) 116 | 117 | test_proxynet = ProxyNet(args.n_classes_gal, args.dim_embed, 118 | proxies=torch.from_numpy(test_proxies)) 119 | 120 | # criterion 121 | criterion = ProxyLoss(args.temperature) 122 | 123 | if args.multi_gpu: 124 | model = nn.DataParallel(model) 125 | 126 | # loading 127 | checkpoint = torch.load(args.__dict__[key]) 128 | model.load_state_dict(checkpoint['state_dict']) 129 | txt = ("\n=> loaded checkpoint '{}' (epoch {})" 130 | .format(args.__dict__[key], checkpoint['epoch'])) 131 | print(txt) 132 | write_logs(txt, results_path) 133 | 134 | if args.cuda: 135 | backbone.cuda() 136 | embed.cuda() 137 | model.cuda() 138 | test_proxynet.cuda() 139 | 140 | txt = 'Extracting testing set (%s)...' % (domain) 141 | print(txt) 142 | x, y, acc = extract_predict( 143 | test_loader, model, 144 | test_proxynet.proxies.weight, criterion) 145 | 146 | feats[domain] = x 147 | labels[domain] = y 148 | 149 | np.savez(fpath, features=feats[domain], labels=labels[domain], acc=acc) 150 | 151 | fpath_train = os.path.join(dirname, 'features_train.npz') 152 | if args.train and not os.path.isfile(fpath_train): 153 | df_train = splits[domain]['train'] 154 | 155 | train_loader = torch.utils.data.DataLoader( 156 | DataLoader(df_train, test_transforms, 157 | root=args.data_dir, mode=domain), 158 | batch_size=args.batch_size * 10, shuffle=False, **kwargs) 159 | 160 | train_proxies = get_proxies( 161 | path_semantic, df_train['cat'].cat.categories) 162 | 163 | train_proxynet = ProxyNet(args.n_classes_gal, args.dim_embed, 164 | proxies=torch.from_numpy(train_proxies)) 165 | train_proxynet.cuda() 166 | txt = 'Extracting training set (%s)...' % (domain) 167 | print(txt) 168 | 169 | x, y, _ = extract_predict( 170 | train_loader, model, 171 | train_proxynet.proxies.weight, criterion) 172 | 173 | fpath = os.path.join(dirname, 'features_train.npz') 174 | 175 | np.savez( 176 | fpath, 177 | features=feats[domain], features_train=x, 178 | labels=labels[domain], labels_train=y, 179 | acc=acc) 180 | 181 | txt = ('Domain (%s): Acc %.2f' % (domain, acc * 100.)) 182 | print(txt) 183 | write_logs(txt, results_path) 184 | 185 | if args.shape: 186 | print('\nRetrieval per model') 187 | new_feat_im, new_labels_im = average_views( 188 | splits['im']['test'], feats['im'], labels['im']) 189 | 190 | idx = retrieve(feats['sk'], new_feat_im) 191 | 192 | metrics = score_shape(labels['sk'], new_labels_im, idx) 193 | names = ['NN', 'FT', 'ST', 'E', 'nDCG', 'mAP'] 194 | txt = [('%s %.3f' % (name, value)) for name, value in zip(names, metrics)] 195 | txt = '\t'.join(txt) 196 | print(txt) 197 | write_logs(txt, results_path) 198 | 199 | print('\nRetrieval per model with refinement') 200 | 201 | alpha = 0.4 202 | 203 | g_sk_x = KNN(feats['sk'], new_feat_im, K=1, mode='ones') 204 | new_sk_x = slerp(alpha, L2norm(feats['sk']), L2norm(g_sk_x)) 205 | idx = retrieve(new_sk_x, new_feat_im) 206 | metrics = score_shape(labels['sk'], new_labels_im, idx) 207 | names = ['NN', 'FT', 'ST', 'E', 'nDCG', 'mAP'] 208 | txt = [('%s %.3f' % (name, value)) for name, value in zip(names, metrics)] 209 | txt = '\t'.join(txt) 210 | print(txt) 211 | write_logs(txt, results_path) 212 | 213 | else: 214 | print('\nRetrieval') 215 | txt = evaluate(feats['im'], labels['im'], 216 | feats['sk'], labels['sk']) 217 | print(txt) 218 | write_logs(txt, results_path) 219 | 220 | print('\nRetrieval with refinement') 221 | if args.overwrite: 222 | alpha = 0.7 223 | else: 224 | alpha = 0.4 225 | 226 | g_sk_x = KNN(feats['sk'], feats['im'], K=1, mode='ones') 227 | 228 | new_sk_x = slerp(alpha, L2norm(feats['sk']), L2norm(g_sk_x)) 229 | txt = evaluate( 230 | feats['im'], labels['im'], 231 | new_sk_x, labels['sk']) 232 | print(txt) 233 | write_logs(txt, results_path) 234 | 235 | 236 | def evaluate(im_x, im_y, sk_x, sk_y, classnames=None): 237 | idx = retrieve(sk_x, im_x) 238 | prec, mAP = score(sk_y, im_y, idx) 239 | txt = ('mAP@all: %.04f Prec@100: %.04f\t' % (mAP, prec)) 240 | return txt 241 | 242 | 243 | def average_views(splits, x, y): 244 | ids = splits['id'].unique() 245 | feat = [] 246 | labels = [] 247 | for cad_id in ids: 248 | cond = splits['id'] == cad_id 249 | feat.append(np.mean(x[cond, :], axis=0)) 250 | labels.append(y[cond][0]) 251 | return np.asarray(feat), np.asarray(labels) 252 | 253 | 254 | def slerp(val, low, high): 255 | """Spherical interpolation. val has a range of 0 to 1.""" 256 | if val <= 0: 257 | return low 258 | elif val >= 1: 259 | return high 260 | elif np.allclose(low, high): 261 | return low 262 | omega = np.arccos(np.einsum('ij, ij->i', low, high)) 263 | so = np.sin(omega) 264 | return (np.sin((1.0-val)*omega) / so)[:, None] * low + (np.sin(val*omega)/so)[:, None] * high 265 | 266 | 267 | def write_logs(txt, logpath): 268 | with open(logpath, 'a') as f: 269 | f.write('\n') 270 | f.write(txt) 271 | 272 | 273 | if __name__ == '__main__': 274 | main() 275 | -------------------------------------------------------------------------------- /runs/closed-eval-sb3dr.sh: -------------------------------------------------------------------------------- 1 | cd .. 2 | 3 | DATA_DIR=data/shapes 4 | EXP_DIR=exp/shapes 5 | 6 | python retrieve.py \ 7 | --im-path=$EXP_DIR/SHREC13_im/checkpoint.pth.tar \ 8 | --sk-path=$EXP_DIR/SHREC13_sk/checkpoint.pth.tar 9 | 10 | python retrieve.py \ 11 | --im-path=$EXP_DIR/SHREC14_im/checkpoint.pth.tar \ 12 | --sk-path=$EXP_DIR/SHREC14_sk/checkpoint.pth.tar 13 | 14 | python retrieve.py \ 15 | --im-path=$EXP_DIR/PART-SHREC14_im/checkpoint.pth.tar \ 16 | --sk-path=$EXP_DIR/PART-SHREC14_sk/checkpoint.pth.tar 17 | -------------------------------------------------------------------------------- /runs/closed-eval-sbic.sh: -------------------------------------------------------------------------------- 1 | cd .. 2 | 3 | DATA_DIR=data/sketches 4 | EXP_DIR=exp/sketches 5 | 6 | python fewshot.py \ 7 | --sk-path=$EXP_DIR/fewshot/Sketchy_sk/checkpoint.pth.tar \ 8 | --im-path=$EXP_DIR/fewshot/Sketchy_im/checkpoint.pth.tar \ 9 | --mixing 10 | -------------------------------------------------------------------------------- /runs/closed-eval-sbir.sh: -------------------------------------------------------------------------------- 1 | cd .. 2 | 3 | DATA_DIR=data/sketches 4 | EXP_DIR=exp/sketches 5 | 6 | python retrieve.py \ 7 | --sk-path=$EXP_DIR/zeroshot/Sketchy_sk/checkpoint.pth.tar \ 8 | --im-path=$EXP_DIR/zeroshot/Sketchy_im/checkpoint.pth.tar 9 | 10 | python retrieve.py \ 11 | --sk-path=$EXP_DIR/zeroshot/TU-Berlin_sk/checkpoint.pth.tar \ 12 | --im-path=$EXP_DIR/zeroshot/TU-Berlin_im/checkpoint.pth.tar 13 | 14 | python retrieve.py \ 15 | --sk-path=$EXP_DIR/gzsl/Sketchy_sk/checkpoint.pth.tar \ 16 | --im-path=$EXP_DIR/gzsl/Sketchy_im/checkpoint.pth.tar 17 | 18 | python retrieve.py \ 19 | --sk-path=$EXP_DIR/gzsl/TU-Berlin_sk/checkpoint.pth.tar \ 20 | --im-path=$EXP_DIR/gzsl/TU-Berlin_im/checkpoint.pth.tar 21 | -------------------------------------------------------------------------------- /runs/closed-train-sb3dr.sh: -------------------------------------------------------------------------------- 1 | cd .. 2 | 3 | DATA_DIR=data/shapes 4 | EXP_DIR=exp/shapes 5 | 6 | ## 7 | # SHREC13 8 | ## 9 | python train.py \ 10 | --epochs=100 \ 11 | --multi-gpu \ 12 | --data_dir=$DATA_DIR \ 13 | --exp_dir=$EXP_DIR \ 14 | --dataset=SHREC13 \ 15 | --mode=sk \ 16 | --batch-size=128 \ 17 | --lr=0.01 \ 18 | --da \ 19 | --shape \ 20 | --word=shrec \ 21 | --backbone=seresnet \ 22 | --temperature=0.1 23 | 24 | python train.py \ 25 | --epochs=30 \ 26 | --multi-gpu \ 27 | --data_dir=$DATA_DIR \ 28 | --exp_dir=$EXP_DIR \ 29 | --dataset=SHREC13 \ 30 | --mode=im \ 31 | --batch-size=128 \ 32 | --lr=0.01 \ 33 | --shape \ 34 | --word=shrec \ 35 | --backbone=seresnet \ 36 | --temperature=0.1 37 | 38 | ## 39 | # SHREC14 40 | ## 41 | python train.py \ 42 | --epochs=100 \ 43 | --multi-gpu \ 44 | --data_dir=$DATA_DIR \ 45 | --exp_dir=$EXP_DIR \ 46 | --dataset=SHREC14 \ 47 | --mode=sk \ 48 | --batch-size=128 \ 49 | --lr=0.01 \ 50 | --da \ 51 | --shape \ 52 | --word=shrec \ 53 | --backbone=seresnet \ 54 | --temperature=0.1 55 | 56 | python train.py \ 57 | --epochs=30 \ 58 | --multi-gpu \ 59 | --data_dir=$DATA_DIR \ 60 | --exp_dir=$EXP_DIR \ 61 | --dataset=SHREC14 \ 62 | --mode=im \ 63 | --batch-size=128 \ 64 | --lr=0.01 \ 65 | --shape \ 66 | --word=shrec \ 67 | --backbone=seresnet \ 68 | --temperature=0.1 69 | 70 | ## 71 | # PART-SHREC14 72 | ## 73 | python train.py \ 74 | --epochs=100 \ 75 | --multi-gpu \ 76 | --data_dir=$DATA_DIR \ 77 | --exp_dir=$EXP_DIR \ 78 | --dataset=PART-SHREC14 \ 79 | --mode=sk \ 80 | --batch-size=128 \ 81 | --lr=0.01 \ 82 | --shape \ 83 | --da \ 84 | --word=shrec \ 85 | --backbone=seresnet \ 86 | --temperature=0.1 87 | 88 | python train.py \ 89 | --epochs=30 \ 90 | --multi-gpu \ 91 | --data_dir=$DATA_DIR \ 92 | --exp_dir=$EXP_DIR \ 93 | --dataset=PART-SHREC14 \ 94 | --mode=im \ 95 | --batch-size=128 \ 96 | --lr=0.01 \ 97 | --shape \ 98 | --word=shrec \ 99 | --backbone=seresnet \ 100 | --temperature=0.1 101 | -------------------------------------------------------------------------------- /runs/closed-train-sbic.sh: -------------------------------------------------------------------------------- 1 | cd .. 2 | 3 | DATA_DIR=data/sketches 4 | EXP_DIR=exp/sketches 5 | 6 | python train.py \ 7 | --epochs=20 \ 8 | --multi-gpu \ 9 | --data_dir=$DATA_DIR \ 10 | --exp_dir=$EXP_DIR/fewshot \ 11 | --dataset=Sketchy \ 12 | --mode=sk \ 13 | --lr=0.001 \ 14 | --da \ 15 | --temperature=0.05 \ 16 | --overwrite \ 17 | --backbone=vgg19 \ 18 | --fewshot 19 | 20 | python train.py \ 21 | --epochs=20 \ 22 | --multi-gpu \ 23 | --data_dir=$DATA_DIR \ 24 | --exp_dir=$EXP_DIR/fewshot \ 25 | --dataset=Sketchy \ 26 | --lr=0.001 \ 27 | --da \ 28 | --mode=im \ 29 | --temperature=0.05 \ 30 | --overwrite \ 31 | --backbone=vgg19 \ 32 | --fewshot 33 | -------------------------------------------------------------------------------- /runs/closed-train-sbir.sh: -------------------------------------------------------------------------------- 1 | cd .. 2 | 3 | DATA_DIR=data/sketches 4 | EXP_DIR=exp/sketches 5 | 6 | ## 7 | # zero-shot 8 | ## 9 | 10 | # Sketchy dataset 11 | python train.py \ 12 | --epochs=30 \ 13 | --multi-gpu \ 14 | --data_dir=$DATA_DIR \ 15 | --exp_dir=$EXP_DIR/zeroshot \ 16 | --dataset=Sketchy \ 17 | --mode=sk \ 18 | --lr=0.001 \ 19 | --da \ 20 | --backbone=seresnet \ 21 | --temperature=0.05 \ 22 | --overwrite 23 | 24 | python train.py \ 25 | --epochs=10 \ 26 | --multi-gpu \ 27 | --data_dir=$DATA_DIR \ 28 | --exp_dir=$EXP_DIR/zeroshot \ 29 | --dataset=Sketchy \ 30 | --backbone=seresnet \ 31 | --lr=0.001 \ 32 | --da \ 33 | --mode=im \ 34 | --temperature=0.05 \ 35 | --overwrite 36 | 37 | # TU-Berlin dataset 38 | python train.py \ 39 | --epochs=50 \ 40 | --multi-gpu \ 41 | --data_dir=$DATA_DIR \ 42 | --exp_dir=$EXP_DIR/zeroshot \ 43 | --dataset=TU-Berlin \ 44 | --mode=sk \ 45 | --lr=0.001 \ 46 | --da \ 47 | --backbone=seresnet \ 48 | --temperature=0.05 \ 49 | --overwrite 50 | 51 | python train.py \ 52 | --epochs=5 \ 53 | --multi-gpu \ 54 | --data_dir=$DATA_DIR \ 55 | --exp_dir=$EXP_DIR/zeroshot \ 56 | --dataset=TU-Berlin \ 57 | --backbone=seresnet \ 58 | --lr=0.001 \ 59 | --da \ 60 | --mode=im \ 61 | --temperature=0.05 \ 62 | --overwrite 63 | 64 | ## 65 | # generalized zero-shot 66 | ## 67 | 68 | # Sketchy dataset 69 | python train.py \ 70 | --epochs=30 \ 71 | --multi-gpu \ 72 | --data_dir=$DATA_DIR \ 73 | --exp_dir=$EXP_DIR/gzsl \ 74 | --dataset=Sketchy \ 75 | --mode=sk \ 76 | --lr=0.001 \ 77 | --da \ 78 | --temperature=0.05 \ 79 | --overwrite \ 80 | --backbone=vgg16 \ 81 | --gzsl 82 | 83 | python train.py \ 84 | --epochs=10 \ 85 | --multi-gpu \ 86 | --data_dir=$DATA_DIR \ 87 | --exp_dir=$EXP_DIR/gzsl \ 88 | --dataset=Sketchy \ 89 | --lr=0.001 \ 90 | --da \ 91 | --mode=im \ 92 | --temperature=0.05 \ 93 | --overwrite \ 94 | --backbone=vgg16 \ 95 | --gzsl 96 | 97 | # TU-Berlin dataset 98 | python train.py \ 99 | --epochs=100 \ 100 | --multi-gpu \ 101 | --data_dir=$DATA_DIR \ 102 | --exp_dir=$EXP_DIR/gzsl \ 103 | --dataset=TU-Berlin \ 104 | --mode=sk \ 105 | --lr=0.001 \ 106 | --da \ 107 | --temperature=0.05 \ 108 | --overwrite \ 109 | --backbone=vgg16 \ 110 | --gzsl 111 | 112 | python train.py \ 113 | --epochs=20 \ 114 | --multi-gpu \ 115 | --data_dir=$DATA_DIR \ 116 | --exp_dir=$EXP_DIR/gzsl \ 117 | --dataset=TU-Berlin \ 118 | --lr=0.001 \ 119 | --da \ 120 | --mode=im \ 121 | --temperature=0.05 \ 122 | --overwrite \ 123 | --backbone=vgg16 \ 124 | --gzsl 125 | -------------------------------------------------------------------------------- /runs/open-eval.sh: -------------------------------------------------------------------------------- 1 | cd .. 2 | 3 | DATA_DIR=data/domainnet 4 | EXP_DIR=exp/domainnet 5 | 6 | ## 7 | # Feature extraction 8 | ## 9 | 10 | # extract many shot 11 | for DOMAIN in clipart infograph painting real sketch 12 | do 13 | python retrieve.py \ 14 | --im-path=$EXP_DIR/manyshot/domainnet_im_$DOMAIN/checkpoint.pth.tar \ 15 | --sk-path=$EXP_DIR/manyshot/domainnet_sk/checkpoint.pth.tar 16 | done 17 | 18 | # extract zero shot 19 | for DOMAIN in clipart infograph painting real sketch 20 | do 21 | python retrieve.py \ 22 | --im-path=$EXP_DIR/zeroshot/domainnet_im_$DOMAIN/checkpoint.pth.tar \ 23 | --sk-path=$EXP_DIR/zeroshot/domainnet_sk/checkpoint.pth.tar 24 | done 25 | 26 | ## 27 | # any2any experiment 28 | ## 29 | 30 | python retrieve-any.py --dir-path=$EXP_DIR/manyshot 31 | python retrieve-any.py --dir-path=$EXP_DIR/zeroshot 32 | 33 | ## 34 | # many2any experiment 35 | ## 36 | 37 | python retrieve-many.py --dir-path=$EXP_DIR/zeroshot --eval=many2any 38 | python retrieve-many.py --dir-path=$EXP_DIR/manyshot --eval=many2any 39 | 40 | ## 41 | # any2many experiment 42 | ## 43 | 44 | python retrieve-many.py --dir-path=$EXP_DIR/manyshot --eval=any2many 45 | -------------------------------------------------------------------------------- /runs/open-train.sh: -------------------------------------------------------------------------------- 1 | cd .. 2 | 3 | DATA_DIR=data/domainnet 4 | EXP_DIR=exp/domainnet 5 | 6 | ## 7 | # domainnet many-shot 8 | ## 9 | for DOMAIN in clipart infograph painting real sketch 10 | do 11 | python train.py \ 12 | --epochs=60 \ 13 | --multi-gpu \ 14 | --data_dir=$DATA_DIR \ 15 | --exp_dir=$EXP_DIR/manyshot \ 16 | --dataset=domainnet_$DOMAIN \ 17 | --mode=im \ 18 | --lr=0.001 \ 19 | --da \ 20 | --backbone=seresnet \ 21 | --temperature=0.05 22 | done 23 | 24 | python train.py \ 25 | --epochs=60 \ 26 | --multi-gpu \ 27 | --data_dir=$DATA_DIR \ 28 | --exp_dir=$EXP_DIR/manyshot \ 29 | --dataset=domainnet \ 30 | --mode=sk \ 31 | --lr=0.001 \ 32 | --da \ 33 | --backbone=seresnet \ 34 | --temperature=0.05 35 | 36 | ## 37 | # domainnet zero-shot 38 | ## 39 | for DOMAIN in clipart infograph painting real sketch 40 | do 41 | python train.py \ 42 | --epochs=30 \ 43 | --multi-gpu \ 44 | --data_dir=$DATA_DIR \ 45 | --exp_dir=$EXP_DIR/zeroshot \ 46 | --dataset=domainnet_$DOMAIN \ 47 | --mode=im \ 48 | --lr=0.001 \ 49 | --da \ 50 | --backbone=seresnet \ 51 | --temperature=0.05 \ 52 | --overwrite 53 | done 54 | 55 | python train.py \ 56 | --epochs=30 \ 57 | --multi-gpu \ 58 | --data_dir=$DATA_DIR \ 59 | --exp_dir=$EXP_DIR/zeroshot \ 60 | --dataset=domainnet \ 61 | --mode=sk \ 62 | --lr=0.001 \ 63 | --da \ 64 | --backbone=seresnet \ 65 | --temperature=0.05 \ 66 | --overwrite 67 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import json 4 | import numpy as np 5 | import time 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from torch.autograd import Variable 10 | from torch.optim.lr_scheduler import CosineAnnealingLR 11 | from torchvision import transforms 12 | from data import create_splits, create_fewshot_splits, create_shape_splits 13 | from data import create_multi_splits 14 | from data import DataLoader, get_proxies 15 | from models import LinearProjection, ConvNet 16 | from models import ProxyNet, ProxyLoss 17 | from utils import AverageMeter, to_numpy 18 | from utils import get_semantic_fname, get_backbone, random_transform 19 | from validate import extract_predict 20 | 21 | 22 | # Training settings 23 | parser = argparse.ArgumentParser(description='PyTorch SBIR') 24 | # hyper-parameters 25 | parser.add_argument('--batch-size', type=int, default=128, metavar='N', 26 | help='input batch size for training (default: 128)') 27 | parser.add_argument('--epochs', type=int, default=30, metavar='N', 28 | help='number of epochs to train (default: 30)') 29 | parser.add_argument('--start_epoch', type=int, default=1, metavar='N', 30 | help='number of start epoch (default: 1)') 31 | parser.add_argument('--lr', type=float, default=0.001, metavar='LR', 32 | help='learning rate (default: 1e-3)') 33 | parser.add_argument('--factor_lower', type=float, default=.1, 34 | help='multiplicative factor of the LR for lower layers') 35 | parser.add_argument('--seed', type=int, default=456, metavar='S', 36 | help='random seed (default: 1)') 37 | parser.add_argument('--temperature', type=float, default=1., metavar='M', 38 | help='temperature (default: 1.)') 39 | parser.add_argument('--wd', type=float, default=5e-5, metavar='M', 40 | help='weight decay (default: 5e-5)') 41 | # flags 42 | parser.add_argument('--no-cuda', action='store_true', default=False, 43 | help='enables CUDA training') 44 | parser.add_argument('--multi-gpu', action='store_true', default=False, 45 | help='enables multi gpu training') 46 | # model 47 | parser.add_argument('--dim_embed', type=int, default=300, metavar='N', 48 | help='how many dimensions in embedding (default: 300)') 49 | parser.add_argument('--da', action='store_true', default=False, 50 | help='data augmentation') 51 | parser.add_argument('--backbone', type=str, default='resnet', 52 | help='vgg16|vgg19|resnet|seresnet') 53 | parser.add_argument('--word', type=str, default='word2vec', 54 | help='Semantic space') 55 | # setup 56 | parser.add_argument('--fewshot', action='store_true', default=False, 57 | help='few-shot experiment') 58 | parser.add_argument('--gzsl', action='store_true', default=False, 59 | help='Generalized setting, only works for Sketchy and TUB') 60 | parser.add_argument('--shape', action='store_true', default=False, 61 | help='3D shape recognition') 62 | parser.add_argument('--mode', type=str, required=True, 63 | help='im|sk') 64 | # plumbing 65 | parser.add_argument('--dataset', type=str, required=True, 66 | help='Sketchy|TU-Berlin|SHREC13|SHREC14|PART-SHREC14|domainnet') 67 | parser.add_argument('--overwrite', action='store_true', default=False, 68 | help='zero-shot experiment') 69 | parser.add_argument('--data_dir', type=str, metavar='DD', 70 | default='data', 71 | help='data folder path') 72 | parser.add_argument('--exp_dir', type=str, default='exp', metavar='ED', 73 | help='folder for saving exp') 74 | parser.add_argument('--m', type=str, default='SBIR', metavar='M', 75 | help='message') 76 | 77 | 78 | args = parser.parse_args() 79 | args.cuda = not args.no_cuda and torch.cuda.is_available() 80 | 81 | np.random.seed(args.seed) 82 | torch.manual_seed(args.seed) 83 | if args.cuda: 84 | torch.cuda.manual_seed(args.seed) 85 | torch.backends.cudnn.deterministic = True 86 | 87 | # get data splits 88 | if 'domainnet' in args.dataset: 89 | args.domain = '_'.join(args.dataset.split('_')[1:]) 90 | args.dataset = 'domainnet' 91 | df_dir = os.path.join('aux', 'data', args.dataset) 92 | 93 | if args.fewshot: 94 | splits = create_fewshot_splits(df_dir) 95 | df_train = splits[args.mode]['train'] 96 | df_gal = splits[args.mode]['test'] 97 | elif args.shape: 98 | splits = create_shape_splits(df_dir) 99 | df_train = splits[args.mode]['train'] 100 | df_gal = splits[args.mode]['test'] 101 | elif args.dataset in ['domainnet']: 102 | splits = create_multi_splits( 103 | df_dir, domain=args.domain, overwrite=args.overwrite) 104 | df_train = splits[args.mode]['train'] 105 | df_gal = splits[args.mode]['test'] 106 | else: 107 | splits = create_splits(df_dir, args.overwrite, args.gzsl) 108 | df_train = splits[args.mode]['train'] 109 | df_gal = splits[args.mode]['gal'] 110 | 111 | args.n_classes = len(df_train['cat'].cat.categories) 112 | args.n_classes_gal = len(df_gal['cat'].cat.categories) 113 | 114 | # create experiment folder 115 | dname = '%s_%s' % (args.dataset, args.mode) 116 | if args.dataset == 'domainnet' and args.mode == 'im': 117 | dname = dname + '_' + args.domain 118 | path = os.path.join(args.exp_dir, dname) 119 | if not os.path.exists(path): 120 | os.makedirs(path) 121 | 122 | # saving logs 123 | with open(os.path.join(path, 'config.json'), 'w') as f: 124 | json.dump(args.__dict__, f, indent=4, sort_keys=True) 125 | 126 | with open(os.path.join(path, 'logs.txt'), 'w') as f: 127 | f.write('Experiment with SBIR\n') 128 | 129 | 130 | def write_logs(txt, logpath=os.path.join(path, 'logs.txt')): 131 | with open(logpath, 'a') as f: 132 | f.write('\n') 133 | f.write(txt) 134 | 135 | 136 | def main(): 137 | # data normalization 138 | input_size = 224 139 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 140 | std=[0.229, 0.224, 0.225]) 141 | 142 | # data loaders 143 | kwargs = {'num_workers': 8, 'pin_memory': True} if args.cuda else {} 144 | 145 | if args.da: 146 | train_transforms = transforms.Compose([ 147 | random_transform, 148 | transforms.ToPILImage(), 149 | transforms.Resize((input_size, input_size)), 150 | transforms.ToTensor(), 151 | normalize]) 152 | else: 153 | train_transforms = transforms.Compose([ 154 | transforms.ToPILImage(), 155 | transforms.Resize((input_size, input_size)), 156 | transforms.RandomHorizontalFlip(), 157 | transforms.ToTensor(), 158 | normalize]) 159 | 160 | test_transforms = transforms.Compose([ 161 | transforms.ToPILImage(), 162 | transforms.Resize((input_size, input_size)), 163 | transforms.ToTensor(), 164 | normalize]) 165 | 166 | train_loader = torch.utils.data.DataLoader( 167 | DataLoader(df_train, train_transforms, 168 | root=args.data_dir, mode=args.mode), 169 | batch_size=args.batch_size, shuffle=True, **kwargs) 170 | 171 | test_loader = torch.utils.data.DataLoader( 172 | DataLoader(df_gal, test_transforms, 173 | root=args.data_dir, mode=args.mode), 174 | batch_size=args.batch_size, shuffle=False, **kwargs) 175 | 176 | # instanciate the models 177 | output_shape, backbone = get_backbone(args) 178 | embed = LinearProjection(output_shape, args.dim_embed) 179 | model = ConvNet(backbone, embed) 180 | 181 | # instanciate the proxies 182 | fsem = get_semantic_fname(args.word) 183 | path_semantic = os.path.join('aux', 'Semantic', args.dataset, fsem) 184 | train_proxies = get_proxies( 185 | path_semantic, df_train['cat'].cat.categories) 186 | test_proxies = get_proxies( 187 | path_semantic, df_gal['cat'].cat.categories) 188 | 189 | train_proxynet = ProxyNet(args.n_classes, args.dim_embed, 190 | proxies=torch.from_numpy(train_proxies)) 191 | test_proxynet = ProxyNet(args.n_classes_gal, args.dim_embed, 192 | proxies=torch.from_numpy(test_proxies)) 193 | 194 | # criterion 195 | criterion = ProxyLoss(args.temperature) 196 | 197 | if args.multi_gpu: 198 | model = nn.DataParallel(model) 199 | 200 | if args.cuda: 201 | backbone.cuda() 202 | embed.cuda() 203 | model.cuda() 204 | train_proxynet.cuda() 205 | test_proxynet.cuda() 206 | 207 | parameters_set = [] 208 | 209 | low_layers = [] 210 | upper_layers = [] 211 | 212 | for c in backbone.children(): 213 | low_layers.extend(list(c.parameters())) 214 | for c in embed.children(): 215 | upper_layers.extend(list(c.parameters())) 216 | 217 | parameters_set.append({'params': low_layers, 218 | 'lr': args.lr * args.factor_lower}) 219 | parameters_set.append({'params': upper_layers, 220 | 'lr': args.lr * 1.}) 221 | 222 | optimizer = optim.SGD( 223 | parameters_set, lr=args.lr, 224 | momentum=0.9, nesterov=True, 225 | weight_decay=args.wd) 226 | 227 | n_parameters = sum([p.data.nelement() 228 | for p in model.parameters()]) 229 | print(' + Number of params: {}'.format(n_parameters)) 230 | 231 | scheduler = CosineAnnealingLR( 232 | optimizer, args.epochs * len(train_loader), eta_min=3e-6) 233 | 234 | print('Starting training...') 235 | for epoch in range(args.start_epoch, args.epochs + 1): 236 | # update learning rate 237 | scheduler.step() 238 | 239 | # train for one epoch 240 | train(train_loader, model, 241 | train_proxynet.proxies.weight, criterion, 242 | optimizer, epoch, scheduler) 243 | 244 | val_acc = evaluate( 245 | test_loader, model, 246 | test_proxynet.proxies.weight, criterion) 247 | 248 | # saving 249 | if epoch == args.epochs: 250 | save_checkpoint({ 251 | 'epoch': epoch, 252 | 'state_dict': model.state_dict()}) 253 | 254 | print('\nResults on test set (end of training)') 255 | write_logs('\nResults on test set (end of training)') 256 | test_acc = evaluate( 257 | test_loader, model, 258 | test_proxynet.proxies.weight, criterion) 259 | 260 | 261 | def train(train_loader, model, 262 | proxies, criterion, optimizer, epoch, scheduler): 263 | """Training loop for one epoch""" 264 | batch_time = AverageMeter() 265 | data_time = AverageMeter() 266 | 267 | val_loss = AverageMeter() 268 | val_acc = AverageMeter() 269 | 270 | # switch to train mode 271 | model.train() 272 | 273 | end = time.time() 274 | 275 | for i, (x, y) in enumerate(train_loader): 276 | # measure data loading time 277 | data_time.update(time.time() - end) 278 | 279 | if len(x.shape) == 5: 280 | batch_size, nviews = x.shape[0], x.shape[1] 281 | x = x.view(batch_size * nviews, 3, 224, 224) 282 | 283 | if len(y) == args.batch_size: 284 | if args.cuda: 285 | x = x.cuda() 286 | y = y.cuda() 287 | 288 | x = Variable(x) 289 | 290 | # embed 291 | x_emb = model(x) 292 | 293 | loss, acc = criterion(x_emb, y, proxies) 294 | 295 | val_loss.update(to_numpy(loss), x.size(0)) 296 | val_acc.update(acc, x.size(0)) 297 | 298 | # compute gradient and do SGD step 299 | optimizer.zero_grad() 300 | loss.backward() 301 | optimizer.step() 302 | scheduler.step() 303 | 304 | # measure elapsed time 305 | batch_time.update(time.time() - end) 306 | end = time.time() 307 | 308 | txt = ('Epoch [%d] (Time %.2f Data %.2f):\t' 309 | 'Loss %.4f\t Acc %.4f' % 310 | (epoch, batch_time.avg * i, data_time.avg * i, 311 | val_loss.avg, val_acc.avg * 100.)) 312 | print(txt) 313 | write_logs(txt) 314 | 315 | 316 | def evaluate(loader, model, proxies, criterion): 317 | x, y, acc = extract_predict(loader, model, proxies, criterion) 318 | txt = ('.. Acc: %.02f\t' % (acc * 100)) 319 | print(txt) 320 | write_logs(txt) 321 | return acc 322 | 323 | 324 | def save_checkpoint(state, folder=path, filename='checkpoint.pth.tar'): 325 | torch.save(state, os.path.join(path, filename)) 326 | 327 | 328 | if __name__ == '__main__': 329 | main() 330 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | import os 4 | import matplotlib 5 | matplotlib.use('agg') 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | from skimage.transform import warp, AffineTransform 9 | 10 | 11 | def L2norm(x): 12 | return x / x.norm(p=2, dim=1)[:, None] 13 | 14 | 15 | def cosine_similarity(x, y=None, eps=1e-8): 16 | if y is None: 17 | w = x.norm(p=2, dim=1, keepdim=True) 18 | return torch.mm(x, x.t()) / (w * w.t()).clamp(min=eps) 19 | else: 20 | xx = L2norm(x) 21 | yy = L2norm(y) 22 | return xx.matmul(yy.t()) 23 | 24 | 25 | class AverageMeter(object): 26 | """Computes and stores the average and current value""" 27 | def __init__(self): 28 | self.reset() 29 | 30 | def reset(self): 31 | self.val = 0 32 | self.avg = 0 33 | self.sum = 0 34 | self.count = 0 35 | 36 | def update(self, val, n=1): 37 | self.val = val 38 | self.sum += val * n 39 | self.count += n 40 | self.avg = self.sum / self.count 41 | 42 | 43 | def to_numpy(x): 44 | return x.cpu().data.numpy()[0] 45 | 46 | 47 | def get_backbone(args, pretrained=True): 48 | from models import VGG16, VGG19, ResNet50, SEResNet50 49 | 50 | if args.backbone == 'resnet': 51 | output_shape = 2048 52 | backbone = ResNet50(pretrained=pretrained, kp=args.kp) 53 | elif args.backbone == 'vgg16': 54 | output_shape = 4096 55 | backbone = VGG16(pretrained=pretrained) 56 | elif args.backbone == 'vgg19': 57 | output_shape = 4096 58 | backbone = VGG19(pretrained=pretrained) 59 | elif args.backbone == 'seresnet': 60 | output_shape = 2048 61 | backbone = SEResNet50(pretrained=pretrained) 62 | return output_shape, backbone 63 | 64 | 65 | def random_transform(img): 66 | '''Same augmentation as Qiu et al ICCV 2019 67 | https://github.com/qliu24/SAKE 68 | ''' 69 | if img.shape[0] != 224: 70 | img = cv2.resize(img, (224, 224)) 71 | 72 | if np.random.random() < 0.5: 73 | img = img[:,::-1,:] 74 | 75 | if np.random.random() < 0.5: 76 | sx = np.random.uniform(0.7, 1.3) 77 | sy = np.random.uniform(0.7, 1.3) 78 | else: 79 | sx = 1.0 80 | sy = 1.0 81 | 82 | if np.random.random() < 0.5: 83 | rx = np.random.uniform(-30.0*2.0*np.pi/360.0,+30.0*2.0*np.pi/360.0) 84 | else: 85 | rx = 0.0 86 | 87 | if np.random.random() < 0.5: 88 | tx = np.random.uniform(-10,10) 89 | ty = np.random.uniform(-10,10) 90 | else: 91 | tx = 0.0 92 | ty = 0.0 93 | 94 | if np.random.random()<0.7: 95 | aftrans = AffineTransform(scale=(sx, sy), rotation=rx, translation=(tx,ty)) 96 | img_aug = warp(img,aftrans.inverse, preserve_range=True).astype('uint8') 97 | return img_aug 98 | else: 99 | return img 100 | 101 | 102 | def get_semantic_fname(space='word2vec'): 103 | if space == 'word2vec': 104 | return 'word2vec-google-news.npy' 105 | elif space == 'shrec': 106 | return 'w2v.npz' 107 | 108 | 109 | def zero_cnames(dataset): 110 | if dataset == 'Sketchy': 111 | # same split at Qiu et al ICCV 2019 112 | cnames = ['cup', 'chicken', 'camel', 113 | 'swan', 'squirrel', 'snail', 'scissors', 114 | 'harp', 'horse', 115 | 'ray', 'rifle', 116 | 'pineapple', 'parrot', 117 | 'volcano', 118 | 'windmill', 'wine_bottle', 119 | 'teddy_bear', 'tree', 'tank', 120 | 'deer', 121 | 'airplane', 122 | 'wheelchair', 123 | 'umbrella', 124 | 'butterfly', 'bell'] 125 | elif dataset == 'TU-Berlin': 126 | # same split at Qiu et al ICCV 2019 127 | cnames = ['banana', 'bottle_opener', 'bus', 'brain', 'bridge', 'bread', 128 | 'suitcase', 'streetlight', 'shoe', 'snowboard', 'space_shuttle', 129 | 'tractor', 'telephone', 'teacup', 't_shirt', 'trombone', 'table', 130 | 'canoe', 131 | 'fan', 'frying_pan', 132 | 'penguin', 'pizza', 'parachute', 133 | 'laptop', 'lighter', 134 | 'hot_air_balloon', 'horse', 135 | 'ant', 136 | 'windmill', 137 | 'rollerblades'] 138 | elif dataset == 'domainnet': 139 | # zero-shot split with at least 40 samples per category 140 | cnames = ['The_Mona_Lisa', 'animal_migration', 141 | 'bandage', 'beach', 'beard', 'bread', 142 | 'calendar', 'campfire', 'circle', 143 | 'door', 'ear', 'eyeglasses', 144 | 'feather', 'flashlight', 'fork', 145 | 'garden', 'grass', 146 | 'hat', 'hockey_stick', 'hot_air_balloon', 'hurricane', 147 | 'key', 'knee', 'ladder', 'lantern', 'mouth', 148 | 'octopus', 'onion', 149 | 'palm_tree', 'picture_frame', 'pond', 'potato', 150 | 'rake', 'roller_coaster', 151 | 'sailboat', 'sandwich', 'scissors', 'snowflake', 'steak', 152 | 'stop_sign', 'string_bean', 'suitcase', 'sun', 153 | 'tree', 'windmill'] 154 | return cnames 155 | 156 | 157 | def heatmap(data, row_labels, col_labels, ax=None, 158 | cbar_kw={}, cbarlabel="", **kwargs): 159 | """ 160 | Create a heatmap from a numpy array and two lists of labels. 161 | (adapted from the matplotlib example) 162 | 163 | Parameters 164 | ---------- 165 | data 166 | A 2D numpy array of shape (N, M). 167 | row_labels 168 | A list or array of length N with the labels for the rows. 169 | col_labels 170 | A list or array of length M with the labels for the columns. 171 | ax 172 | A `matplotlib.axes.Axes` instance to which the heatmap is plotted. If 173 | not provided, use current axes or create a new one. Optional. 174 | cbar_kw 175 | A dictionary with arguments to `matplotlib.Figure.colorbar`. Optional. 176 | cbarlabel 177 | The label for the colorbar. Optional. 178 | **kwargs 179 | All other arguments are forwarded to `imshow`. 180 | """ 181 | 182 | if not ax: 183 | ax = plt.gca() 184 | 185 | # Plot the heatmap 186 | im = ax.imshow(data, **kwargs) 187 | 188 | # Create colorbar 189 | cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw) 190 | cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom") 191 | 192 | # We want to show all ticks... 193 | ax.set_xticks(np.arange(data.shape[1])) 194 | ax.set_yticks(np.arange(data.shape[0])) 195 | # ... and label them with the respective list entries. 196 | ax.set_xticklabels(col_labels) 197 | ax.set_yticklabels(row_labels) 198 | 199 | # Let the horizontal axes labeling appear on top. 200 | ax.tick_params(top=True, bottom=False, 201 | labeltop=True, labelbottom=False) 202 | 203 | # Rotate the tick labels and set their alignment. 204 | plt.setp(ax.get_xticklabels(), rotation=-30, ha="right", 205 | rotation_mode="anchor") 206 | 207 | # Turn spines off and create white grid. 208 | for edge, spine in ax.spines.items(): 209 | spine.set_visible(False) 210 | 211 | ax.set_xticks(np.arange(data.shape[1]+1)-.5, minor=True) 212 | ax.set_yticks(np.arange(data.shape[0]+1)-.5, minor=True) 213 | ax.grid(which="minor", color="w", linestyle='-', linewidth=3) 214 | ax.tick_params(which="minor", bottom=False, left=False) 215 | 216 | return im, cbar 217 | 218 | 219 | def annotate_heatmap(im, data=None, valfmt="{x:.2f}", 220 | textcolors=["black", "white"], 221 | threshold=None, **textkw): 222 | """ 223 | A function to annotate a heatmap. 224 | (adapted from the matplotlib example) 225 | 226 | Parameters 227 | ---------- 228 | im 229 | The AxesImage to be labeled. 230 | data 231 | Data used to annotate. If None, the image's data is used. Optional. 232 | valfmt 233 | The format of the annotations inside the heatmap. This should either 234 | use the string format method, e.g. "$ {x:.2f}", or be a 235 | `matplotlib.ticker.Formatter`. Optional. 236 | textcolors 237 | A list or array of two color specifications. The first is used for 238 | values below a threshold, the second for those above. Optional. 239 | threshold 240 | Value in data units according to which the colors from textcolors are 241 | applied. If None (the default) uses the middle of the colormap as 242 | separation. Optional. 243 | **kwargs 244 | All other arguments are forwarded to each call to `text` used to create 245 | the text labels. 246 | """ 247 | 248 | if not isinstance(data, (list, np.ndarray)): 249 | data = im.get_array() 250 | 251 | # Normalize the threshold to the images color range. 252 | if threshold is not None: 253 | threshold = im.norm(threshold) 254 | else: 255 | threshold = im.norm(data.max())/2. 256 | 257 | # Set default alignment to center, but allow it to be 258 | # overwritten by textkw. 259 | kw = dict(horizontalalignment="center", 260 | verticalalignment="center") 261 | kw.update(textkw) 262 | 263 | # Get the formatter in case a string is supplied 264 | if isinstance(valfmt, str): 265 | valfmt = matplotlib.ticker.StrMethodFormatter(valfmt) 266 | 267 | # Loop over the data and create a `Text` for each "pixel". 268 | # Change the text's color depending on the data. 269 | texts = [] 270 | for i in range(data.shape[0]): 271 | for j in range(data.shape[1]): 272 | kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)]) 273 | text = im.axes.text(j, i, valfmt(data[i, j], None), **kw) 274 | texts.append(text) 275 | 276 | return texts 277 | -------------------------------------------------------------------------------- /validate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import faiss 3 | import numpy as np 4 | import multiprocessing as mp 5 | from torch.autograd import Variable 6 | from metrics import precision_at_k, mean_average_precision, average_precision 7 | from metrics import dcg_at_k, ndcg_at_k 8 | from utils import AverageMeter, to_numpy 9 | 10 | 11 | def L2norm(x): 12 | return x / np.linalg.norm(x, axis=1)[:, None] 13 | 14 | 15 | def extract(loader, model): 16 | model.eval() 17 | outputs = [] 18 | labels = [] 19 | for i, (imgs, l) in enumerate(loader): 20 | if torch.cuda.is_available(): 21 | imgs = imgs.cuda(async=True) 22 | 23 | imgs = Variable(imgs, volatile=True) 24 | outputs.append(model(imgs).data.cpu().numpy()) 25 | 26 | labels.extend(l) 27 | return np.vstack(outputs), np.asarray(labels) 28 | 29 | 30 | def extract_predict(loader, model, proxies, criterion): 31 | model.eval() 32 | outputs = [] 33 | labels = [] 34 | 35 | val_acc = AverageMeter() 36 | 37 | copy = False 38 | 39 | for i, (imgs, l) in enumerate(loader): 40 | if len(imgs.shape) == 5: 41 | batch_size, nviews = imgs.shape[0], imgs.shape[1] 42 | imgs = imgs.view(batch_size * nviews, 3, 224, 224) 43 | 44 | # hack to handle multiple gpus 45 | # can crash if there are 0 im 46 | if batch_size < 4: 47 | imgs = torch.cat([imgs, imgs], dim=0) 48 | l = torch.cat([l, l]) 49 | copy = True 50 | 51 | labels.extend(l) 52 | 53 | if torch.cuda.is_available(): 54 | imgs = imgs.cuda(async=True) 55 | l = l.cuda(async=True) 56 | 57 | imgs = Variable(imgs, volatile=True) 58 | if torch.cuda.device_count() > 1: 59 | embs = model(imgs) 60 | else: 61 | embs = model(imgs) 62 | 63 | if copy: 64 | embs = embs[:batch_size] 65 | l = l[:batch_size] 66 | 67 | outputs.append(embs.data.cpu().numpy()) 68 | 69 | loss, acc = criterion(embs, l, proxies) 70 | val_acc.update(acc, imgs.size(0)) 71 | 72 | return np.vstack(outputs), np.asarray(labels), val_acc.avg 73 | 74 | 75 | def retrieve(query, gallery, dist='euc', L2=True): 76 | d = query.shape[1] 77 | if dist == 'euc': 78 | index_flat = faiss.IndexFlatL2(d) 79 | elif dist == 'cos': 80 | index_flat = faiss.IndexFlatIP(d) 81 | 82 | if L2: 83 | query = L2norm(query) 84 | gallery = L2norm(gallery) 85 | 86 | index_flat.add(gallery) 87 | K = gallery.shape[0] 88 | D, I = index_flat.search(query, K) 89 | return I 90 | 91 | 92 | def KNN(query, gallery, K=10, mode='ones'): 93 | '''retrieves the K-Nearest Neighbors in the gallery''' 94 | d = query.shape[1] 95 | query = L2norm(query) 96 | gallery = L2norm(gallery) 97 | 98 | res = faiss.StandardGpuResources() 99 | 100 | index_flat = faiss.IndexFlatL2(d) 101 | 102 | if torch.cuda.is_available(): 103 | gpu_index_flat = faiss.index_cpu_to_gpu(res, 0, index_flat) 104 | gpu_index_flat.add(gallery) 105 | D, I = gpu_index_flat.search(query, K) 106 | else: 107 | index_flat.add(gallery) 108 | D, I = index_flat.search(query, K) 109 | 110 | if mode == 'lin': 111 | weights = (float(K) - np.arange(0, K)) / float(K) 112 | elif mode == 'exp': 113 | weights = np.exp(-np.arange(0, K)) 114 | elif mode == 'ones': 115 | weights = np.ones(K) 116 | weights_sum = weights.sum() 117 | 118 | new_queries = [] 119 | for i in range(len(query)): 120 | idx = I[i, :K] 121 | to_consider = gallery[idx, :] 122 | new_queries.append(np.dot(weights, to_consider) / weights_sum) 123 | new_queries = np.asarray(new_queries, dtype=np.float32) 124 | return new_queries 125 | 126 | 127 | def score(sk_labels, im_labels, index): 128 | res = np.equal(im_labels[index], sk_labels[:, None]) 129 | 130 | prec = np.mean([precision_at_k(r, 100) for r in res]) 131 | 132 | pool = mp.Pool(processes=10) 133 | results = [pool.apply_async(average_precision, args=(r,)) for r in res] 134 | mAP = np.mean([p.get() for p in results]) 135 | pool.close() 136 | return prec, mAP 137 | 138 | 139 | def score_shape(sk_labels, im_labels, index): 140 | vv, cc = np.unique(im_labels, return_counts=True) 141 | lut = {} 142 | for v, c in zip(vv, cc): 143 | lut[v] = c 144 | 145 | res = np.equal(im_labels[index], sk_labels[:, None]) 146 | 147 | # 1-NN 148 | nn = np.mean(res[:, 0]) 149 | 150 | # first and second tier 151 | ft = np.mean([np.sum(r[:lut[l]]) / float(lut[l]) 152 | for r, l in zip(res, sk_labels)]) 153 | st = np.mean([np.sum(r[:2 * lut[l]]) / float(lut[l]) 154 | for r, l in zip(res, sk_labels)]) 155 | 156 | # e-measure 157 | prec = np.mean([precision_at_k(r, 32) for r in res]) 158 | rec = np.mean([np.sum(r[:32]) / float(lut[l]) 159 | for r, l in zip(res, sk_labels)]) 160 | e_measure = 2 * prec * rec / (prec + rec) 161 | 162 | # dcg 163 | pool = mp.Pool(processes=10) 164 | results = [pool.apply_async(dcg_at_k, args=(r, len(im_labels), 1)) for r in res] 165 | mDCG = np.mean([p.get() for p in results]) 166 | pool.close() 167 | 168 | # ndgc 169 | pool = mp.Pool(processes=10) 170 | results = [pool.apply_async(ndcg_at_k, args=(r, len(im_labels), 1)) for r in res] 171 | mnDCG = np.mean([p.get() for p in results]) 172 | pool.close() 173 | 174 | # map 175 | pool = mp.Pool(processes=10) 176 | results = [pool.apply_async(average_precision, args=(r,)) for r in res] 177 | mAP = np.mean([p.get() for p in results]) 178 | pool.close() 179 | 180 | return nn, ft, st, e_measure, mnDCG, mAP 181 | 182 | 183 | def score_single(sk_labels, im_labels, index): 184 | res = np.equal(im_labels[index], sk_labels[:, None]) 185 | 186 | prec = np.mean([precision_at_k(r, 100) for r in res]) 187 | mAP = mean_average_precision(res) 188 | return prec, mAP 189 | --------------------------------------------------------------------------------