├── __init__.py ├── data ├── __init__.py ├── dataset.py ├── stanford_products.py ├── symo.py ├── inshop.py ├── cub200.py └── cars196.py ├── tests ├── __init__.py └── evaluation │ ├── __init__.py │ └── test_retrieval.py ├── DOCKER_TAG ├── data1 └── data │ ├── inshop │ └── .keep │ └── symo │ └── create_list_eval_partition.py ├── evaluation ├── __init__.py ├── nmi.py └── retrieval.py ├── metric_learning ├── __init__.py ├── modules │ ├── __init__.py │ ├── losses.py │ └── featurizer.py ├── util.py ├── extract_features.py ├── sampler.py └── train_classification.py ├── .gitmodules ├── requirements.txt ├── scripts ├── get_cub200_dataset.sh ├── get_stanford_products_dataset.sh ├── get_cars196_dataset.sh ├── get_inshop_dataset.sh └── bin │ └── pdoc ├── docker ├── docker-build.sh └── Dockerfile ├── .gitignore ├── README.md ├── detect.py ├── demo.ipynb └── LICENSE /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /DOCKER_TAG: -------------------------------------------------------------------------------- 1 | c4e5c1e 2 | -------------------------------------------------------------------------------- /data1/data/inshop/.keep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /metric_learning/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /metric_learning/modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "third_party/vision"] 2 | path = third_party/vision 3 | url = https://github.com/pytorch/vision.git 4 | [submodule "third_party/pytorch"] 5 | path = third_party/pytorch 6 | url = https://github.com/pytorch/pytorch.git 7 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Pillow-SIMD==5.3.0.post0 2 | scipy>=1.0.0 3 | simplejson==3.10.0 4 | protobuf==3.5.2 5 | scikit-image==0.12.3 6 | matplotlib==2.1.0 7 | pyfakefs==3.2 8 | scikit-learn==0.17.1 9 | ipdb==0.8.1 10 | urllib3==1.24.2 11 | Bottleneck==1.2.1 12 | pretrainedmodels==0.7.4 13 | hickle==3.2.2 14 | -------------------------------------------------------------------------------- /scripts/get_cub200_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -ex 2 | # 3 | # Get the CUB200_2011 dataset 4 | # 5 | 6 | DATA_DIR=/data1/data/cub200 7 | 8 | if [[ ! -d "${DATA_DIR}" ]]; then 9 | echo "${DATA_DIR} doesn't exist, will create one."; 10 | mkdir -p ${DATA_DIR} 11 | fi 12 | 13 | wget -P ${DATA_DIR} http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz 14 | cd ${DATA_DIR}; tar -xf CUB_200_2011.tgz 15 | -------------------------------------------------------------------------------- /scripts/get_stanford_products_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -ex 2 | # 3 | # Get the Stanford Online Product dataset 4 | # 5 | 6 | DATA_DIR=/data1/data/stanford_products 7 | 8 | if [[ ! -d "${DATA_DIR}" ]]; then 9 | echo "${DATA_DIR} doesn't exist, will create one."; 10 | mkdir -p ${DATA_DIR} 11 | fi 12 | 13 | wget -P ${DATA_DIR} ftp://cs.stanford.edu/cs/cvgl/Stanford_Online_Products.zip 14 | cd ${DATA_DIR}; unzip Stanford_Online_Products.zip -------------------------------------------------------------------------------- /scripts/get_cars196_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -ex 2 | # 3 | # Get the CARS196 dataset 4 | # 5 | 6 | DATA_DIR=/data1/data/cars196 7 | 8 | if [[ ! -d "${DATA_DIR}" ]]; then 9 | echo "${DATA_DIR} doesn't exist, will create one."; 10 | mkdir -p ${DATA_DIR} 11 | fi 12 | 13 | wget -P ${DATA_DIR} http://imagenet.stanford.edu/internal/car196/car_ims.tgz 14 | wget -P ${DATA_DIR} http://imagenet.stanford.edu/internal/car196/cars_annos.mat 15 | 16 | cd ${DATA_DIR}; 17 | tar -xzf car_ims.tgz 18 | -------------------------------------------------------------------------------- /scripts/get_inshop_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -ex 2 | # 3 | # Get the In-Shop dataset 4 | # 5 | 6 | DATA_DIR=../data1/data/inshop 7 | 8 | if [[ ! -d "${DATA_DIR}" ]]; then 9 | echo "${DATA_DIR} doesn't exist, will create one."; 10 | mkdir -p ${DATA_DIR} 11 | fi 12 | 13 | # Pretty annoying but you have to download the datasets manually and put them into /data1/data/inshop 14 | # There's no direct download link to the dataset from what we could find. 15 | 16 | # Expected files to exist 17 | # /data1/data/inshop/img.zip 18 | # /data1/data/inshop/list_eval_partition.txt 19 | 20 | cd ${DATA_DIR}; 21 | unzip img.zip 22 | -------------------------------------------------------------------------------- /metric_learning/util.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import sys 8 | 9 | import numpy as np 10 | import torch 11 | 12 | 13 | class SimpleLogger(object): 14 | def __init__(self, logfile, terminal): 15 | ZERO_BUFFER_SIZE = 0 # immediately flush logs 16 | 17 | self.log = open(logfile, 'a') 18 | self.terminal = terminal 19 | 20 | def write(self, message): 21 | self.terminal.write(message) 22 | self.log.write(message) 23 | 24 | def flush(self): 25 | self.terminal.flush() 26 | self.log.flush() -------------------------------------------------------------------------------- /metric_learning/extract_features.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | 5 | 6 | def extract_feature(model, loader, gpu_device): 7 | """ 8 | Extract embeddings from given `model` for given `loader` dataset on `gpu_device`. 9 | """ 10 | model.eval() 11 | 12 | all_embeddings = [] 13 | all_labels = [] 14 | log_every_n_step = 10 15 | 16 | with torch.no_grad(): 17 | for i, (im, class_label, instance_label, index) in enumerate(loader): 18 | im = im.to(device=gpu_device) 19 | embedding = model(im) 20 | 21 | all_embeddings.append(embedding.cpu().numpy()) 22 | all_labels.extend(instance_label.tolist()) 23 | 24 | if (i + 1) % log_every_n_step == 0: 25 | print('Process Iteration {} / {}:'.format(i, len(loader))) 26 | 27 | all_embeddings = np.vstack(all_embeddings) 28 | 29 | print("Generated {} embedding matrix".format(all_embeddings.shape)) 30 | print("Generate {} labels".format(len(all_labels))) 31 | 32 | model.train() 33 | return all_embeddings, all_labels 34 | -------------------------------------------------------------------------------- /metric_learning/modules/losses.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import Parameter 5 | 6 | 7 | class NormSoftmaxLoss(nn.Module): 8 | """ 9 | L2 normalize weights and apply temperature scaling on logits. 10 | """ 11 | def __init__(self, 12 | dim, 13 | num_instances, 14 | temperature=0.05): 15 | super(NormSoftmaxLoss, self).__init__() 16 | 17 | self.weight = Parameter(torch.Tensor(num_instances, dim)) 18 | # Initialization from nn.Linear (https://github.com/pytorch/pytorch/blob/v1.0.0/torch/nn/modules/linear.py#L129) 19 | stdv = 1. / math.sqrt(self.weight.size(1)) 20 | self.weight.data.uniform_(-stdv, stdv) 21 | 22 | self.temperature = temperature 23 | self.loss_fn = nn.CrossEntropyLoss() 24 | 25 | def forward(self, embeddings, instance_targets): 26 | norm_weight = nn.functional.normalize(self.weight, dim=1) 27 | 28 | prediction_logits = nn.functional.linear(embeddings, norm_weight) 29 | 30 | loss = self.loss_fn(prediction_logits / self.temperature, instance_targets) 31 | return loss 32 | -------------------------------------------------------------------------------- /docker/docker-build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -e 2 | # 3 | # Build Docker image, and optionally publish. 4 | # 5 | # Usage: 6 | # 7 | # $ ./docker-build.sh DOCKERFILE [publish] 8 | # 9 | # Examples: 10 | # 11 | # Build docker image locally. Useful to test Dockerfile changes before publishing. 12 | # $ ./docker/docker-build.sh docker/Dockerfile 13 | # Build + publish 14 | # $ ./docker/docker-build.sh docker/Dockerfile publish 15 | # 16 | # Note: only tested on linux 17 | # 18 | 19 | export VERSION=`git rev-parse --short HEAD` 20 | 21 | DOCKERFILE=$1 22 | 23 | docker build -f $DOCKERFILE -t pinterestdocker/visualembedding:$VERSION . 24 | 25 | # If the publish command is given AND we have no code changes not checked in. 26 | # The second condition is to prevent folks from overwriting a production 27 | # commit hash. If you publish, you only publish committed changes. Unfortunately 28 | # this can be circumvented as currently implemented but the current solution 29 | # is better than none 30 | if [[ "$2" == "publish" ]]; then 31 | if [[ -z $(git status -s) ]]; then 32 | docker push pinterestdocker/visualembedding:$VERSION 33 | else 34 | echo "[PUSH ERROR] Cannot push with outstanding changes:" 35 | git status -s 36 | fi 37 | fi 38 | 39 | cat < DOCKER_TAG 40 | $VERSION 41 | EOL 42 | -------------------------------------------------------------------------------- /metric_learning/sampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class ClassBalancedBatchSampler(object): 5 | """ 6 | BatchSampler that ensures a fixed amount of images per class are sampled in the minibatch 7 | """ 8 | def __init__(self, targets, batch_size, images_per_class, ignore_index=None): 9 | self.targets = targets 10 | self.batch_size = batch_size 11 | self.images_per_class = images_per_class 12 | self.ignore_index = ignore_index 13 | self.reverse_index, self.ignored = self._build_reverse_index() 14 | 15 | def __iter__(self): 16 | for _ in range(len(self)): 17 | yield self.sample_batch() 18 | 19 | def _build_reverse_index(self): 20 | reverse_index = {} 21 | ignored = [] 22 | for i, target in enumerate(self.targets): 23 | if target == self.ignore_index: 24 | ignored.append(i) 25 | continue 26 | if target not in reverse_index: 27 | reverse_index[target] = [] 28 | reverse_index[target].append(i) 29 | return reverse_index, ignored 30 | 31 | def sample_batch(self): 32 | # Real batch size is self.images_per_class * (self.batch_size // self.images_per_class) 33 | num_classes = self.batch_size // self.images_per_class 34 | sampled_classes = np.random.choice(list(self.reverse_index.keys()), 35 | num_classes, 36 | replace=False) 37 | sampled_indices = [] 38 | for cls in sampled_classes: 39 | # Need replace = True for datasets with non-uniform distribution of images per class 40 | sampled_indices.extend(np.random.choice(self.reverse_index[cls], 41 | self.images_per_class, 42 | replace=True)) 43 | return sampled_indices 44 | 45 | def __len__(self): 46 | return len(self.targets) // self.batch_size 47 | -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod, abstractproperty 2 | from PIL import Image 3 | 4 | 5 | class Dataset(object): 6 | """ 7 | This abstract class defines the interface needed for dataset loading 8 | All concrete subclasses need to implement the following method/property: 9 | 10 | def name: returns the name of the dataset 11 | 12 | def image_root_dir: the root directory of the images 13 | 14 | def _load: the actual logic to load the dataset, 15 | It needs to populate these three lists 16 | self.image_paths = [] 17 | self.class_labels = [] 18 | self.instance_labels = [] 19 | """ 20 | 21 | __metaclass__ = ABCMeta 22 | 23 | def __init__(self, root, train=True, transform=None): 24 | self.root = root 25 | self.train = train 26 | self.transform = transform 27 | self.image_paths = [] 28 | self.class_labels = [] 29 | self.instance_labels = [] 30 | 31 | self._load() 32 | 33 | # Check dataset is loaded properly 34 | assert (len(self.image_paths) != 0) 35 | assert (len(self.instance_map) != 0) 36 | assert (len(self.image_paths) == len(self.instance_labels)) 37 | assert (len(self.image_paths) == len(self.class_labels)) 38 | 39 | @abstractproperty 40 | def name(self): 41 | raise NotImplementedError() 42 | 43 | @abstractproperty 44 | def image_root_dir(self): 45 | raise NotImplementedError() 46 | 47 | @abstractmethod 48 | def _load(self): 49 | raise NotImplementedError() 50 | 51 | def __getitem__(self, index): 52 | im_path = self.image_paths[index] 53 | im = Image.open(im_path).convert('RGB') 54 | if self.transform is not None: 55 | im = self.transform(im) 56 | class_target = self.class_labels[index] 57 | instance_target = self.instance_labels[index] 58 | return im, class_target, instance_target, index 59 | 60 | def __len__(self): 61 | return len(self.image_paths) 62 | -------------------------------------------------------------------------------- /metric_learning/modules/featurizer.py: -------------------------------------------------------------------------------- 1 | import pretrainedmodels 2 | import torch.nn as nn 3 | 4 | 5 | class EmbeddedFeatureWrapper(nn.Module): 6 | """ 7 | Wraps a base model with embedding layer modifications. 8 | """ 9 | def __init__(self, 10 | feature, 11 | input_dim, 12 | output_dim): 13 | super(EmbeddedFeatureWrapper, self).__init__() 14 | 15 | self.feature = feature 16 | self.pool = nn.AdaptiveAvgPool2d(1) 17 | self.standardize = nn.LayerNorm(input_dim, elementwise_affine=False) 18 | 19 | self.remap = None 20 | if input_dim != output_dim: 21 | self.remap = nn.Linear(input_dim, output_dim, bias=False) 22 | 23 | def forward(self, images): 24 | x = self.feature(images) 25 | x = self.pool(x) 26 | x = x.view(x.size(0), -1) 27 | x = self.standardize(x) 28 | 29 | if self.remap: 30 | x = self.remap(x) 31 | 32 | x = nn.functional.normalize(x, dim=1) 33 | 34 | return x 35 | 36 | def __str__(self): 37 | return "{}_{}".format(self.feature.name, str(self.embed)) 38 | 39 | 40 | def resnet50(output_dim): 41 | """ 42 | resnet50 variant with `output_dim` embedding output size. 43 | """ 44 | basemodel = pretrainedmodels.__dict__["resnet50"](num_classes=1000) 45 | 46 | model = nn.Sequential( 47 | basemodel.conv1, 48 | basemodel.bn1, 49 | basemodel.relu, 50 | basemodel.maxpool, 51 | 52 | basemodel.layer1, 53 | basemodel.layer2, 54 | basemodel.layer3, 55 | basemodel.layer4 56 | ) 57 | model.name = "resnet50" 58 | featurizer = EmbeddedFeatureWrapper(feature=model, input_dim=2048, output_dim=output_dim) 59 | featurizer.input_space = basemodel.input_space 60 | featurizer.input_range = basemodel.input_range 61 | featurizer.input_size = basemodel.input_size 62 | featurizer.std = basemodel.std 63 | featurizer.mean = basemodel.mean 64 | 65 | return featurizer 66 | -------------------------------------------------------------------------------- /data/stanford_products.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os.path 3 | 4 | from dataset import Dataset 5 | 6 | 7 | class StanfordOnlineProducts(Dataset): 8 | 9 | def __init__(self, root, train=True, transform=None): 10 | self.info_file = 'Ebay_{}.txt'.format('train' if train else 'test') 11 | super(StanfordOnlineProducts, self).__init__(root, train, transform) 12 | print ("Loaded {} samples for dataset {}, {} classes, {} instances".format(len(self), 13 | self.name, 14 | self.num_cls, 15 | self.num_instance)) 16 | 17 | @property 18 | def name(self): 19 | return 'stanford_online_products_{}'.format('train' if self.train else 'test') 20 | 21 | @property 22 | def image_root_dir(self): 23 | return self.root 24 | 25 | @property 26 | def num_cls(self): 27 | return len(self.class_map) 28 | 29 | @property 30 | def num_instance(self): 31 | return len(self.instance_map) 32 | 33 | def _load(self): 34 | self.instance_map = {} 35 | self.class_map = {} 36 | with open(os.path.join(self.root, self.info_file), 'r') as f: 37 | reader = csv.DictReader(f, delimiter=' ') 38 | for entry in reader: 39 | self.image_paths.append(os.path.join(self.image_root_dir, entry['path'])) 40 | class_id = int(entry['super_class_id']) - 1 41 | if class_id not in self.class_map: 42 | self.class_map[class_id] = len(self.class_map) 43 | self.class_labels.append(self.class_map[class_id]) 44 | 45 | instance_id = entry['class_id'] 46 | if instance_id not in self.instance_map: 47 | self.instance_map[instance_id] = len(self.instance_map) 48 | self.instance_labels.append(self.instance_map[instance_id]) 49 | 50 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | db.sqlite3-journal 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # IPython 80 | profile_default/ 81 | ipython_config.py 82 | 83 | # pyenv 84 | .python-version 85 | 86 | # pipenv 87 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 88 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 89 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 90 | # install all needed dependencies. 91 | #Pipfile.lock 92 | 93 | # celery beat schedule file 94 | celerybeat-schedule 95 | 96 | # SageMath parsed files 97 | *.sage.py 98 | 99 | # Environments 100 | .env 101 | .venv 102 | env/ 103 | venv/ 104 | ENV/ 105 | env.bak/ 106 | venv.bak/ 107 | 108 | # Spyder project settings 109 | .spyderproject 110 | .spyproject 111 | 112 | # Rope project settings 113 | .ropeproject 114 | 115 | # mkdocs documentation 116 | /site 117 | 118 | # mypy 119 | .mypy_cache/ 120 | .dmypy.json 121 | dmypy.json 122 | 123 | # Pyre type checker 124 | .pyre/ 125 | 126 | # IDE 127 | .idea 128 | -------------------------------------------------------------------------------- /evaluation/nmi.py: -------------------------------------------------------------------------------- 1 | import faiss 2 | import numpy as np 3 | from sklearn.cluster import KMeans 4 | from sklearn.metrics.cluster import normalized_mutual_info_score 5 | 6 | from argparse import ArgumentParser 7 | 8 | 9 | def parse_args(): 10 | """ 11 | Helper function parsing the command line options 12 | @retval ArgumentParser 13 | """ 14 | parser = ArgumentParser(description="PyTorch metric learning nmi script") 15 | # Optional arguments for the launch helper 16 | parser.add_argument("--num_workers", type=int, default=4, 17 | help="The number of workers for eval") 18 | parser.add_argument("--snap", type=str, 19 | help="The snapshot to compute nmi") 20 | parser.add_argument("--output", type=str, default="/data1/output/", 21 | help="The output file") 22 | parser.add_argument("--dataset", type=str, default="StanfordOnlineProducts", 23 | help="The dataset for training") 24 | parser.add_argument('--binarize', action='store_true') 25 | 26 | return parser.parse_args() 27 | 28 | 29 | def test_nmi(embeddings, labels, output_file): 30 | unique_labels = np.unique(labels) 31 | kmeans = KMeans(n_clusters=unique_labels.size, random_state=0, n_jobs=-1).fit(embeddings) 32 | 33 | nmi = normalized_mutual_info_score(kmeans.labels_, labels) 34 | 35 | print("NMI: {}".format(nmi)) 36 | return nmi 37 | 38 | 39 | def test_nmi_faiss(embeddings, labels): 40 | res = faiss.StandardGpuResources() 41 | flat_config = faiss.GpuIndexFlatConfig() 42 | flat_config.device = 0 43 | 44 | unique_labels = np.unique(labels) 45 | d = embeddings.shape[1] 46 | kmeans = faiss.Clustering(d, unique_labels.size) 47 | kmeans.verbose = True 48 | kmeans.niter = 300 49 | kmeans.nredo = 10 50 | kmeans.seed = 0 51 | 52 | index = faiss.GpuIndexFlatL2(res, d, flat_config) 53 | 54 | kmeans.train(embeddings, index) 55 | 56 | dists, pred_labels = index.search(embeddings, 1) 57 | 58 | pred_labels = pred_labels.squeeze() 59 | 60 | nmi = normalized_mutual_info_score(labels, pred_labels) 61 | 62 | print("NMI: {}".format(nmi)) 63 | return nmi 64 | 65 | 66 | if __name__ == '__main__': 67 | args = parse_args() 68 | embedding_file = args.snap.replace('.pth', '_embed.npy') 69 | all_embeddings = np.load(embedding_file) 70 | lable_file = args.snap.replace('.pth', '_label.npy') 71 | all_labels = np.load(lable_file) 72 | nmi = test_nmi_faiss(all_embeddings, all_labels) 73 | -------------------------------------------------------------------------------- /data1/data/symo/create_list_eval_partition.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from os.path import join 4 | import random 5 | from PIL import Image 6 | 7 | train_prob = 0.8 8 | query_prob = 0.5 9 | 10 | main_dir = 'img' 11 | txt_path = 'list_eval_partition.txt' 12 | txt_file = open(txt_path, mode='w') 13 | for main_category in os.listdir(main_dir): 14 | print(f'[INFO] processing main_category:{main_category}') 15 | category_path = join(main_dir, main_category) 16 | for sub_category in os.listdir(category_path): 17 | print(f'[INFO] processing sub_category"{sub_category}') 18 | sub_category_path = join(category_path, sub_category) 19 | if random.random() < train_prob: 20 | for img_name in os.listdir(sub_category_path): 21 | img_path = join(sub_category_path, img_name) 22 | if " " in img_name: 23 | new_img_name = img_name.replace(" ", "") 24 | new_img_path = join(sub_category_path, new_img_name) 25 | shutil.move(img_path, new_img_path) 26 | img_path = new_img_path 27 | print(f'[INFO] rename {img_path} to {new_img_path}') 28 | try: 29 | Image.open(img_path).convert('RGB') 30 | except OSError: 31 | print(f'[INFO] truncation error for {img_path}, skipping') 32 | continue 33 | 34 | record = f'{img_path} {sub_category_path} train\n' 35 | txt_file.write(record) 36 | else: 37 | for img_name in os.listdir(sub_category_path): 38 | img_path = join(sub_category_path, img_name) 39 | if " " in img_name: 40 | new_img_name = img_name.replace(" ", "") 41 | new_img_path = join(sub_category_path, new_img_name) 42 | shutil.move(img_path, new_img_path) 43 | img_path = new_img_path 44 | print(f'[INFO] rename {img_path} to {new_img_path}') 45 | if random.random() < query_prob: 46 | record = f'{img_path} {sub_category_path} query\n' 47 | else: 48 | record = f'{img_path} {sub_category_path} gallery\n' 49 | txt_file.write(record) 50 | try: 51 | Image.open(img_path).convert('RGB') 52 | except OSError: 53 | print(f'[INFO] truncation error for image {img_path}, skipping') 54 | continue 55 | print('[INFO] It is done!') 56 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:9.1-cudnn7-devel-ubuntu16.04 2 | 3 | MAINTAINER Andrew Zhai 4 | 5 | ARG BUILD_JOBS=10 6 | ENV BUILD_JOBS ${BUILD_JOBS} 7 | 8 | 9 | RUN apt-get -q -y update && apt-get install --force-yes -q -y \ 10 | build-essential \ 11 | curl \ 12 | vim \ 13 | git \ 14 | wget \ 15 | libatlas-base-dev \ 16 | libgtest-dev \ 17 | libmysqlclient-dev \ 18 | libssl-dev \ 19 | libxml2-dev \ 20 | libxslt1-dev \ 21 | libjpeg8-dev \ 22 | libexiv2-dev \ 23 | zlib1g-dev \ 24 | libffi-dev \ 25 | libboost-all-dev \ 26 | libgflags-dev \ 27 | libgoogle-glog-dev \ 28 | libhdf5-serial-dev \ 29 | libleveldb-dev \ 30 | liblmdb-dev \ 31 | liblz4-dev \ 32 | gfortran \ 33 | libopencv-dev \ 34 | libsnappy-dev \ 35 | unzip \ 36 | python-dev \ 37 | python-tk \ 38 | python-setuptools \ 39 | libiomp-dev \ 40 | libopenmpi-dev \ 41 | openmpi-bin \ 42 | openmpi-doc \ 43 | libmagickwand-dev \ 44 | unzip \ 45 | dh-autoreconf \ 46 | cmake \ 47 | libnccl2=2.2.12-1+cuda9.1 \ 48 | libnccl-dev=2.2.12-1+cuda9.1 49 | 50 | RUN easy_install pip 51 | RUN pip install \ 52 | pip==9.0.1 \ 53 | h5py==2.7.0 54 | 55 | # libprotobuf. We do this instead of apt-get to ensure python bindings and binaries are same version 56 | RUN cd /opt && \ 57 | wget https://github.com/google/protobuf/archive/v3.5.2.tar.gz && \ 58 | tar -xzf v3.5.2.tar.gz && \ 59 | (cd protobuf-3.5.2 && \ 60 | ./autogen.sh && \ 61 | CPPFLAGS="-fPIC" ./configure && \ 62 | make -j ${BUILD_JOBS} && \ 63 | make install && \ 64 | ldconfig) && \ 65 | (cd protobuf-3.5.2/python/ && \ 66 | python setup.py install) && \ 67 | rm -rf protobuf-3.5.2 68 | 69 | # Caffe2 + Pytorch 70 | ENV PYTORCH_ROOT=/opt/pytorch 71 | COPY third_party/pytorch $PYTORCH_ROOT 72 | ENV LD_PRELOAD=/usr/lib/libmpi_cxx.so 73 | 74 | RUN cd $PYTORCH_ROOT \ 75 | && pip install -r requirements.txt \ 76 | && TORCH_CUDA_ARCH_LIST="3.5 5.2 6.0 6.1 7.0+PTX" TORCH_NVCC_FLAGS="-Xfatbin -compress-all" \ 77 | FULL_CAFFE2=1 python setup.py install 78 | 79 | # torchvision 80 | ENV VISION_ROOT=/opt/vision 81 | COPY third_party/vision $VISION_ROOT 82 | 83 | RUN cd $VISION_ROOT && \ 84 | python setup.py install 85 | 86 | # faiss 87 | RUN apt-get install --force-yes -q -y libopenblas-dev 88 | 89 | ENV BLASLDFLAGS=/usr/lib/libopenblas.so.0 90 | 91 | RUN cd /opt && \ 92 | git clone https://github.com/facebookresearch/faiss.git && \ 93 | (cd faiss && \ 94 | git reset --hard 87721af1294c0dc2008d0537d9082198a477ac3a && \ 95 | ./configure && \ 96 | make -j ${BUILD_JOBS} && \ 97 | make -j ${BUILD_JOBS} install && \ 98 | make -C gpu -j ${BUILD_JOBS} && \ 99 | make -C gpu/test -j ${BUILD_JOBS} && \ 100 | make -C python gpu -j ${BUILD_JOBS} && \ 101 | make -C python build -j ${BUILD_JOBS} && \ 102 | make -C python install -j ${BUILD_JOBS}) 103 | 104 | ENV PYTHONPATH=$PYTHONPATH:/opt/faiss/python 105 | 106 | # requirements 107 | COPY requirements.txt . 108 | RUN pip install --upgrade pip 109 | RUN pip install -r requirements.txt && rm requirements.txt 110 | -------------------------------------------------------------------------------- /tests/evaluation/test_retrieval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | import numpy as np 4 | 5 | from mock import mock 6 | from mock import call 7 | from pyfakefs import fake_filesystem_unittest 8 | 9 | from evaluation.retrieval import evaluate_recall_at_k 10 | from evaluation.retrieval import evaluate_float_binary_embedding_faiss 11 | 12 | 13 | class TestRetrieval(fake_filesystem_unittest.TestCase): 14 | 15 | def setUp(self): 16 | self.setUpPyfakefs() 17 | 18 | self.query_embeddings = np.array([ 19 | [1.0, 3.0], 20 | [-3.0, 1.0] 21 | ]).astype('float32') 22 | 23 | self.db_embeddings = np.array([ 24 | [2.0, 3.0], 25 | [1.0, 5.0], 26 | [-3.0, 3.0] 27 | ]).astype('float32') 28 | 29 | self.query_labels = [1, 2] 30 | 31 | self.db_labels = [2, 1, 2] 32 | 33 | def test_regular_evaluate_recall_at_k(self): 34 | """ 35 | test recall_at_k 36 | query_db retrieval: 37 | q[0] -> 0, 1, 2 38 | q[1] -> 2, 1, 0 39 | """ 40 | 41 | dists = np.array([ 42 | [1.0, 2.0], 43 | [2.0, 5.39] 44 | ]).astype('float32') 45 | 46 | results = np.array([ 47 | [0, 1], 48 | [2, 1] 49 | ]) 50 | 51 | r_at_k = evaluate_recall_at_k(dists, results, self.query_labels, self.db_labels, 2) 52 | 53 | self.assertTrue(r_at_k.shape[0] == 2) 54 | np.testing.assert_almost_equal(r_at_k, [50.0, 100.0], 2) 55 | 56 | def test_self_evaluate_recall_at_k(self): 57 | """ 58 | test recall_at_k 59 | db_db retrieval: 60 | db[0] -> 0, 1, 2 61 | db[1] -> 1, 0, 2 62 | db[2] -> 2, 1, 0 63 | """ 64 | 65 | dists = np.array([ 66 | [0.0, 2.24, 5.0], 67 | [0.0, 2.24, 4.47], 68 | [0.0, 4.47, 5.0] 69 | ]).astype('float32') 70 | 71 | results = np.array([ 72 | [0, 1, 2], 73 | [1, 0, 2], 74 | [2, 1, 0] 75 | ]) 76 | 77 | r_at_k = evaluate_recall_at_k(dists, results, self.db_labels, self.db_labels, 2) 78 | 79 | self.assertTrue(r_at_k.shape[0] == 2) 80 | np.testing.assert_almost_equal(r_at_k, [0.0, 66.67], 2) 81 | 82 | @mock.patch("evaluation.retrieval._retrieve_knn_faiss_gpu") 83 | @mock.patch("evaluation.retrieval.evaluate_recall_at_k") 84 | def test_evaluate_float_binary_embedding_faiss(self, recall_at_k_mock, nn_mock): 85 | output_file = "epoch_test" 86 | 87 | nn_mock.side_effect = [('dist1', 'result1'), ('dist2', 'result2')] 88 | 89 | recall_at_k_mock.side_effect = [np.zeros((1000,)), np.ones((1000,))] 90 | 91 | evaluate_float_binary_embedding_faiss(self.query_embeddings, self.db_embeddings, 92 | self.query_labels, self.db_labels, 93 | output_file, k=1000, gpu_id=0) 94 | 95 | self.assertEqual(nn_mock.call_count, 2) 96 | 97 | np.testing.assert_array_equal(self.query_embeddings, nn_mock.call_args_list[0][0][0]) 98 | np.testing.assert_array_equal(self.db_embeddings, nn_mock.call_args_list[0][0][1]) 99 | 100 | np.testing.assert_array_equal(np.require(self.query_embeddings > 0, dtype='float32'), 101 | nn_mock.call_args_list[1][0][0]) 102 | np.testing.assert_array_equal(np.require(self.db_embeddings > 0, dtype='float32'), 103 | nn_mock.call_args_list[1][0][1]) 104 | 105 | self.assertEqual(recall_at_k_mock.call_count, 2) 106 | r_at_k_calls = [ 107 | call('dist1', 'result1', self.query_labels, self.db_labels, 1000), 108 | call('dist2', 'result2', self.query_labels, self.db_labels, 1000) 109 | ] 110 | recall_at_k_mock.assert_has_calls(r_at_k_calls) 111 | 112 | self.assertTrue(os.path.exists(output_file + '_identity.eval')) 113 | self.assertTrue(os.path.exists(output_file + '_binary.eval')) 114 | 115 | 116 | 117 | if __name__ == '__main__': 118 | unittest.main() 119 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Classification is a Strong Baseline for Deep Metric Learning (BMVC '19) 3 | Andrew Zhai, Hao-Yu Wu 4 | 5 | ## Paper ([https://arxiv.org/abs/1811.12649](https://arxiv.org/abs/1811.12649)) 6 | This repo contains the source code for our paper (WIP) 7 | 8 | ## Setup Repo 9 | ``` 10 | git clone https://github.com/azgo14/classification_metric_learning.git 11 | ``` 12 | 13 | The repo assumes that all data files exist under `/data1` directory locally as that local directly will be mounted ot `/data1` in the container. 14 | 15 | ## Running Commands 16 | We provide a simple utility to make running commands in a docker container easier. This tool will automatically download the expected Docker image, mount the expected directories, and make running commands simple. To use the command, add `scripts/bin/pdoc` to your PATH variable as so: 17 | ``` 18 | export PATH=$PATH:/scripts/bin 19 | ``` 20 | 21 | ### Example 22 | You can then see that commands such as `pdoc nvidia-smi` will run `nvidia-smi` inside the docker container. 23 | 24 | ## Build Docker 25 | To rebuild the docker image: 26 | 27 | 1) Initialize all submodules recursively via: 28 | ``` 29 | git submodule update --init --recursive 30 | ``` 31 | 32 | 2) Build the image with 33 | ``` 34 | ./docker/docker-build.sh ./docker/Dockerfile 35 | ``` 36 | 37 | ## Datasets 38 | Download the datasets with the following scripts. We assume data will live in the /data1 directory throughput our code 39 | ### CUB 40 | ``` 41 | ./scripts/get_cub200_dataset.sh 42 | ``` 43 | 44 | ### CARS 45 | ``` 46 | ./scripts/get_cars196_dataset.sh 47 | ``` 48 | 49 | ### Stanford Online Products 50 | ``` 51 | ./scripts/get_stanford_products_dataset.sh 52 | ``` 53 | 54 | ### In-Shop 55 | Manual download raw data from http://mmlab.ie.cuhk.edu.hk/projects/DeepFashion/InShopRetrieval.html 56 | 57 | Expected the following raw data files to exist 58 | /data1/data/inshop/img.zip 59 | /data1/data/inshop/list_eval_partition.txt 60 | 61 | ``` 62 | ./scripts/get_inshop_dataset.sh 63 | ``` 64 | 65 | ## Reproduction 66 | ### CUB 67 | ``` 68 | ./scripts/bin/pdoc CUDA_VISIBLE_DEVICES=0 CUDA_DEVICE_ORDER=PCI_BUS_ID python metric_learning/train_classification.py --dataset Cub200 --dim 2048 --model_name resnet50 --epochs_per_step 15 --num_steps 2 --test_every_n_epochs 5 --lr 0.001 --lr_mult 1 --class_balancing --images_per_class 25 --batch_size 75 69 | ``` 70 | (May differ slightly because of random seed)\ 71 | Raw Features: R@1, R@2, R@4, R@8: 65.36 & 76.76 & 85.42 & 91.51\ 72 | Binary Features: R@1, R@2, R@4, R@8: 63.67 & 75.37 & 84.54 & 90.99 73 | 74 | 75 | ### CARS 76 | ``` 77 | ./scripts/bin/pdoc CUDA_VISIBLE_DEVICES=1 CUDA_DEVICE_ORDER=PCI_BUS_ID python metric_learning/train_classification.py --dataset Cars196 --dim 2048 --model_name resnet50 --epochs_per_step 15 --num_steps 2 --test_every_n_epochs 5 --lr 0.01 --lr_mult 1 --class_balancing --images_per_class 25 --batch_size 75 78 | ``` 79 | (May differ slightly because of random seed)\ 80 | Raw Features: R@1, R@2, R@4, R@8: 89.50 & 94.18 & 96.84 & 98.41\ 81 | Binary Features: R@1, R@2, R@4, R@8: 89.29 & 93.95 & 96.61 & 98.14 82 | 83 | 84 | ### Stanford Online Products 85 | ``` 86 | ./scripts/bin/pdoc CUDA_VISIBLE_DEVICES=1 CUDA_DEVICE_ORDER=PCI_BUS_ID python metric_learning/train_classification.py --dataset StanfordOnlineProducts --dim 2048 --model_name resnet50 --epochs_per_step 15 --num_steps 2 --test_every_n_epochs 5 --lr 0.01 --lr_mult 1 --class_balancing --images_per_class 5 --batch_size 75 87 | ``` 88 | (May differ slightly because of random seed)\ 89 | Raw Features: R@1, R@10, R@100, R@1000: 79.55 & 91.54 & 96.66 & 98.95\ 90 | Binary Features: R@1, R@10, R@100, R@1000: 78.03 & 90.71 & 96.24 & 98.72 91 | 92 | 93 | ### In-Shop 94 | ``` 95 | ./scripts/bin/pdoc CUDA_VISIBLE_DEVICES=5 CUDA_DEVICE_ORDER=PCI_BUS_ID python metric_learning/train_classification.py --dataset InShop --dim 2048 --model_name resnet50 --epochs_per_step 15 --num_steps 2 --test_every_n_epochs 5 --lr 0.01 --lr_mult 1 --class_balancing --images_per_class 5 --batch_size 75 96 | ``` 97 | (May differ slightly because of random seed)\ 98 | Raw Features: R@1, R@10, R@20, R@30, R@40, R@50: 89.35 & 97.81 & 98.61 & 98.87 & 99.05 & 99.13\ 99 | Binary Features: R@1, R@10, R@20, R@30, R@40, R@50: 88.76 & 97.65 & 98.47 & 98.73 & 98.94 & 99.05 100 | 101 | ## References / re-implementations 102 | - The Computer Vision Best Practices repository: [02_state_of_the_art.ipynb](https://github.com/microsoft/computervision-recipes/blob/master/scenarios/similarity/02_state_of_the_art.ipynb) 103 | -------------------------------------------------------------------------------- /data/symo.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os.path 3 | import random 4 | 5 | from data.dataset import Dataset 6 | 7 | 8 | class Symo(Dataset): 9 | def __init__(self, root, train=True, query=True, transform=None, return_path=False): 10 | self.return_path = return_path 11 | self.query = query 12 | self.split_file = "list_eval_partition.txt" 13 | super().__init__(root, train, transform) 14 | print("Loaded {} samples for dataset {}, {} classes, {} instances".format(len(self), 15 | self.name, 16 | self.num_cls, 17 | self.num_instance)) 18 | 19 | @property 20 | def name(self): 21 | return 'inshop_{}_{}'.format('train' if self.train else 'test', 22 | 'query' if self.query else 'gallery') 23 | 24 | @property 25 | def image_root_dir(self): 26 | return self.root 27 | 28 | @property 29 | def num_cls(self): 30 | return len(self.class_map) 31 | 32 | @property 33 | def num_instance(self): 34 | return len(self.instance_map) 35 | 36 | def _load(self): 37 | self.class_map = {} 38 | with open(os.path.join(self.root, self.split_file), 'r') as f: 39 | for line in f.read().splitlines()[2:]: 40 | image_name, item_id, evaluation_status = line.strip().split() 41 | skip = True 42 | if self.train: 43 | if evaluation_status == "train": 44 | # Train data points 45 | self.image_paths.append(os.path.join(self.image_root_dir, image_name)) 46 | class_id = item_id 47 | if class_id not in self.class_map: 48 | self.class_map[class_id] = len(self.class_map) 49 | self.class_labels.append(self.class_map[class_id]) 50 | else: 51 | if evaluation_status != "train": 52 | # Test data points 53 | 54 | # Keep class ids consistent amongst query and gallery data points. The class id set is 55 | # the same for query and gallery. 56 | class_id = item_id 57 | if class_id not in self.class_map: 58 | self.class_map[class_id] = len(self.class_map) 59 | 60 | if evaluation_status == "query" and self.query: 61 | self.image_paths.append(os.path.join(self.image_root_dir, image_name)) 62 | self.class_labels.append(self.class_map[class_id]) 63 | elif evaluation_status == "gallery" and not self.query: 64 | self.image_paths.append(os.path.join(self.image_root_dir, image_name)) 65 | self.class_labels.append(self.class_map[class_id]) 66 | 67 | # same thing for in-shop 68 | self.instance_labels = self.class_labels 69 | self.instance_map = self.class_map 70 | 71 | 72 | if __name__ == '__main__': 73 | inshop_train_set = InShop('/data1/data/inshop') 74 | for i in random.sample(range(0, len(inshop_train_set)), 5): 75 | image_id, class_id, instance_id, idx = inshop_train_set[i] 76 | assert idx == i 77 | print("Image {} has class label {}, instance label {}".format(inshop_train_set.image_paths[i], 78 | instance_id, 79 | class_id)) 80 | 81 | inshop_query_set = InShop('/data1/data/inshop', train=False, query=True) 82 | for i in random.sample(range(0, len(inshop_query_set)), 5): 83 | image_id, class_id, instance_id, idx = inshop_query_set[i] 84 | assert idx == i 85 | print("Image {} has class label {}, instance label {}".format(inshop_query_set.image_paths[i], 86 | instance_id, 87 | class_id)) 88 | 89 | inshop_index_set = InShop('/data1/data/inshop', train=False, query=False) 90 | for i in random.sample(range(0, len(inshop_index_set)), 5): 91 | image_id, class_id, instance_id, idx = inshop_index_set[i] 92 | assert idx == i 93 | print("Image {} has class label {}, instance label {}".format(inshop_index_set.image_paths[i], 94 | instance_id, 95 | class_id)) 96 | -------------------------------------------------------------------------------- /data/inshop.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os.path 3 | import random 4 | 5 | from data.dataset import Dataset 6 | 7 | 8 | class InShop(Dataset): 9 | def __init__(self, root, train=True, query=True, transform=None, return_path=False): 10 | self.return_path = return_path 11 | self.query = query 12 | self.split_file = "list_eval_partition.txt" 13 | super(InShop, self).__init__(root, train, transform) 14 | print("Loaded {} samples for dataset {}, {} classes, {} instances".format(len(self), 15 | self.name, 16 | self.num_cls, 17 | self.num_instance)) 18 | 19 | @property 20 | def name(self): 21 | return 'inshop_{}_{}'.format('train' if self.train else 'test', 22 | 'query' if self.query else 'gallery') 23 | 24 | @property 25 | def image_root_dir(self): 26 | return self.root 27 | 28 | @property 29 | def num_cls(self): 30 | return len(self.class_map) 31 | 32 | @property 33 | def num_instance(self): 34 | return len(self.instance_map) 35 | 36 | def _load(self): 37 | self.class_map = {} 38 | with open(os.path.join(self.root, self.split_file), 'r') as f: 39 | for line in f.read().splitlines()[2:]: 40 | image_name, item_id, evaluation_status = line.strip().split() 41 | skip = True 42 | if self.train: 43 | if evaluation_status == "train": 44 | # Train data points 45 | self.image_paths.append(os.path.join(self.image_root_dir, image_name)) 46 | class_id = item_id 47 | if class_id not in self.class_map: 48 | self.class_map[class_id] = len(self.class_map) 49 | self.class_labels.append(self.class_map[class_id]) 50 | else: 51 | if evaluation_status != "train": 52 | # Test data points 53 | 54 | # Keep class ids consistent amongst query and gallery data points. The class id set is 55 | # the same for query and gallery. 56 | class_id = item_id 57 | if class_id not in self.class_map: 58 | self.class_map[class_id] = len(self.class_map) 59 | 60 | if evaluation_status == "query" and self.query: 61 | self.image_paths.append(os.path.join(self.image_root_dir, image_name)) 62 | self.class_labels.append(self.class_map[class_id]) 63 | elif evaluation_status == "gallery" and not self.query: 64 | self.image_paths.append(os.path.join(self.image_root_dir, image_name)) 65 | self.class_labels.append(self.class_map[class_id]) 66 | 67 | # same thing for in-shop 68 | self.instance_labels = self.class_labels 69 | self.instance_map = self.class_map 70 | 71 | 72 | if __name__ == '__main__': 73 | inshop_train_set = InShop('/data1/data/inshop') 74 | for i in random.sample(range(0, len(inshop_train_set)), 5): 75 | image_id, class_id, instance_id, idx = inshop_train_set[i] 76 | assert idx == i 77 | print("Image {} has class label {}, instance label {}".format(inshop_train_set.image_paths[i], 78 | instance_id, 79 | class_id)) 80 | 81 | inshop_query_set = InShop('/data1/data/inshop', train=False, query=True) 82 | for i in random.sample(range(0, len(inshop_query_set)), 5): 83 | image_id, class_id, instance_id, idx = inshop_query_set[i] 84 | assert idx == i 85 | print("Image {} has class label {}, instance label {}".format(inshop_query_set.image_paths[i], 86 | instance_id, 87 | class_id)) 88 | 89 | inshop_index_set = InShop('/data1/data/inshop', train=False, query=False) 90 | for i in random.sample(range(0, len(inshop_index_set)), 5): 91 | image_id, class_id, instance_id, idx = inshop_index_set[i] 92 | assert idx == i 93 | print("Image {} has class label {}, instance label {}".format(inshop_index_set.image_paths[i], 94 | instance_id, 95 | class_id)) 96 | -------------------------------------------------------------------------------- /scripts/bin/pdoc: -------------------------------------------------------------------------------- 1 | #!/bin/bash -ex 2 | # 3 | # This command launches a docker container and mounts this directory. 4 | # This tool simply makes the interface to docker a bit easier 5 | # 6 | # Arguments: 7 | # 8 | # -h 9 | # Display help. 10 | # -l 11 | # Run local docker image, without pulling from ECR. Useful for testing 12 | # local Docker builds before publishing, ie via docker/docker-build.sh 13 | # CMD ARG0 ARG1 ... ARGN 14 | # The command you'd like to run within the Docker context. 15 | # 16 | # Usage: 17 | # 18 | # pdoc [-h] [-l] [CMD ARG0 ARG1 ... ARGN] 19 | # 20 | # Examples: 21 | # 22 | # # Enter interactive console within Docker context as root 23 | # export PATH=$HOME/code/multitask_visual_embeddings/bin:$PATH 24 | # cd $HOME/code/multitask_visual_embeddings 25 | # pdoc bash 26 | # 27 | # # Run local Docker image (don't pull from ECR) 28 | # pdoc -l bash 29 | # 30 | # # Command with multiple arguments 31 | # pdoc ls / 32 | # 33 | # Note: only tested on linux 34 | # 35 | 36 | OS=$(uname) 37 | 38 | READLINK="readlink" 39 | if [ "$OS" = "Darwin" ]; then 40 | GREADLINK=$(which greadlink) 41 | if [ "${GREADLINK}" = "" ]; then 42 | brew install coreutils 43 | fi 44 | READLINK="greadlink" 45 | fi 46 | 47 | _SELF_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" 48 | 49 | # BEGIN parse arguments 50 | # https://stackoverflow.com/questions/192249/how-do-i-parse-command-line-arguments-in-bash 51 | # A POSIX variable 52 | OPTIND=1 # Reset in case getopts has been used previously in the shell. 53 | 54 | # Initialize our own variables: 55 | DEFAULT_DOCKERFILE="$(dirname $(${READLINK} -f $0))/Dockerfile" 56 | DOCKERFILE=$DEFAULT_DOCKERFILE 57 | RUN_LOCAL=0 58 | 59 | while getopts "h?ld:" opt; do 60 | case "$opt" in 61 | h|\?) 62 | echo "Usage: pdoc [-h] [-l] [CMD ARG0 ARG1 ... ARGN]" 63 | exit 0 64 | ;; 65 | l) RUN_LOCAL=1 66 | ;; 67 | esac 68 | done 69 | 70 | shift $((OPTIND-1)) 71 | 72 | [ "$1" = "--" ] && shift 73 | 74 | # If command is "pdoc ls /", then USER_CMD_AND_ARGS is: "ls /"" 75 | USER_CMD_AND_ARGS=$@ 76 | # END parse arguments 77 | 78 | PROD_VERSION=$(cat DOCKER_TAG) 79 | 80 | # Custom arguments 81 | PROD_IMAGE=pinterestdocker/visualembedding:$PROD_VERSION 82 | 83 | # https://stackoverflow.com/questions/7444504/explanation-of-colon-operator-in-foo-value 84 | # Essentially defaults for environment variables that can be overwritten 85 | : ${DOCKER_IMAGE=$PROD_IMAGE} 86 | : ${DOCKER_INTERACTIVE=-i} 87 | 88 | RED='\033[0;31m' 89 | YELLOW='\033[1;33m' 90 | GREEN='\033[0;32m' 91 | NC='\033[0m' # No Color 92 | if [ -x "$(command -v nvidia-docker)" ]; then 93 | echo -e "${GREEN}nvidia-docker is being used (GPU + CPU)${NC}\n" 94 | DOCKER_CMD="nvidia-docker run" 95 | elif [ -x "$(command -v docker)" ]; then 96 | echo -e "${YELLOW}docker is being used (CPU only)${NC}\n" 97 | DOCKER_CMD="docker run --runtime=nvidia" 98 | else 99 | echo -e "${RED}No docker/nvidia-docker installed!!!${NC}" 100 | exit 1 101 | fi 102 | 103 | # Mount commonly used directories 104 | MOUNT_ARGS="-v /tmp:/tmp" 105 | if [ -d /data1 ]; then 106 | MOUNT_ARGS="${MOUNT_ARGS} -v /data1:/data1" 107 | fi 108 | 109 | KNOX_BIN_DIR=/usr/bin/knox 110 | KNOX_LIB_DIR=/var/lib/knox 111 | 112 | if [ -e ${KNOX_BIN_DIR} ]; then 113 | MOUNT_ARGS="${MOUNT_ARGS} -v ${KNOX_BIN_DIR}:${KNOX_BIN_DIR}" 114 | fi 115 | if [ -e ${KNOX_LIB_DIR} ]; then 116 | MOUNT_ARGS="${MOUNT_ARGS} -v ${KNOX_LIB_DIR}:${KNOX_LIB_DIR}" 117 | fi 118 | 119 | # Get the directory of this script. 120 | DIR=$(dirname $(${READLINK} -f $0)) 121 | MVE_DIR=$(${READLINK} -f $DIR/../..) 122 | DOCKER_MVE_DIR=/multitask_visual_embeddings 123 | ABS_PWD=$(${READLINK} -f $PWD | sed "s#$MVE_DIR#$DOCKER_MVE_DIR#g") 124 | 125 | echo "" > $MVE_DIR/docker/.docker_log 126 | ENV_FILE=$MVE_DIR/docker/.docker_env 127 | 128 | cat > $ENV_FILE << EOL 129 | DOCKER_MVE_DIR=$DOCKER_MVE_DIR 130 | DOCKER_LOG=$DOCKER_MVE_DIR/docker/.docker_log 131 | LC_ALL=C 132 | AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY 133 | AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID 134 | AWS_SESSION_TOKEN=$AWS_SESSION_TOKEN 135 | PYTHONPATH=$DOCKER_MVE_DIR 136 | 137 | https_proxy=$https_proxy 138 | http_proxy=$http_proxy 139 | no_proxy=$no_proxy 140 | EOL 141 | 142 | if [ $DOCKER_IMAGE = $PROD_IMAGE ] && [ $RUN_LOCAL -eq 0 ]; then 143 | echo "Pulling docker image version ${PROD_VERSION}" 144 | docker pull $DOCKER_IMAGE 145 | fi 146 | 147 | # run inside bash to allow ^C handling correctly 148 | # https://github.com/bodil/pulp/issues/87 149 | # https://github.com/docker/docker/pull/12228 150 | # https://github.com/docker/docker/issues/12022 151 | # 152 | # ipc=host needed to ensure inter process communication has enough 153 | # shared memory space. https://github.com/pytorch/pytorch/issues/1158 154 | # 155 | # Run the docker image. 156 | $DOCKER_CMD \ 157 | --env-file=$ENV_FILE \ 158 | --network=host \ 159 | --ipc=host \ 160 | -t ${DOCKER_INTERACTIVE} \ 161 | -w $ABS_PWD \ 162 | -v $MVE_DIR:$DOCKER_MVE_DIR \ 163 | -v $HOME/docker_root:/root \ 164 | $MOUNT_ARGS \ 165 | $DOCKER_IMAGE bash -c "$USER_CMD_AND_ARGS" 166 | -------------------------------------------------------------------------------- /data/cub200.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import random 3 | import csv 4 | 5 | from dataset import Dataset 6 | 7 | 8 | class Cub200(Dataset): 9 | def __init__(self, root, train=True, transform=None, benchmark=True): 10 | self.image_id_2_relfile = os.path.join(root, 'images.txt') 11 | self.image_id_2_cls_id_file = os.path.join(root, 'image_class_labels.txt') 12 | self.class_name_file = os.path.join(root, 'classes.txt') 13 | self.benchmark = benchmark 14 | super(Cub200, self).__init__(root, train, transform) 15 | print "Loaded {} samples for dataset {}, {} classes, {} instances".format(len(self), self.name, self.num_cls, self.num_instance) 16 | 17 | @property 18 | def name(self): 19 | return 'cub200_{}_{}'.format('benchmark' if self.benchmark else 'random', 'train' if self.train else 'test') 20 | 21 | @property 22 | def image_root_dir(self): 23 | return os.path.join(self.root, 'images') 24 | 25 | @property 26 | def num_cls(self): 27 | return len(self.class_map) 28 | 29 | @property 30 | def num_instance(self): 31 | return len(self.instance_map) 32 | 33 | def _load(self): 34 | 35 | meta_data = self._load_meta_data() 36 | self.instance_names = {} 37 | 38 | # load train/test split 39 | instance_id_to_load = self._load_split(meta_data, benchmark=self.benchmark) 40 | 41 | self.class_map = {} 42 | self.instance_map = {} 43 | 44 | for image_id, instance_id in meta_data['id2cls'].items(): 45 | 46 | if str(instance_id) not in instance_id_to_load: 47 | continue 48 | 49 | self.image_paths.append(os.path.join(self.image_root_dir, meta_data['id2file'][image_id])) 50 | # consolidate the ids into continuous labels from 0 to num_instance 51 | if instance_id not in self.instance_map: 52 | self.instance_map[instance_id] = len(self.instance_map) 53 | self.instance_labels.append(self.instance_map[instance_id]) 54 | self.instance_names[self.instance_map[instance_id]] = meta_data['class_names'][instance_id] 55 | 56 | # Set the class_id to instance id for now 57 | if instance_id not in self.class_map: 58 | self.class_map[instance_id] = len(self.class_map) 59 | self.class_labels.append(self.class_map[instance_id]) 60 | 61 | def _load_meta_data(self): 62 | # all the ids are 1-indexed. convert them into 0-indexed 63 | meta_data = {} 64 | meta_data['id2file'] = {} 65 | meta_data['id2cls'] = {} 66 | meta_data['class_names'] = {} 67 | with open(self.image_id_2_relfile) as rf: 68 | csvreader = csv.reader(rf, delimiter=' ') 69 | for row in csvreader: 70 | meta_data['id2file'][int(row[0])] = row[1] 71 | 72 | with open(self.image_id_2_cls_id_file) as rf: 73 | csvreader = csv.reader(rf, delimiter=' ') 74 | for row in csvreader: 75 | meta_data['id2cls'][int(row[0])] = int(row[1]) 76 | 77 | with open(self.class_name_file) as rf: 78 | csvreader = csv.reader(rf, delimiter=' ') 79 | for row in csvreader: 80 | meta_data['class_names'][int(row[0])] = row[1] 81 | 82 | self.class_names = meta_data['class_names'] 83 | return meta_data 84 | 85 | def _load_split(self, meta_data, benchmark=True): 86 | split_file = 'cub_{}_{}_cls_split.txt'.format('benchmark' if benchmark else 'random', 87 | 'train' if self.train else 'test') 88 | split = os.path.join(self.root, split_file) 89 | if not os.path.exists(split): 90 | # split the classes into 50:50 train:test split 91 | 92 | num_total_classes = len(meta_data['class_names']) 93 | shuffled_idxs = range(num_total_classes) 94 | 95 | if not benchmark: 96 | import random 97 | random.seed(2018) 98 | random.shuffle(shuffled_idxs) 99 | 100 | with open(os.path.join(self.root, 101 | 'cub_{}_train_cls_split.txt'.format('benchmark' if benchmark else 'random')) 102 | , 'wb') as wf: 103 | for i in shuffled_idxs[:num_total_classes//2]: 104 | # make the class id 1-indexed to be consistent with the dataset 105 | wf.write(str(i+1) + '\n') 106 | 107 | with open(os.path.join(self.root, 108 | 'cub_{}_test_cls_split.txt'.format('benchmark' if benchmark else 'random')) 109 | , 'wb') as wf: 110 | for i in shuffled_idxs[num_total_classes//2:]: 111 | # make the class id 1-indexed to be consistent with the dataset 112 | wf.write(str(i+1) + '\n') 113 | 114 | with open(split) as f: 115 | lines = f.readlines() 116 | 117 | return set([x.strip() for x in lines]) 118 | 119 | 120 | if __name__ == '__main__': 121 | 122 | train_set = Cub200('/data1/data/cub200/CUB_200_2011') 123 | for i in random.sample(range(0,len(train_set)), 5): 124 | image_id, class_id, instance_id = train_set[i] 125 | print "Image {} has label {}, name {}".format(train_set.image_paths[i], instance_id, train_set.instance_names[instance_id]) 126 | 127 | test_set = Cub200('/data1/data/cub200/CUB_200_2011', train=False) 128 | for i in random.sample(range(0,len(test_set)), 5): 129 | image_id, class_id, instance_id = test_set[i] 130 | print "Image {} has label {}, name {}".format(test_set.image_paths[i], instance_id, test_set.instance_names[instance_id]) 131 | 132 | -------------------------------------------------------------------------------- /data/cars196.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import random 3 | from scipy.io import loadmat 4 | 5 | from dataset import Dataset 6 | 7 | 8 | class Cars196(Dataset): 9 | def __init__(self, root, train=True, transform=None, benchmark=True): 10 | self.meta_file = 'cars_annos.mat' 11 | self.benchmark = benchmark 12 | super(Cars196, self).__init__(root, train, transform) 13 | print "Loaded {} samples for dataset {}, {} classes, {} instances".format(len(self), self.name, self.num_cls, self.num_instance) 14 | 15 | @property 16 | def name(self): 17 | return 'cars196_{}_{}'.format('benchmark' if self.benchmark else 'random', 'train' if self.train else 'test') 18 | 19 | @property 20 | def image_root_dir(self): 21 | return self.root 22 | 23 | @property 24 | def num_cls(self): 25 | return len(self.class_map) 26 | 27 | @property 28 | def num_instance(self): 29 | return len(self.instance_map) 30 | 31 | def _load(self): 32 | 33 | meta_data = loadmat(os.path.join(self.root, self.meta_file), squeeze_me=True) 34 | self.instance2class = [] 35 | self.instance_names = {} 36 | self.class_names = {} 37 | 38 | # load train/test split 39 | instance_id_to_load = self._load_split(meta_data) 40 | 41 | self.class_map = {} 42 | self.instance_map = {} 43 | 44 | annotations = meta_data['annotations'] 45 | for entry in annotations: 46 | # convert Matlab 1-indexing to python 0-indexing 47 | instance_id = int(entry['class']) - 1 48 | if str(instance_id) not in instance_id_to_load: 49 | continue 50 | 51 | class_name = meta_data['class_names'][instance_id] 52 | # The annotations are typically with the format of (Make, Model, Type, Year) 53 | make = class_name.split(' ')[0] 54 | type = class_name.split(' ')[-2] 55 | vehicle_type = ' '.join([make, type]) 56 | if vehicle_type not in self.class_map: 57 | self.class_map[vehicle_type] = len(self.class_map) 58 | 59 | self.class_labels.append(self.class_map[vehicle_type]) 60 | self.class_names[self.class_map[vehicle_type]] = vehicle_type 61 | 62 | self.image_paths.append(os.path.join(self.image_root_dir, entry['relative_im_path'])) 63 | # consolidate the ids into continuous labels 64 | if instance_id not in self.instance_map: 65 | self.instance2class.append(self.class_map[vehicle_type]) 66 | self.instance_map[instance_id] = len(self.instance_map) 67 | 68 | self.instance_labels.append(self.instance_map[instance_id]) 69 | self.instance_names[self.instance_map[instance_id]] = class_name 70 | 71 | def _load_split(self, meta_data, benchmark=True): 72 | split_file = 'cars_{}_{}_cls_split.txt'.format('benchmark' if benchmark else 'random', 73 | 'train' if self.train else 'test') 74 | split = os.path.join(self.root, split_file) 75 | if not os.path.exists(split): 76 | # split the classes into 50:50 train:test split 77 | 78 | num_total_classes = meta_data['class_names'].size 79 | shuffled_idxs = range(num_total_classes) 80 | 81 | if not benchmark: 82 | import random 83 | random.seed(2018) 84 | random.shuffle(shuffled_idxs) 85 | 86 | with open(os.path.join(self.root, 87 | 'cars_{}_train_cls_split.txt'.format('benchmark' if benchmark else 'random')) 88 | , 'wb') as wf: 89 | for i in shuffled_idxs[:num_total_classes//2]: 90 | wf.write(str(i) + '\n') 91 | 92 | with open(os.path.join(self.root, 93 | 'cars_{}_test_cls_split.txt'.format('benchmark' if benchmark else 'random')) 94 | , 'wb') as wf: 95 | for i in shuffled_idxs[num_total_classes//2:]: 96 | wf.write(str(i) + '\n') 97 | 98 | with open(split) as f: 99 | lines = f.readlines() 100 | 101 | return set([x.strip() for x in lines]) 102 | 103 | 104 | if __name__ == '__main__': 105 | 106 | cars196_train_set = Cars196('/data1/data/cars196') 107 | print "Loaded {} samples for dataset {}".format(len(cars196_train_set), cars196_train_set.name) 108 | for i in random.sample(range(0,len(cars196_train_set)), 5): 109 | image_id, class_id, instance_id = cars196_train_set[i] 110 | print "Image {} has label {}, name {}, label {}, name {}".format(cars196_train_set.image_paths[i], instance_id, 111 | cars196_train_set.instance_names[instance_id], 112 | class_id, 113 | cars196_train_set.class_names[class_id]) 114 | 115 | cars196_test_set = Cars196('/data1/data/cars196', train=False) 116 | print "Loaded {} samples for dataset {}".format(len(cars196_test_set), cars196_test_set.name) 117 | for i in [2022, 1668, 2041, 1710, 2233, 2160, 3970, 3800]: 118 | # for i in random.sample(range(0,len(cars196_train_set)), 5): 119 | image_id, class_id, instance_id = cars196_test_set[i] 120 | 121 | print "Image {} has label {}, name {}, label {}, name {}".format(cars196_test_set.image_paths[i], instance_id, 122 | cars196_test_set.instance_names[instance_id], 123 | class_id, cars196_test_set.class_names[class_id]) 124 | 125 | for i, c in enumerate(cars196_test_set.instance2class): 126 | print "Instace name {} has class name {}".format(cars196_test_set.instance_names[i], 127 | cars196_test_set.class_names[c]) 128 | -------------------------------------------------------------------------------- /detect.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import sys 4 | from deep_utils import dump_pickle, load_pickle 5 | import time 6 | from itertools import chain 7 | from argparse import ArgumentParser 8 | import torch 9 | from pretrainedmodels.utils import ToRange255 10 | from pretrainedmodels.utils import ToSpaceBGR 11 | from scipy.spatial.distance import cdist 12 | from torch.utils.data import DataLoader 13 | from torch.utils.data.dataloader import default_collate 14 | from torchvision import transforms 15 | from data.inshop import InShop 16 | from metric_learning.util import SimpleLogger 17 | from metric_learning.sampler import ClassBalancedBatchSampler 18 | from PIL import Image 19 | import metric_learning.modules.featurizer as featurizer 20 | import metric_learning.modules.losses as losses 21 | import numpy as np 22 | from evaluation.retrieval import evaluate_float_binary_embedding_faiss, _retrieve_knn_faiss_gpu_inner_product 23 | 24 | dataset = "InShop" 25 | dataset_root = "" 26 | batch_size = 64 27 | model_name = "resnet50" 28 | lr = 0.01 29 | gamma = 0.1 30 | class_balancing = True 31 | images_per_class = 5 32 | lr_mult = 1 33 | dim = 2048 34 | 35 | test_every_n_epochs = 2 36 | epochs_per_step = 4 37 | pretrain_epochs = 1 38 | num_steps = 3 39 | output = "data1/output" 40 | 41 | 42 | def adjust_learning_rate(optimizer, epoch, epochs_per_step, gamma=0.1): 43 | """Sets the learning rate to the initial LR decayed by 10 every epochs""" 44 | # Skip gamma update on first epoch. 45 | if epoch != 0 and epoch % epochs_per_step == 0: 46 | for param_group in optimizer.param_groups: 47 | param_group['lr'] *= gamma 48 | print("learning rate adjusted: {}".format(param_group['lr'])) 49 | 50 | 51 | def main(): 52 | torch.cuda.set_device(0) 53 | gpu_device = torch.device('cuda') 54 | 55 | output_directory = os.path.join(output, dataset, str(dim), 56 | '_'.join([model_name, str(batch_size)])) 57 | if not os.path.exists(output_directory): 58 | os.makedirs(output_directory) 59 | out_log = os.path.join(output_directory, "train.log") 60 | sys.stdout = SimpleLogger(out_log, sys.stdout) 61 | 62 | # Select model 63 | model_factory = getattr(featurizer, model_name) 64 | model = model_factory(dim) 65 | weights = torch.load( 66 | '/home/ai/projects/symo/classification_metric_learning/data1/output/InShop/2048/resnet50_75/epoch_30.pth') 67 | model.load_state_dict(weights) 68 | eval_transform = transforms.Compose([ 69 | transforms.Resize((256, 256)), 70 | transforms.CenterCrop(max(model.input_size)), 71 | transforms.ToTensor(), 72 | ToSpaceBGR(model.input_space == 'BGR'), 73 | ToRange255(max(model.input_range) == 255), 74 | transforms.Normalize(mean=model.mean, std=model.std) 75 | ]) 76 | 77 | # Setup dataset 78 | 79 | # train_dataset = InShop('../data1/data/inshop', transform=train_transform) 80 | query_dataset = InShop('data1/data/inshop', train=False, query=True, transform=eval_transform) 81 | index_dataset = InShop('data1/data/inshop', train=False, query=False, transform=eval_transform) 82 | 83 | query_loader = DataLoader(query_dataset, 84 | batch_size=batch_size, 85 | drop_last=False, 86 | shuffle=False, 87 | pin_memory=True, 88 | num_workers=0) 89 | 90 | model.to(device='cuda') 91 | model.eval() 92 | query_image = Image.open( 93 | "/home/ai/Pictures/im3.png").convert( 94 | 'RGB') 95 | with torch.no_grad(): 96 | query_image = model(eval_transform(query_image).to('cuda').unsqueeze(0))[0].cpu().numpy() 97 | 98 | index_dataset = InShop('data1/data/inshop', train=False, query=False, transform=eval_transform) 99 | index_loader = DataLoader(index_dataset, 100 | batch_size=75, 101 | drop_last=False, 102 | shuffle=False, 103 | pin_memory=True, 104 | num_workers=0) 105 | # db_list = extract_feature(model, index_loader, 'cuda') 106 | db_list = load_pickle('db.pkl') 107 | # db_dirs = [ 108 | # "/home/ai/projects/symo/classification_metric_learning/data1/data/inshop/img/WOMEN/Blouses_Shirts/id_00000001", 109 | # "/home/ai/projects/symo/classification_metric_learning/data1/data/inshop/img/WOMEN/Blouses_Shirts/id_00000004", 110 | # "/home/ai/projects/symo/classification_metric_learning/data1/data/inshop/img/WOMEN/Blouses_Shirts/id_00000038", 111 | # "/home/ai/projects/symo/classification_metric_learning/data1/data/inshop/img/WOMEN/Blouses_Shirts/id_00000067", 112 | # ] 113 | # db_list = {} 114 | # with torch.no_grad(): 115 | # 116 | # for dir_ in db_dirs: 117 | # for n in os.listdir(dir_): 118 | # img_path = os.path.join(dir_, n) 119 | # img = Image.open(img_path) 120 | # db_list[img_path] = model(eval_transform(img).unsqueeze(0)).cpu().numpy()[0] 121 | v = get_most_similar(query_image, db_list) 122 | print(v) 123 | 124 | 125 | def get_most_similar(feature, features_dict, n=10, distance='cosine'): 126 | features = list(features_dict.values()) 127 | ids = list(features_dict.keys()) 128 | p = cdist(np.array(features), 129 | np.expand_dims(feature, axis=0), 130 | metric=distance)[:, 0] 131 | group = zip(p, ids.copy()) 132 | res = sorted(group, key=lambda x: x[0]) 133 | r = res[:n] 134 | return r 135 | 136 | 137 | def extract_feature(model, loader, gpu_device): 138 | """ 139 | Extract embeddings from given `model` for given `loader` dataset on `gpu_device`. 140 | """ 141 | model.eval() 142 | model.to(gpu_device) 143 | db_dict = {} 144 | log_every_n_step = 10 145 | 146 | with torch.no_grad(): 147 | for i, (im, class_label, instance_label, index) in enumerate(loader): 148 | im = im.to(device=gpu_device) 149 | embedding = model(im) 150 | for i, em in zip(index, embedding): 151 | db_dict[loader.dataset.image_paths[int(i)]] = em.detach().cpu().numpy() 152 | if (i + 1) % log_every_n_step == 0: 153 | print('Process Iteration {} / {}:'.format(i, len(loader))) 154 | dump_pickle('db.pkl', db_dict) 155 | return db_dict 156 | 157 | 158 | if __name__ == '__main__': 159 | main() 160 | -------------------------------------------------------------------------------- /evaluation/retrieval.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def _retrieve_knn_faiss_gpu_inner_product(query_embeddings, db_embeddings, k, gpu_id=0): 4 | """ 5 | Retrieve k nearest neighbor based on inner product 6 | 7 | Args: 8 | query_embeddings: numpy array of size [NUM_QUERY_IMAGES x EMBED_SIZE] 9 | db_embeddings: numpy array of size [NUM_DB_IMAGES x EMBED_SIZE] 10 | k: number of nn results to retrieve excluding query 11 | gpu_id: gpu device id to use for nearest neighbor (if possible for `metric` chosen) 12 | 13 | Returns: 14 | dists: numpy array of size [NUM_QUERY_IMAGES x k], distances of k nearest neighbors 15 | for each query 16 | retrieved_db_indices: numpy array of size [NUM_QUERY_IMAGES x k], indices of k nearest neighbors 17 | for each query 18 | """ 19 | import faiss 20 | 21 | res = faiss.StandardGpuResources() 22 | flat_config = faiss.GpuIndexFlatConfig() 23 | flat_config.device = gpu_id 24 | 25 | # Evaluate with inner product 26 | index = faiss.GpuIndexFlatIP(res, db_embeddings.shape[1], flat_config) 27 | index.add(db_embeddings) 28 | # retrieved k+1 results in case that query images are also in the db 29 | dists, retrieved_result_indices = index.search(query_embeddings, k + 1) 30 | 31 | return dists, retrieved_result_indices 32 | 33 | 34 | def _retrieve_knn_faiss_gpu_euclidean(query_embeddings, db_embeddings, k, gpu_id=0): 35 | """ 36 | Retrieve k nearest neighbor based on inner product 37 | 38 | Args: 39 | query_embeddings: numpy array of size [NUM_QUERY_IMAGES x EMBED_SIZE] 40 | db_embeddings: numpy array of size [NUM_DB_IMAGES x EMBED_SIZE] 41 | k: number of nn results to retrieve excluding query 42 | gpu_id: gpu device id to use for nearest neighbor (if possible for `metric` chosen) 43 | 44 | Returns: 45 | dists: numpy array of size [NUM_QUERY_IMAGES x k], distances of k nearest neighbors 46 | for each query 47 | retrieved_db_indices: numpy array of size [NUM_QUERY_IMAGES x k], indices of k nearest neighbors 48 | for each query 49 | """ 50 | import faiss 51 | 52 | res = faiss.StandardGpuResources() 53 | flat_config = faiss.GpuIndexFlatConfig() 54 | flat_config.device = gpu_id 55 | 56 | # Evaluate with inner product 57 | index = faiss.GpuIndexFlatL2(res, db_embeddings.shape[1], flat_config) 58 | index.add(db_embeddings) 59 | # retrieved k+1 results in case that query images are also in the db 60 | dists, retrieved_result_indices = index.search(query_embeddings, k + 1) 61 | 62 | return dists, retrieved_result_indices 63 | 64 | 65 | def evaluate_recall_at_k(dists, results, query_labels, db_labels, k): 66 | """ 67 | Evaluate Recall@k based on retrieval results 68 | 69 | Args: 70 | dists: numpy array of size [NUM_QUERY_IMAGES x k], distances of k nearest neighbors for each query 71 | results: numpy array of size [NUM_QUERY_IMAGES x k], indices of k nearest neighbors for each query 72 | query_labels: list of labels for each query 73 | db_labels: list of labels for each db 74 | k: number of nn results to evaluate 75 | 76 | Returns: 77 | recall_at_k: Recall@k in percentage 78 | """ 79 | 80 | self_retrieval = False 81 | 82 | if query_labels is db_labels: 83 | self_retrieval = True 84 | 85 | expected_result_size = k + 1 if self_retrieval else k 86 | 87 | assert results.shape[1] >= expected_result_size, \ 88 | "Not enough retrieved results to evaluate Recall@{}".format(k) 89 | 90 | recall_at_k = np.zeros((k,)) 91 | 92 | for i in range(len(query_labels)): 93 | pos = 0 # keep track recall at pos 94 | j = 0 # looping through results 95 | while pos < k: 96 | if self_retrieval and i == results[i, j]: 97 | # Only skip the document when query and index sets are the exact same 98 | j += 1 99 | continue 100 | if query_labels[i] == db_labels[results[i, j]]: 101 | recall_at_k[pos:] += 1 102 | break 103 | j += 1 104 | pos += 1 105 | 106 | return recall_at_k/float(len(query_labels))*100.0 107 | 108 | 109 | def evaluate_float_binary_embedding_faiss(query_embeddings, db_embeddings, query_labels, db_labels, 110 | output, k=1000, gpu_id=0): 111 | """ 112 | Wrapper function to evaluate Recall@k for float and binary embeddings 113 | output recall@k strings for Cars, CUBS, Stanford Online Product, and InShop datasets 114 | """ 115 | 116 | # ======================== float embedding evaluation ========================================================== 117 | # knn retrieval from embeddings (l2 normalized embedding + inner product = cosine similarity) 118 | dists, retrieved_result_indices = _retrieve_knn_faiss_gpu_inner_product(query_embeddings, 119 | db_embeddings, 120 | k, 121 | gpu_id=gpu_id) 122 | 123 | # evaluate recall@k 124 | r_at_k_f = evaluate_recall_at_k(dists, retrieved_result_indices, query_labels, db_labels, k) 125 | 126 | output_file = output + '_identity.eval' 127 | cars_cub_eval_str = "R@1, R@2, R@4, R@8: {:.2f} & {:.2f} & {:.2f} & {:.2f} \\\\".format( 128 | r_at_k_f[0], r_at_k_f[1], r_at_k_f[3], r_at_k_f[7]) 129 | sop_eval_str = "R@1, R@10, R@100, R@1000: {:.2f} & {:.2f} & {:.2f} & {:.2f} \\\\".format( 130 | r_at_k_f[0], r_at_k_f[9], r_at_k_f[99], r_at_k_f[999]) 131 | in_shop_eval_str = "R@1, R@10, R@20, R@30, R@40, R@50: {:.2f} & {:.2f} & {:.2f} & {:.2f} & {:.2f} & {:.2f} \\\\".format( 132 | r_at_k_f[0], r_at_k_f[9], r_at_k_f[19], r_at_k_f[29], r_at_k_f[39], r_at_k_f[49]) 133 | 134 | print(cars_cub_eval_str) 135 | print(sop_eval_str) 136 | print(in_shop_eval_str) 137 | with open(output_file, 'w') as of: 138 | of.write(cars_cub_eval_str + '\n') 139 | of.write(sop_eval_str + '\n') 140 | of.write(in_shop_eval_str + '\n') 141 | 142 | # ======================== binary embedding evaluation ========================================================= 143 | binary_query_embeddings = np.require(query_embeddings > 0, dtype='float32') 144 | binary_db_embeddings = np.require(db_embeddings > 0, dtype='float32') 145 | 146 | # knn retrieval from embeddings (binary embeddings + euclidean = hamming distance) 147 | dists, retrieved_result_indices = _retrieve_knn_faiss_gpu_euclidean(binary_query_embeddings, 148 | binary_db_embeddings, 149 | k, 150 | gpu_id=gpu_id) 151 | # evaluate recall@k 152 | r_at_k_b = evaluate_recall_at_k(dists, retrieved_result_indices, query_labels, db_labels, k) 153 | 154 | output_file = output + '_binary.eval' 155 | 156 | cars_cub_eval_str = "R@1, R@2, R@4, R@8: {:.2f} & {:.2f} & {:.2f} & {:.2f} \\\\".format( 157 | r_at_k_b[0], r_at_k_b[1], r_at_k_b[3], r_at_k_b[7]) 158 | sop_eval_str = "R@1, R@10, R@100, R@1000: {:.2f} & {:.2f} & {:.2f} & {:.2f} \\\\".format( 159 | r_at_k_b[0], r_at_k_b[9], r_at_k_b[99], r_at_k_b[999]) 160 | in_shop_eval_str = "R@1, R@10, R@20, R@30, R@40, R@50: {:.2f} & {:.2f} & {:.2f} & {:.2f} & {:.2f} & {:.2f} \\\\".format( 161 | r_at_k_b[0], r_at_k_b[9], r_at_k_b[19], r_at_k_b[29], r_at_k_b[39], r_at_k_b[49]) 162 | 163 | print(cars_cub_eval_str) 164 | print(sop_eval_str) 165 | print(in_shop_eval_str) 166 | with open(output_file, 'w') as of: 167 | of.write(cars_cub_eval_str + '\n') 168 | of.write(sop_eval_str + '\n') 169 | of.write(in_shop_eval_str + '\n') 170 | 171 | return max(r_at_k_f[0], r_at_k_b[0]) 172 | 173 | 174 | -------------------------------------------------------------------------------- /demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "1964f058", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import glob\n", 11 | "import os\n", 12 | "import sys\n", 13 | "from deep_utils import dump_pickle, load_pickle\n", 14 | "import time\n", 15 | "from itertools import chain\n", 16 | "from argparse import ArgumentParser\n", 17 | "import torch\n", 18 | "from pretrainedmodels.utils import ToRange255\n", 19 | "from pretrainedmodels.utils import ToSpaceBGR\n", 20 | "from scipy.spatial.distance import cdist\n", 21 | "from torch.utils.data import DataLoader\n", 22 | "from torch.utils.data.dataloader import default_collate\n", 23 | "from torchvision import transforms\n", 24 | "from data.inshop import InShop\n", 25 | "from metric_learning.util import SimpleLogger\n", 26 | "from metric_learning.sampler import ClassBalancedBatchSampler\n", 27 | "from PIL import Image\n", 28 | "import metric_learning.modules.featurizer as featurizer\n", 29 | "import metric_learning.modules.losses as losses\n", 30 | "import numpy as np\n", 31 | "from evaluation.retrieval import evaluate_float_binary_embedding_faiss, _retrieve_knn_faiss_gpu_inner_product\n", 32 | "from PIL import Image\n", 33 | "import matplotlib.pyplot as plt\n", 34 | "\n", 35 | "\n", 36 | "def adjust_learning_rate(optimizer, epoch, epochs_per_step, gamma=0.1):\n", 37 | " \"\"\"Sets the learning rate to the initial LR decayed by 10 every epochs\"\"\"\n", 38 | " # Skip gamma update on first epoch.\n", 39 | " if epoch != 0 and epoch % epochs_per_step == 0:\n", 40 | " for param_group in optimizer.param_groups:\n", 41 | " param_group['lr'] *= gamma\n", 42 | " print(\"learning rate adjusted: {}\".format(param_group['lr']))\n", 43 | " " 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "id": "e733af10", 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "dataset = \"InShop\"\n", 54 | "dataset_root = \"\"\n", 55 | "batch_size = 64\n", 56 | "model_name = \"resnet50\"\n", 57 | "lr = 0.01\n", 58 | "gamma = 0.1\n", 59 | "class_balancing = True\n", 60 | "images_per_class = 5\n", 61 | "lr_mult = 1\n", 62 | "dim = 2048\n", 63 | "\n", 64 | "test_every_n_epochs = 2\n", 65 | "epochs_per_step = 4\n", 66 | "pretrain_epochs = 1\n", 67 | "num_steps = 3\n", 68 | "output = \"data1/output\"\n", 69 | "create_pkl = False\n", 70 | "model_path = '/home/ai/projects/symo/classification_metric_learning/data1/output/InShop/2048/resnet50_75/epoch_30.pth'" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": null, 76 | "id": "fc0d1e8b", 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "\n", 81 | "\n", 82 | "\n", 83 | "def get_most_similar(feature, features_dict, n=10, distance='cosine'):\n", 84 | " features = list(features_dict.values())\n", 85 | " ids = list(features_dict.keys())\n", 86 | " p = cdist(np.array(features),\n", 87 | " np.expand_dims(feature, axis=0),\n", 88 | " metric=distance)[:, 0]\n", 89 | " group = zip(p, ids.copy())\n", 90 | " res = sorted(group, key=lambda x: x[0])\n", 91 | " r = res[:n]\n", 92 | " return r\n", 93 | "\n", 94 | "\n", 95 | "def extract_feature(model, loader, gpu_device):\n", 96 | " \"\"\"\n", 97 | " Extract embeddings from given `model` for given `loader` dataset on `gpu_device`.\n", 98 | " \"\"\"\n", 99 | " model.eval()\n", 100 | " model.to(gpu_device)\n", 101 | " db_dict = {}\n", 102 | " log_every_n_step = 10\n", 103 | "\n", 104 | " with torch.no_grad():\n", 105 | " for i, (im, class_label, instance_label, index) in enumerate(loader):\n", 106 | " im = im.to(device=gpu_device)\n", 107 | " embedding = model(im)\n", 108 | " for i,em in zip(index, embedding):\n", 109 | " db_dict[loader.dataset.image_paths[int(i)]] = em.detach().cpu().numpy()\n", 110 | " if (i + 1) % log_every_n_step == 0:\n", 111 | " print('Process Iteration {} / {}:'.format(i, len(loader)))\n", 112 | " dump_pickle('db.pkl', db_dict)\n", 113 | " return db_dict" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": null, 119 | "id": "853eccfc", 120 | "metadata": {}, 121 | "outputs": [], 122 | "source": [ 123 | "def main(query_img):\n", 124 | " torch.cuda.set_device(0)\n", 125 | " gpu_device = torch.device('cuda')\n", 126 | "\n", 127 | " output_directory = os.path.join(output, dataset, str(dim),\n", 128 | " '_'.join([model_name, str(batch_size)]))\n", 129 | " if not os.path.exists(output_directory):\n", 130 | " os.makedirs(output_directory)\n", 131 | " out_log = os.path.join(output_directory, \"train.log\")\n", 132 | " sys.stdout = SimpleLogger(out_log, sys.stdout)\n", 133 | "\n", 134 | " # Select model\n", 135 | " model_factory = getattr(featurizer, model_name)\n", 136 | " model = model_factory(dim)\n", 137 | " weights = torch.load(model_path)\n", 138 | " model.load_state_dict(weights)\n", 139 | " eval_transform = transforms.Compose([\n", 140 | " transforms.Resize((256, 256)),\n", 141 | " transforms.CenterCrop(max(model.input_size)),\n", 142 | " transforms.ToTensor(),\n", 143 | " ToSpaceBGR(model.input_space == 'BGR'),\n", 144 | " ToRange255(max(model.input_range) == 255),\n", 145 | " transforms.Normalize(mean=model.mean, std=model.std)\n", 146 | " ])\n", 147 | "\n", 148 | " # Setup dataset\n", 149 | "\n", 150 | " # train_dataset = InShop('../data1/data/inshop', transform=train_transform)\n", 151 | " query_dataset = InShop('data1/data/inshop', train=False, query=True, transform=eval_transform)\n", 152 | " index_dataset = InShop('data1/data/inshop', train=False, query=False, transform=eval_transform)\n", 153 | "\n", 154 | " query_loader = DataLoader(query_dataset,\n", 155 | " batch_size=batch_size,\n", 156 | " drop_last=False,\n", 157 | " shuffle=False,\n", 158 | " pin_memory=True,\n", 159 | " num_workers=0)\n", 160 | "\n", 161 | " model.to(device='cuda')\n", 162 | " model.eval()\n", 163 | " query_image = Image.open(query_img).convert('RGB')\n", 164 | " with torch.no_grad():\n", 165 | " query_image = model(eval_transform(query_image).to('cuda').unsqueeze(0))[0].cpu().numpy()\n", 166 | "\n", 167 | " index_dataset = InShop('data1/data/inshop', train=False, query=False, transform=eval_transform)\n", 168 | " index_loader = DataLoader(index_dataset,\n", 169 | " batch_size=75,\n", 170 | " drop_last=False,\n", 171 | " shuffle=False,\n", 172 | " pin_memory=True,\n", 173 | " num_workers=0)\n", 174 | " \n", 175 | " if create_pkl:\n", 176 | " db_list = extract_feature(model, index_loader, 'cuda')\n", 177 | " else:\n", 178 | " db_list = load_pickle('db.pkl')\n", 179 | " return get_most_similar(query_image, db_list)\n" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": null, 185 | "id": "d5228a15", 186 | "metadata": {}, 187 | "outputs": [], 188 | "source": [ 189 | "def visualize(query_img, images):\n", 190 | " img = Image.open(query_img)\n", 191 | " plt.imshow(img)\n", 192 | " plt.title('main_image')\n", 193 | " plt.show()\n", 194 | " for score, img_path in images:\n", 195 | " img = Image.open(img_path)\n", 196 | " plt.imshow(img)\n", 197 | " plt.title(str(score))\n", 198 | " plt.show()" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": null, 204 | "id": "71e62a55", 205 | "metadata": {}, 206 | "outputs": [], 207 | "source": [ 208 | "query_img = \"/home/ai/Pictures/im3.png\"\n", 209 | "images = main(query_img)" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": null, 215 | "id": "b69a7102", 216 | "metadata": {}, 217 | "outputs": [], 218 | "source": [ 219 | "visualize(query_img, images)" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": null, 225 | "id": "0246369a", 226 | "metadata": {}, 227 | "outputs": [], 228 | "source": [] 229 | } 230 | ], 231 | "metadata": { 232 | "kernelspec": { 233 | "display_name": "Python 3", 234 | "language": "python", 235 | "name": "python3" 236 | }, 237 | "language_info": { 238 | "codemirror_mode": { 239 | "name": "ipython", 240 | "version": 3 241 | }, 242 | "file_extension": ".py", 243 | "mimetype": "text/x-python", 244 | "name": "python", 245 | "nbconvert_exporter": "python", 246 | "pygments_lexer": "ipython3", 247 | "version": "3.8.10" 248 | } 249 | }, 250 | "nbformat": 4, 251 | "nbformat_minor": 5 252 | } 253 | -------------------------------------------------------------------------------- /metric_learning/train_classification.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import time 5 | from itertools import chain 6 | 7 | from argparse import ArgumentParser 8 | 9 | import torch 10 | from pretrainedmodels.utils import ToRange255 11 | from pretrainedmodels.utils import ToSpaceBGR 12 | from torch.utils.data import DataLoader 13 | from torch.utils.data.dataloader import default_collate 14 | from torchvision import transforms 15 | 16 | from data.inshop import InShop 17 | from data.symo import Symo 18 | from metric_learning.util import SimpleLogger 19 | from metric_learning.sampler import ClassBalancedBatchSampler 20 | 21 | import metric_learning.modules.featurizer as featurizer 22 | import metric_learning.modules.losses as losses 23 | 24 | from extract_features import extract_feature 25 | from evaluation.retrieval import evaluate_float_binary_embedding_faiss 26 | 27 | 28 | def parse_args(): 29 | """ 30 | Helper function parsing the command line options 31 | """ 32 | parser = ArgumentParser(description="PyTorch metric learning training script") 33 | # Optional arguments for the launch helper 34 | parser.add_argument("--dataset", type=str, default="StanfordOnlineProducts", 35 | help="The dataset for training") 36 | parser.add_argument("--dataset_root", type=str, default="", 37 | help="The root directory to the dataset") 38 | parser.add_argument("--batch_size", type=int, default=64, help="Batch size for training") 39 | parser.add_argument("--model_name", type=str, default="resnet50", help="The model name") 40 | parser.add_argument("--lr", type=float, default=0.01, help="The base lr") 41 | parser.add_argument("--gamma", type=float, default=0.1, help="Gamma applied to learning rate") 42 | parser.add_argument("--class_balancing", default=False, action='store_true', help="Use class balancing") 43 | parser.add_argument("--images_per_class", type=int, default=5, help="Images per class") 44 | parser.add_argument("--lr_mult", type=float, default=1, help="lr_mult for new params") 45 | parser.add_argument("--dim", type=int, default=2048, help="The dimension of the embedding") 46 | 47 | parser.add_argument("--test_every_n_epochs", type=int, default=2, help="Tests every N epochs") 48 | parser.add_argument("--epochs_per_step", type=int, default=4, help="Epochs for learning rate step") 49 | parser.add_argument("--pretrain_epochs", type=int, default=1, help="Epochs for pretraining") 50 | parser.add_argument("--num_steps", type=int, default=3, help="Num steps to take") 51 | parser.add_argument("--output", type=str, default="../data1/output", help="The output folder for training") 52 | 53 | return parser.parse_args() 54 | 55 | 56 | def adjust_learning_rate(optimizer, epoch, epochs_per_step, gamma=0.1): 57 | """Sets the learning rate to the initial LR decayed by 10 every epochs""" 58 | # Skip gamma update on first epoch. 59 | if epoch != 0 and epoch % epochs_per_step == 0: 60 | for param_group in optimizer.param_groups: 61 | param_group['lr'] *= gamma 62 | print("learning rate adjusted: {}".format(param_group['lr'])) 63 | 64 | 65 | def main(): 66 | args = parse_args() 67 | torch.cuda.set_device(0) 68 | gpu_device = torch.device('cuda') 69 | 70 | output_directory = os.path.join(args.output, args.dataset, str(args.dim), 71 | '_'.join([args.model_name, str(args.batch_size)])) 72 | os.makedirs(output_directory, exist_ok=True) 73 | out_log = os.path.join(output_directory, "train.log") 74 | sys.stdout = SimpleLogger(out_log, sys.stdout) 75 | 76 | # Select model 77 | model_factory = getattr(featurizer, args.model_name) 78 | model = model_factory(args.dim) 79 | 80 | # Setup train and eval transformations 81 | train_transform = transforms.Compose([ 82 | transforms.Resize((256, 256)), 83 | transforms.RandomCrop(max(model.input_size)), 84 | transforms.RandomHorizontalFlip(), 85 | transforms.ToTensor(), 86 | ToSpaceBGR(model.input_space == 'BGR'), 87 | ToRange255(max(model.input_range) == 255), 88 | transforms.Normalize(mean=model.mean, std=model.std) 89 | ]) 90 | eval_transform = transforms.Compose([ 91 | transforms.Resize((256, 256)), 92 | transforms.CenterCrop(max(model.input_size)), 93 | transforms.ToTensor(), 94 | ToSpaceBGR(model.input_space == 'BGR'), 95 | ToRange255(max(model.input_range) == 255), 96 | transforms.Normalize(mean=model.mean, std=model.std) 97 | ]) 98 | if args.dataset == "InShop": 99 | train_dataset = InShop('../data1/data/inshop', transform=train_transform) 100 | query_dataset = InShop('../data1/data/inshop', train=False, query=True, transform=eval_transform) 101 | index_dataset = InShop('../data1/data/inshop', train=False, query=False, transform=eval_transform) 102 | elif args.dataset == 'symo': 103 | train_dataset = Symo('../data1/data/symo', transform=train_transform) 104 | query_dataset = Symo('../data1/data/symo', train=False, query=True, transform=eval_transform) 105 | index_dataset = Symo('../data1/data/symo', train=False, query=False, transform=eval_transform) 106 | else: 107 | print("Dataset {} is not supported yet... Abort".format(args.dataset)) 108 | return 109 | 110 | # Setup dataset loader 111 | if args.class_balancing: 112 | print("Class Balancing") 113 | sampler = ClassBalancedBatchSampler(train_dataset.instance_labels, args.batch_size, args.images_per_class) 114 | train_loader = DataLoader(train_dataset, 115 | batch_sampler=sampler, num_workers=4, 116 | pin_memory=True, drop_last=False, collate_fn=default_collate) 117 | else: 118 | print("No class balancing") 119 | train_loader = DataLoader(train_dataset, 120 | batch_size=args.batch_size, 121 | drop_last=False, 122 | shuffle=True, 123 | pin_memory=True, 124 | num_workers=4) 125 | 126 | query_loader = DataLoader(query_dataset, 127 | batch_size=args.batch_size, 128 | drop_last=False, 129 | shuffle=False, 130 | pin_memory=True, 131 | num_workers=4) 132 | index_loader = DataLoader(index_dataset, 133 | batch_size=args.batch_size, 134 | drop_last=False, 135 | shuffle=False, 136 | pin_memory=True, 137 | num_workers=4) 138 | 139 | # Setup loss function 140 | loss_fn = losses.NormSoftmaxLoss(args.dim, train_dataset.num_instance) 141 | 142 | model.to(device=gpu_device) 143 | loss_fn.to(device=gpu_device) 144 | 145 | # Training mode 146 | model.train() 147 | 148 | # Start with pretraining where we finetune only new parameters to warm up 149 | opt = torch.optim.SGD(list(loss_fn.parameters()) + list(set(model.parameters()) - 150 | set(model.feature.parameters())), 151 | lr=args.lr * args.lr_mult, momentum=0.9, weight_decay=1e-4) 152 | 153 | log_every_n_step = 10 154 | for epoch in range(args.pretrain_epochs): 155 | for i, (im, _, instance_label, index) in enumerate(train_loader): 156 | data = time.time() 157 | opt.zero_grad() 158 | 159 | im = im.to(device=gpu_device, non_blocking=True) 160 | instance_label = instance_label.to(device=gpu_device, non_blocking=True) 161 | 162 | forward = time.time() 163 | embedding = model(im) 164 | loss = loss_fn(embedding, instance_label) 165 | 166 | back = time.time() 167 | loss.backward() 168 | opt.step() 169 | 170 | end = time.time() 171 | 172 | if (i + 1) % log_every_n_step == 0: 173 | print('Epoch {}, LR {}, Iteration {} / {}:\t{}'.format( 174 | args.pretrain_epochs - epoch, opt.param_groups[0]['lr'], i, len(train_loader), loss.item())) 175 | 176 | print('Data: {}\tForward: {}\tBackward: {}\tBatch: {}'.format( 177 | forward - data, back - forward, end - back, end - forward)) 178 | 179 | eval_file = os.path.join(output_directory, 'epoch_{}'.format(args.pretrain_epochs - epoch)) 180 | query_embeddings, query_labels = extract_feature(model, query_loader, gpu_device) 181 | index_embeddings, index_labels = extract_feature(model, index_loader, gpu_device) 182 | evaluate_float_binary_embedding_faiss(query_embeddings, index_embeddings, query_labels, index_labels, 183 | eval_file, k=1000, gpu_id=0) 184 | 185 | # Full end-to-end finetune of all parameters 186 | opt = torch.optim.SGD(chain(model.parameters(), loss_fn.parameters()), lr=args.lr, momentum=0.9, weight_decay=1e-4) 187 | 188 | for epoch in range(args.epochs_per_step * args.num_steps): 189 | print('Output Directory: {}'.format(output_directory)) 190 | adjust_learning_rate(opt, epoch, args.epochs_per_step, gamma=args.gamma) 191 | 192 | for i, (im, _, instance_label, index) in enumerate(train_loader): 193 | data = time.time() 194 | 195 | opt.zero_grad() 196 | 197 | im = im.to(device=gpu_device, non_blocking=True) 198 | instance_label = instance_label.to(device=gpu_device, non_blocking=True) 199 | 200 | forward = time.time() 201 | embedding = model(im) 202 | loss = loss_fn(embedding, instance_label) 203 | 204 | back = time.time() 205 | loss.backward() 206 | opt.step() 207 | 208 | end = time.time() 209 | 210 | if (i + 1) % log_every_n_step == 0: 211 | print('Epoch {}, LR {}, Iteration {} / {}:\t{}'.format( 212 | epoch, opt.param_groups[0]['lr'], i, len(train_loader), loss.item())) 213 | print('Data: {}\tForward: {}\tBackward: {}\tBatch: {}'.format( 214 | forward - data, back - forward, end - back, end - data)) 215 | 216 | snapshot_path = os.path.join(output_directory, 'epoch_{}.pth'.format(epoch + 1)) 217 | torch.save(model.state_dict(), snapshot_path) 218 | 219 | if (epoch + 1) % args.test_every_n_epochs == 0: 220 | eval_file = os.path.join(output_directory, 'epoch_{}'.format(epoch + 1)) 221 | query_embeddings, query_labels = extract_feature(model, query_loader, gpu_device) 222 | index_embeddings, index_labels = extract_feature(model, index_loader, gpu_device) 223 | evaluate_float_binary_embedding_faiss(query_embeddings, index_embeddings, query_labels, index_labels, 224 | eval_file, 225 | k=1000, gpu_id=0) 226 | 227 | 228 | if __name__ == '__main__': 229 | main() 230 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | 3 | Apache License 4 | Version 2.0, January 2004 5 | http://www.apache.org/licenses/ 6 | 7 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 8 | 9 | 1. Definitions. 10 | 11 | "License" shall mean the terms and conditions for use, reproduction, 12 | and distribution as defined by Sections 1 through 9 of this document. 13 | 14 | "Licensor" shall mean the copyright owner or entity authorized by 15 | the copyright owner that is granting the License. 16 | 17 | "Legal Entity" shall mean the union of the acting entity and all 18 | other entities that control, are controlled by, or are under common 19 | control with that entity. For the purposes of this definition, 20 | "control" means (i) the power, direct or indirect, to cause the 21 | direction or management of such entity, whether by contract or 22 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 23 | outstanding shares, or (iii) beneficial ownership of such entity. 24 | 25 | "You" (or "Your") shall mean an individual or Legal Entity 26 | exercising permissions granted by this License. 27 | 28 | "Source" form shall mean the preferred form for making modifications, 29 | including but not limited to software source code, documentation 30 | source, and configuration files. 31 | 32 | "Object" form shall mean any form resulting from mechanical 33 | transformation or translation of a Source form, including but 34 | not limited to compiled object code, generated documentation, 35 | and conversions to other media types. 36 | 37 | "Work" shall mean the work of authorship, whether in Source or 38 | Object form, made available under the License, as indicated by a 39 | copyright notice that is included in or attached to the work 40 | (an example is provided in the Appendix below). 41 | 42 | "Derivative Works" shall mean any work, whether in Source or Object 43 | form, that is based on (or derived from) the Work and for which the 44 | editorial revisions, annotations, elaborations, or other modifications 45 | represent, as a whole, an original work of authorship. For the purposes 46 | of this License, Derivative Works shall not include works that remain 47 | separable from, or merely link (or bind by name) to the interfaces of, 48 | the Work and Derivative Works thereof. 49 | 50 | "Contribution" shall mean any work of authorship, including 51 | the original version of the Work and any modifications or additions 52 | to that Work or Derivative Works thereof, that is intentionally 53 | submitted to Licensor for inclusion in the Work by the copyright owner 54 | or by an individual or Legal Entity authorized to submit on behalf of 55 | the copyright owner. For the purposes of this definition, "submitted" 56 | means any form of electronic, verbal, or written communication sent 57 | to the Licensor or its representatives, including but not limited to 58 | communication on electronic mailing lists, source code control systems, 59 | and issue tracking systems that are managed by, or on behalf of, the 60 | Licensor for the purpose of discussing and improving the Work, but 61 | excluding communication that is conspicuously marked or otherwise 62 | designated in writing by the copyright owner as "Not a Contribution." 63 | 64 | "Contributor" shall mean Licensor and any individual or Legal Entity 65 | on behalf of whom a Contribution has been received by Licensor and 66 | subsequently incorporated within the Work. 67 | 68 | 2. Grant of Copyright License. Subject to the terms and conditions of 69 | this License, each Contributor hereby grants to You a perpetual, 70 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 71 | copyright license to reproduce, prepare Derivative Works of, 72 | publicly display, publicly perform, sublicense, and distribute the 73 | Work and such Derivative Works in Source or Object form. 74 | 75 | 3. Grant of Patent License. Subject to the terms and conditions of 76 | this License, each Contributor hereby grants to You a perpetual, 77 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 78 | (except as stated in this section) patent license to make, have made, 79 | use, offer to sell, sell, import, and otherwise transfer the Work, 80 | where such license applies only to those patent claims licensable 81 | by such Contributor that are necessarily infringed by their 82 | Contribution(s) alone or by combination of their Contribution(s) 83 | with the Work to which such Contribution(s) was submitted. If You 84 | institute patent litigation against any entity (including a 85 | cross-claim or counterclaim in a lawsuit) alleging that the Work 86 | or a Contribution incorporated within the Work constitutes direct 87 | or contributory patent infringement, then any patent licenses 88 | granted to You under this License for that Work shall terminate 89 | as of the date such litigation is filed. 90 | 91 | 4. Redistribution. You may reproduce and distribute copies of the 92 | Work or Derivative Works thereof in any medium, with or without 93 | modifications, and in Source or Object form, provided that You 94 | meet the following conditions: 95 | 96 | (a) You must give any other recipients of the Work or 97 | Derivative Works a copy of this License; and 98 | 99 | (b) You must cause any modified files to carry prominent notices 100 | stating that You changed the files; and 101 | 102 | (c) You must retain, in the Source form of any Derivative Works 103 | that You distribute, all copyright, patent, trademark, and 104 | attribution notices from the Source form of the Work, 105 | excluding those notices that do not pertain to any part of 106 | the Derivative Works; and 107 | 108 | (d) If the Work includes a "NOTICE" text file as part of its 109 | distribution, then any Derivative Works that You distribute must 110 | include a readable copy of the attribution notices contained 111 | within such NOTICE file, excluding those notices that do not 112 | pertain to any part of the Derivative Works, in at least one 113 | of the following places: within a NOTICE text file distributed 114 | as part of the Derivative Works; within the Source form or 115 | documentation, if provided along with the Derivative Works; or, 116 | within a display generated by the Derivative Works, if and 117 | wherever such third-party notices normally appear. The contents 118 | of the NOTICE file are for informational purposes only and 119 | do not modify the License. You may add Your own attribution 120 | notices within Derivative Works that You distribute, alongside 121 | or as an addendum to the NOTICE text from the Work, provided 122 | that such additional attribution notices cannot be construed 123 | as modifying the License. 124 | 125 | You may add Your own copyright statement to Your modifications and 126 | may provide additional or different license terms and conditions 127 | for use, reproduction, or distribution of Your modifications, or 128 | for any such Derivative Works as a whole, provided Your use, 129 | reproduction, and distribution of the Work otherwise complies with 130 | the conditions stated in this License. 131 | 132 | 5. Submission of Contributions. Unless You explicitly state otherwise, 133 | any Contribution intentionally submitted for inclusion in the Work 134 | by You to the Licensor shall be under the terms and conditions of 135 | this License, without any additional terms or conditions. 136 | Notwithstanding the above, nothing herein shall supersede or modify 137 | the terms of any separate license agreement you may have executed 138 | with Licensor regarding such Contributions. 139 | 140 | 6. Trademarks. This License does not grant permission to use the trade 141 | names, trademarks, service marks, or product names of the Licensor, 142 | except as required for reasonable and customary use in describing the 143 | origin of the Work and reproducing the content of the NOTICE file. 144 | 145 | 7. Disclaimer of Warranty. Unless required by applicable law or 146 | agreed to in writing, Licensor provides the Work (and each 147 | Contributor provides its Contributions) on an "AS IS" BASIS, 148 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 149 | implied, including, without limitation, any warranties or conditions 150 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 151 | PARTICULAR PURPOSE. You are solely responsible for determining the 152 | appropriateness of using or redistributing the Work and assume any 153 | risks associated with Your exercise of permissions under this License. 154 | 155 | 8. Limitation of Liability. In no event and under no legal theory, 156 | whether in tort (including negligence), contract, or otherwise, 157 | unless required by applicable law (such as deliberate and grossly 158 | negligent acts) or agreed to in writing, shall any Contributor be 159 | liable to You for damages, including any direct, indirect, special, 160 | incidental, or consequential damages of any character arising as a 161 | result of this License or out of the use or inability to use the 162 | Work (including but not limited to damages for loss of goodwill, 163 | work stoppage, computer failure or malfunction, or any and all 164 | other commercial damages or losses), even if such Contributor 165 | has been advised of the possibility of such damages. 166 | 167 | 9. Accepting Warranty or Additional Liability. While redistributing 168 | the Work or Derivative Works thereof, You may choose to offer, 169 | and charge a fee for, acceptance of support, warranty, indemnity, 170 | or other liability obligations and/or rights consistent with this 171 | License. However, in accepting such obligations, You may act only 172 | on Your own behalf and on Your sole responsibility, not on behalf 173 | of any other Contributor, and only if You agree to indemnify, 174 | defend, and hold each Contributor harmless for any liability 175 | incurred by, or claims asserted against, such Contributor by reason 176 | of your accepting any such warranty or additional liability. 177 | 178 | END OF TERMS AND CONDITIONS 179 | 180 | APPENDIX: How to apply the Apache License to your work. 181 | 182 | To apply the Apache License to your work, attach the following 183 | boilerplate notice, with the fields enclosed by brackets "[]" 184 | replaced with your own identifying information. (Don't include 185 | the brackets!) The text should be enclosed in the appropriate 186 | comment syntax for the file format. We also recommend that a 187 | file or class name and description of purpose be included on the 188 | same "printed page" as the copyright notice for easier 189 | identification within third-party archives. 190 | 191 | Copyright [yyyy] [name of copyright owner] 192 | 193 | Licensed under the Apache License, Version 2.0 (the "License"); 194 | you may not use this file except in compliance with the License. 195 | You may obtain a copy of the License at 196 | 197 | http://www.apache.org/licenses/LICENSE-2.0 198 | 199 | Unless required by applicable law or agreed to in writing, software 200 | distributed under the License is distributed on an "AS IS" BASIS, 201 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 202 | See the License for the specific language governing permissions and 203 | limitations under the License. 204 | --------------------------------------------------------------------------------