├── .gitignore ├── LICENSE ├── README.md ├── config.py ├── data_engine.py ├── datasets └── split_train_test.py ├── evaluation.py ├── imgs ├── SEC.png └── experiment_results.png ├── learner.py ├── loss.py ├── models └── bninception.py ├── mytrain.py ├── mytrain.sh ├── myutils.py ├── test_sop.py └── test_sop.sh /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | data/ 3 | output*/ 4 | ckpts/ 5 | *.pth 6 | *.t7 7 | *.png 8 | *.jpg 9 | tmp*.py 10 | # run*.sh 11 | *.pdf 12 | 13 | 14 | # Byte-compiled / optimized / DLL files 15 | __pycache__/ 16 | *.py[cod] 17 | *$py.class 18 | 19 | # C extensions 20 | *.so 21 | 22 | # Distribution / packaging 23 | .Python 24 | build/ 25 | develop-eggs/ 26 | dist/ 27 | downloads/ 28 | eggs/ 29 | .eggs/ 30 | lib/ 31 | lib64/ 32 | parts/ 33 | sdist/ 34 | var/ 35 | wheels/ 36 | *.egg-info/ 37 | .installed.cfg 38 | *.egg 39 | MANIFEST 40 | 41 | # PyInstaller 42 | # Usually these files are written by a python script from a template 43 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 44 | *.manifest 45 | *.spec 46 | 47 | # Installer logs 48 | pip-log.txt 49 | pip-delete-this-directory.txt 50 | 51 | # Unit test / coverage reports 52 | htmlcov/ 53 | .tox/ 54 | .coverage 55 | .coverage.* 56 | .cache 57 | nosetests.xml 58 | coverage.xml 59 | *.cover 60 | .hypothesis/ 61 | .pytest_cache/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | local_settings.py 70 | db.sqlite3 71 | 72 | # Flask stuff: 73 | instance/ 74 | .webassets-cache 75 | 76 | # Scrapy stuff: 77 | .scrapy 78 | 79 | # Sphinx documentation 80 | docs/_build/ 81 | 82 | # PyBuilder 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # celery beat schedule file 92 | celerybeat-schedule 93 | 94 | # SageMath parsed files 95 | *.sage.py 96 | 97 | # Environments 98 | .env 99 | .venv 100 | env/ 101 | venv/ 102 | ENV/ 103 | env.bak/ 104 | venv.bak/ 105 | 106 | # Spyder project settings 107 | .spyderproject 108 | .spyproject 109 | 110 | # Rope project settings 111 | .ropeproject 112 | 113 | # mkdocs documentation 114 | /site 115 | 116 | # mypy 117 | .mypy_cache/ 118 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Dyfine 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SphericalEmbedding 2 | 3 | This repository is the official implementation of [Deep Metric Learning with Spherical Embedding](https://arxiv.org/abs/2011.02785) on deep metric learning (DML) task. 4 | 5 | >📋 Training a vanilla triplet loss / semihard triplet loss / normalized N-pair loss (tuplet loss) / multi-similarity loss on CUB200-2011 / Cars196 / SOP / In-Shop datasets. 6 | 7 |
8 | 9 | 10 | ## Requirements 11 | 12 | This repo was tested with Ubuntu 16.04.1 LTS, Python 3.6, PyTorch 1.1.0, and CUDA 10.1. 13 | 14 | Requirements: torch==1.1.0, tensorboardX 15 | 16 | ## Training 17 | 18 | 1. Prepare datasets and pertained BN-Inception. 19 | 20 | Download datasets: [CUB200-2011](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html), [Cars196](https://ai.stanford.edu/~jkrause/cars/car_dataset.html), [SOP](https://cvgl.stanford.edu/projects/lifted_struct/), [In-Shop](http://mmlab.ie.cuhk.edu.hk/projects/DeepFashion/InShopRetrieval.html), unzip and organize them as follows. 21 | 22 | ``` 23 | └───datasets 24 | └───split_train_test.py 25 | └───CUB_200_2011 26 | | └───images.txt 27 | | └───images 28 | | └───001.Black_footed_Albatross 29 | | └───... 30 | └───CARS196 31 | | └───cars_annos.mat 32 | | └───car_ims 33 | | └───000001.jpg 34 | | └───... 35 | └───SOP 36 | | └───Stanford_Online_Products 37 | | └───Ebay_train.txt 38 | | └───Ebay_test.txt 39 | | └───bicycle_final 40 | | └───... 41 | └───Inshop 42 | | └───list_eval_partition.txt 43 | | └───img 44 | | └───MEN 45 | | └───WOMEN 46 | | └───... 47 | ``` 48 | 49 | ​ Then run ```split_train_test.py``` to generate training and testing lists. 50 | 51 | ​ Download the imagenet pertained [BN-Inception](http://data.lip6.fr/cadene/pretrainedmodels/bn_inception-52deb4733.pth) and put it into ```./pretrained_models```. 52 | 53 | 2. To train the model(s) in the paper, run the following commands or use ```sh mytrain.sh```. 54 | 55 | Train models with vanilla triplet loss. 56 | 57 | ```train 58 | CUDA_VISIBLE_DEVICES=0 python train.py --use_dataset CUB --instances 3 --lr 0.5e-5 --lr_p 0.25e-5 \ 59 | --lr_gamma 0.1 --use_loss triplet 60 | ``` 61 | 62 | ​ Train models with vanilla triplet loss + SEC. 63 | 64 | ```train 65 | CUDA_VISIBLE_DEVICES=0 python train.py --use_dataset CUB --instances 3 --lr 0.5e-5 --lr_p 0.25e-5 \ 66 | --lr_gamma 0.1 --use_loss triplet --sec_wei 1.0 67 | ``` 68 | 69 | ​ Train models with vanilla triplet loss + L2-reg. 70 | 71 | ```train 72 | CUDA_VISIBLE_DEVICES=0 python train.py --use_dataset CUB --instances 3 --lr 0.5e-5 --lr_p 0.25e-5 \ 73 | --lr_gamma 0.1 --use_loss triplet --l2reg_wei 1e-4 74 | ``` 75 | 76 | ​ Similarly, we set ```--use_loss``` to ```semihtriplet```/```n-npair```/```ms``` and ```--instances``` to ```3```/```2```/```5```, for training models with semihard triplet loss / normalized N-pair loss / multi-similarity loss. We set ```--use_dataset``` to ```Cars```/```SOP```/```Inshop```, for training models on other datasets. 77 | 78 | >📋 The detailed settings of the above hyper-parameters is provided in Appendix B of our paper (with two exceptions to the lr settings listed below). 79 | > 80 | >(a) multi-similarity loss without SEC/L2-reg on CUB: 1e-5/0.5e-5/0.1@3k, 6k 81 | > 82 | >(b) multi-similarity loss without SEC/L2-reg on Cars: 2e-5/2e-5/0.1@2k 83 | > 84 | >(We find that using a larger learning rate harms the original loss function.) 85 | > 86 | >When training on a different dataset or with a different loss function, we only need to modify the hyper-parameters in above commands and the head settings (only when using multi-similarity loss without SEC/L2-reg, we need to set need_bn=False, 87 | > 88 | >``` 89 | >self.model = torch.nn.DataParallel(BNInception(need_bn=False)).cuda() 90 | >``` 91 | > 92 | >in line 24 of learner.py). 93 | 94 | >📋 Additionally, to use SEC with EMA method, we need to set ```--norm_momentum ```, where norm_momentum denotes $\rho$ in Appendix D of our paper. 95 | 96 | ## Testing 97 | 98 | The test of NMI and F1 on SOP costs a lot of time, and we thus conduct it only after the training process (we only conduct test of R@K during training). In particular, run: 99 | 100 | ```eval 101 | CUDA_VISIBLE_DEVICES=0 python test_sop.py --use_dataset SOP --test_sop_model SOP_xxxx_xxxx 102 | ``` 103 | 104 | or use ```sh test_sop.sh``` for a complete test of NMI, F1, and R@K on SOP. Here ```SOP_xxxx_xxxx``` is the model to be tested which could be found in ```./work_space```. 105 | 106 | For other three datasets, the test of NMI, F1, and R@K is conducted during the training process. 107 | 108 | ## Results 109 | 110 | Our model achieves the following performance on CUB200-2011, Cars196, SOP, and In-Shop datasets: 111 | 112 |
113 | 114 | ## Citation 115 | 116 | If you find this repo useful for your research, please consider citing this paper 117 | 118 | @article{zhang2020deep, 119 | title={Deep Metric Learning with Spherical Embedding}, 120 | author={Zhang, Dingyi and Li, Yingming and Zhang, Zhongfei}, 121 | journal={arXiv preprint arXiv:2011.02785}, 122 | year={2020} 123 | } 124 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import myutils 3 | import argparse 4 | import torch 5 | from torchvision import transforms 6 | import datetime 7 | from easydict import EasyDict as edict 8 | import os, logging, sys 9 | 10 | def get_config(): 11 | 12 | parser = argparse.ArgumentParser('argument for training') 13 | 14 | parser.add_argument('--use_dataset', type=str, default='Cars', choices=['CUB', 'Cars', 'SOP', 'Inshop']) 15 | # batch 16 | parser.add_argument('--batch_size', type=int, default=120) 17 | parser.add_argument('--instances', type=int, default=3) 18 | # optimization 19 | parser.add_argument('--lr', type=float, default=0.0) 20 | parser.add_argument('--lr_p', type=float, default=0.0) 21 | parser.add_argument('--lr_gamma', type=float, default=0.0) 22 | # model dataset 23 | parser.add_argument('--freeze_bn', type=int, default=1) 24 | # method 25 | parser.add_argument('--use_loss', type=str, default='triplet', choices=['triplet', 'n-npair', 'semihtriplet', 'ms']) 26 | parser.add_argument('--sec_wei', type=float, default=0.0) 27 | parser.add_argument('--norm_momentum', type=float, default=1.0) 28 | parser.add_argument('--l2reg_wei', type=float, default=0.0) 29 | 30 | parser.add_argument('--test_sop_model', type=str, default='') 31 | 32 | conf = parser.parse_args() 33 | 34 | conf.num_devs = 1 35 | 36 | if conf.use_dataset == 'CUB': 37 | conf.lr = 1.0e-5 if conf.lr==0 else conf.lr 38 | conf.lr_p = 0.5e-5 if conf.lr_p==0 else conf.lr_p 39 | conf.weight_decay = 0.5 * 5e-3 40 | 41 | conf.start_step = 0 42 | conf.lr_gamma = 0.1 if conf.lr_gamma==0 else conf.lr_gamma 43 | if conf.use_loss=='ms': 44 | conf.step_milestones = [3000, 6000, 9000] 45 | else: 46 | conf.step_milestones = [5000, 9000, 9000] 47 | conf.steps = 8000 48 | 49 | elif conf.use_dataset == 'Cars': 50 | conf.lr = 1e-5 if conf.lr==0 else conf.lr 51 | conf.lr_p = 1e-5 if conf.lr_p==0 else conf.lr_p 52 | conf.weight_decay = 0.5 * 5e-3 53 | 54 | conf.start_step = 0 55 | if conf.lr_gamma == 0.1: 56 | conf.step_milestones = [2000, 9000, 9000] 57 | elif conf.lr_gamma == 0.5: 58 | conf.step_milestones = [4000, 6000, 9000] 59 | conf.steps = 8000 60 | 61 | elif conf.use_dataset == 'SOP': 62 | conf.lr = 2.5e-4 if conf.lr==0 else conf.lr 63 | conf.lr_p = 0.5e-4 if conf.lr_p==0 else conf.lr_p 64 | conf.weight_decay = 1e-5 65 | 66 | conf.start_step = 0 67 | conf.lr_gamma = 0.1 if conf.lr_gamma==0 else conf.lr_gamma 68 | conf.step_milestones = [6e3, 18e3, 35e3] 69 | conf.steps = 12e3 70 | 71 | elif conf.use_dataset == 'Inshop': 72 | conf.lr = 5e-4 if conf.lr==0 else conf.lr 73 | conf.lr_p = 1e-4 if conf.lr_p==0 else conf.lr_p 74 | conf.weight_decay = 1e-5 75 | 76 | conf.start_step = 0 77 | conf.lr_gamma = 0.1 if conf.lr_gamma==0 else conf.lr_gamma 78 | conf.step_milestones = [6e3, 18e3, 35e3] 79 | conf.steps = 12e3 80 | 81 | conf.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 82 | 83 | now_time = datetime.datetime.now().strftime('%m%d_%H%M') 84 | conf_work_path = 'work_space/' + conf.use_dataset + '_' + now_time 85 | myutils.mkdir_p(conf_work_path, delete=True) 86 | myutils.set_file_logger(work_dir=conf_work_path, log_level=logging.DEBUG) 87 | sys.stdout = myutils.Logger(conf_work_path + '/log-prt') 88 | sys.stderr = myutils.Logger(conf_work_path + '/log-prt-err') 89 | 90 | path0, path1 = conf_work_path.split('/') 91 | conf.log_path = Path(path0) / 'logs' / path1 / 'log' 92 | conf.work_path = Path(conf_work_path) 93 | conf.model_path = conf.work_path / 'models' 94 | conf.save_path = conf.work_path / 'save' 95 | 96 | conf.start_eval = False 97 | 98 | conf.num_workers = 8 99 | 100 | conf.bninception_pretrained_model_path = './pretrained_models/bn_inception-52deb4733.pth' 101 | 102 | conf.transform_dict = {} 103 | conf.use_simple_aug = False 104 | 105 | conf.transform_dict['rand-crop'] = \ 106 | transforms.Compose([ 107 | transforms.Resize(size=(256, 256)) if conf.use_simple_aug else transforms.Resize(size=256), 108 | transforms.RandomCrop((227, 227)) if conf.use_simple_aug else transforms.RandomResizedCrop( 109 | scale=[0.16, 1], 110 | size=227 111 | ), 112 | transforms.RandomHorizontalFlip(), 113 | transforms.ToTensor(), 114 | transforms.Normalize(mean=[123 / 255.0, 117 / 255.0, 104 / 255.0], 115 | std=[1.0 / 255, 1.0 / 255, 1.0 / 255]), 116 | transforms.Lambda(lambda x: x[[2, 1, 0], ...]) #to BGR 117 | ]) 118 | conf.transform_dict['center-crop'] = \ 119 | transforms.Compose([ 120 | transforms.Resize(size=(256, 256)) if conf.use_simple_aug else transforms.Resize(size=256), 121 | transforms.CenterCrop(227), 122 | transforms.ToTensor(), 123 | transforms.Normalize(mean=[123 / 255.0, 117 / 255.0, 104 / 255.0], 124 | std=[1.0 / 255, 1.0 / 255, 1.0 / 255]), 125 | transforms.Lambda(lambda x: x[[2, 1, 0], ...]) #to BGR 126 | ]) 127 | 128 | 129 | 130 | return conf 131 | -------------------------------------------------------------------------------- /data_engine.py: -------------------------------------------------------------------------------- 1 | import myutils 2 | import os 3 | import torch 4 | import numpy as np 5 | import os.path as osp 6 | from PIL import Image 7 | from torch.utils.data.sampler import Sampler 8 | from collections import defaultdict 9 | import re 10 | 11 | class MSBaseDataSet(torch.utils.data.Dataset): 12 | """ 13 | Basic Dataset read image path from img_source 14 | img_source: list of img_path and label 15 | """ 16 | def __init__(self, conf, img_source, transform=None, mode="RGB"): 17 | self.mode = mode 18 | 19 | self.root = os.path.dirname(img_source) 20 | assert os.path.exists(img_source), f"{img_source} NOT found." 21 | self.img_source = img_source 22 | 23 | self.label_list = list() 24 | self.path_list = list() 25 | self._load_data() 26 | self.label_index_dict = self._build_label_index_dict() 27 | 28 | self.num_cls = len(self.label_index_dict.keys()) 29 | self.num_train = len(self.label_list) 30 | 31 | self.transform = transform 32 | 33 | def __len__(self): 34 | return len(self.label_list) 35 | 36 | def __repr__(self): 37 | return self.__str__() 38 | 39 | def __str__(self): 40 | return f"| Dataset Info |datasize: {self.__len__()}|num_labels: {len(set(self.label_list))}|" 41 | 42 | def _load_data(self): 43 | with open(self.img_source, 'r') as f: 44 | for line in f: 45 | _path, _label = re.split(r",| ", line.strip()) 46 | self.path_list.append(_path) 47 | self.label_list.append(_label) 48 | 49 | def _build_label_index_dict(self): 50 | index_dict = defaultdict(list) 51 | for i, label in enumerate(self.label_list): 52 | index_dict[label].append(i) 53 | return index_dict 54 | 55 | def read_image(self, img_path, mode='RGB'): 56 | """Keep reading image until succeed. 57 | This can avoid IOError incurred by heavy IO process.""" 58 | got_img = False 59 | if not osp.exists(img_path): 60 | raise IOError(f"{img_path} does not exist") 61 | while not got_img: 62 | try: 63 | img = Image.open(img_path).convert("RGB") 64 | if mode == "BGR": 65 | r, g, b = img.split() 66 | img = Image.merge("RGB", (b, g, r)) 67 | got_img = True 68 | except IOError: 69 | print(f"IOError incurred when reading '{img_path}'. Will redo.") 70 | pass 71 | return img 72 | 73 | def __getitem__(self, index): 74 | path = self.path_list[index] 75 | img_path = os.path.join(self.root, path) 76 | label = self.label_list[index] 77 | 78 | img = self.read_image(img_path, mode=self.mode) 79 | 80 | if self.transform is not None: 81 | img = self.transform(img) 82 | return {'image': img, 'label': int(label), 'index': index} 83 | 84 | 85 | class RandomIdSampler(Sampler): 86 | def __init__(self, conf, label_index_dict): 87 | self.label_index_dict = label_index_dict 88 | self.num_train = 0 89 | for k in self.label_index_dict.keys(): 90 | self.num_train += len(self.label_index_dict[k]) 91 | 92 | self.num_instances = conf.instances 93 | self.batch_size = conf.batch_size 94 | assert self.batch_size % self.num_instances == 0 95 | self.num_pids_per_batch = self.batch_size // self.num_instances 96 | 97 | self.ids = list(self.label_index_dict.keys()) 98 | 99 | self.length = self.num_train//self.batch_size * self.batch_size 100 | self.conf = conf 101 | 102 | def __len__(self): 103 | return self.length 104 | 105 | def get_batch_ids(self): 106 | pids = [] 107 | 108 | pids = np.random.choice(self.ids, 109 | size=self.num_pids_per_batch, 110 | replace=False) 111 | return pids 112 | 113 | def get_batch_idxs(self): 114 | pids = self.get_batch_ids() 115 | 116 | inds = [] 117 | cnt = 0 118 | for pid in pids: 119 | index_list = self.label_index_dict[pid] 120 | if len(index_list) >= self.num_instances: 121 | t = np.random.choice(index_list, size=self.num_instances, replace=False) 122 | else: 123 | t = np.random.choice(index_list, size=self.num_instances, replace=True) 124 | t_ = t.astype(int) 125 | for ind in t: 126 | yield ind 127 | cnt += 1 128 | if cnt == self.batch_size: 129 | break 130 | if cnt == self.batch_size: 131 | break 132 | 133 | def __iter__(self): 134 | cnt = 0 135 | while cnt < len(self): 136 | for ind in self.get_batch_idxs(): 137 | cnt += 1 138 | yield ind 139 | 140 | 141 | -------------------------------------------------------------------------------- /datasets/split_train_test.py: -------------------------------------------------------------------------------- 1 | from scipy.io import loadmat 2 | 3 | 4 | # CUB200-2011 5 | with open('./CUB_200_2011/images.txt', 'r') as src: 6 | srclines = src.readlines() 7 | 8 | with open('./CUB_200_2011/cub_train.txt', 'w') as tf: 9 | for line in srclines: 10 | i, fname = line.strip().split() 11 | label = int(fname.split('.', 1)[0]) 12 | if label <= 100: 13 | print('images/{},{}'.format(fname, label-1), file=tf) 14 | 15 | with open('./CUB_200_2011/cub_test.txt', 'w') as tf: 16 | for line in srclines: 17 | i, fname = line.strip().split() 18 | label = int(fname.split('.', 1)[0]) 19 | if label > 100: 20 | print('images/{},{}'.format(fname, label-1), file=tf) 21 | 22 | 23 | # Cars196 24 | file = loadmat('./CARS196/cars_annos.mat') 25 | annos = file['annotations'] 26 | 27 | with open('./CARS196/cars_train.txt', 'w') as tf: 28 | for i in range(16185): 29 | if annos[0,i][-2] <= 98: 30 | print('{},{}'.format(annos[0,i][0][0], annos[0,i][-2][0][0]-1), file=tf) 31 | 32 | with open('./CARS196/cars_test.txt', 'w') as tf: 33 | for i in range(16185): 34 | if annos[0,i][-2] > 98: 35 | print('{},{}'.format(annos[0,i][0][0], annos[0,i][-2][0][0]-1), file=tf) 36 | 37 | 38 | # SOP 39 | with open('./SOP/Stanford_Online_Products/Ebay_train.txt', 'r') as src: 40 | srclines = src.readlines() 41 | 42 | with open('./SOP/sop_train.txt', 'w') as tf: 43 | for i in range(1, len(srclines)): 44 | line = srclines[i] 45 | line_split = line.strip().split(' ') 46 | cls_id = str(int(line_split[1]) - 1) 47 | img_path = 'Stanford_Online_Products/'+line_split[3] 48 | print(img_path+','+cls_id, file=tf) 49 | 50 | with open('./SOP/Stanford_Online_Products/Ebay_test.txt', 'r') as src: 51 | srclines = src.readlines() 52 | 53 | with open('./SOP/sop_test.txt', 'w') as tf: 54 | for i in range(1, len(srclines)): 55 | line = srclines[i] 56 | line_split = line.strip().split(' ') 57 | cls_id = str(int(line_split[1]) - 1) 58 | img_path = 'Stanford_Online_Products/'+line_split[3] 59 | print(img_path+','+cls_id, file=tf) 60 | 61 | 62 | # In-Shop 63 | with open('./Inshop/list_eval_partition.txt', 'r') as file_to_read: 64 | lines = file_to_read.readlines() 65 | 66 | with open('./Inshop/inshop_train.txt', 'w') as tf: 67 | cls_name2idx = {} 68 | cls_num = 0 69 | for line in lines: 70 | words = line.strip().split() 71 | if len(words)==3: 72 | if words[-1]=='train': 73 | path = words[0] 74 | cls_name = words[1] 75 | if cls_name not in cls_name2idx.keys(): 76 | cls_name2idx[cls_name] = cls_num 77 | cls_num += 1 78 | print('{},{}'.format(path, cls_name2idx[cls_name]), file=tf) 79 | 80 | with open('./Inshop/inshop_query.txt', 'w') as tf: 81 | test_cls_name2idx = {} 82 | cls_num = 0 83 | for line in lines: 84 | words = line.strip().split() 85 | if len(words)==3: 86 | if words[-1]=='query': 87 | path = words[0] 88 | cls_name = words[1] 89 | if cls_name not in test_cls_name2idx.keys(): 90 | test_cls_name2idx[cls_name] = cls_num 91 | cls_num += 1 92 | print('{},{}'.format(path, test_cls_name2idx[cls_name]), file=tf) 93 | 94 | with open('./Inshop/inshop_gallery.txt', 'w') as tf: 95 | for line in lines: 96 | words = line.strip().split() 97 | if len(words)==3: 98 | if words[-1]=='gallery': 99 | path = words[0] 100 | cls_name = words[1] 101 | if cls_name not in test_cls_name2idx.keys(): 102 | print('error!') 103 | break 104 | print('{},{}'.format(path, test_cls_name2idx[cls_name]), file=tf) 105 | 106 | 107 | 108 | 109 | 110 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | from sklearn.cluster import KMeans 2 | from sklearn.metrics.cluster import normalized_mutual_info_score 3 | import numpy as np 4 | import logging 5 | import torch 6 | import myutils 7 | import math 8 | from scipy.special import comb 9 | 10 | def NMI_F1(X, ground_truth, n_cluster): 11 | X = [x.cpu().numpy() for x in X] 12 | # list to numpy 13 | X = np.array(X) 14 | 15 | ground_truth = np.array(ground_truth) 16 | 17 | kmeans = KMeans(n_clusters=n_cluster, n_jobs=4, random_state=0).fit(X) 18 | 19 | logging.info('K-means done') 20 | nmi, f1 = compute_clutering_metric(np.asarray(kmeans.labels_), ground_truth) 21 | 22 | return nmi, f1 23 | 24 | def normalize(x): 25 | norm = x.norm(dim=1, p=2, keepdim=True) 26 | x = x.div(norm.expand_as(x)) 27 | return x 28 | 29 | def pairwise_similarity(x, y=None): 30 | if y is None: 31 | y = x 32 | 33 | y = normalize(y) 34 | x = normalize(x) 35 | 36 | similarity = torch.mm(x, y.t()) 37 | return similarity 38 | 39 | 40 | def Recall_at_ks(sim_mat, data_name=None, query_ids=None, gallery_ids=None): 41 | # start_time = time.time() 42 | # print(start_time) 43 | """ 44 | :param sim_mat: 45 | :param query_ids 46 | :param gallery_ids 47 | :param data 48 | 49 | Compute [R@1, R@2, R@4, R@8] 50 | """ 51 | 52 | ks_dict = dict() 53 | ks_dict['CUB'] = [1, 2, 4, 8, 16, 32] 54 | ks_dict['Cars'] = [1, 2, 4, 8, 16, 32] 55 | ks_dict['SOP'] = [1, 10, 100, 1000] 56 | ks_dict['Inshop'] = [1, 10, 20, 30, 40, 50] 57 | 58 | assert data_name in ['CUB', 'Cars', 'SOP', 'Inshop'] 59 | k_s = ks_dict[data_name] 60 | 61 | sim_mat = sim_mat.cpu().numpy() 62 | m, n = sim_mat.shape 63 | 64 | 65 | gallery_ids = np.asarray(gallery_ids) 66 | if query_ids is None: 67 | query_ids = gallery_ids 68 | else: 69 | query_ids = np.asarray(query_ids) 70 | 71 | 72 | num_valid = np.zeros(len(k_s)) 73 | neg_nums = np.zeros(m) 74 | for i in range(m): 75 | x = sim_mat[i] 76 | 77 | pos_max = np.max(x[gallery_ids == query_ids[i]]) 78 | 79 | neg_num = np.sum(x > pos_max) 80 | neg_nums[i] = neg_num 81 | 82 | for i, k in enumerate(k_s): 83 | if i == 0: 84 | temp = np.sum(neg_nums < k) 85 | num_valid[i] = temp 86 | else: 87 | temp = np.sum(neg_nums < k) 88 | num_valid[i] = temp 89 | 90 | return num_valid / float(m) 91 | 92 | 93 | def compute_clutering_metric(idx, item_ids): 94 | 95 | N = len(idx) 96 | 97 | # cluster centers 98 | centers = np.unique(idx) 99 | num_cluster = len(centers) 100 | 101 | # count the number of objects in each cluster 102 | count_cluster = np.zeros(num_cluster) 103 | for i in range(num_cluster): 104 | count_cluster[i] = len(np.where(idx == centers[i])[0]) 105 | 106 | # build a mapping from item_id to item index 107 | keys = np.unique(item_ids) 108 | num_item = len(keys) 109 | values = range(num_item) 110 | item_map = dict() 111 | for i in range(len(keys)): 112 | item_map.update([(keys[i], values[i])]) 113 | 114 | # count the number of objects of each item 115 | count_item = np.zeros(num_item) 116 | for i in range(N): 117 | index = item_map[item_ids[i]] 118 | count_item[index] = count_item[index] + 1 119 | 120 | # compute purity 121 | purity = 0 122 | for i in range(num_cluster): 123 | member = np.where(idx == centers[i])[0] 124 | member_ids = item_ids[member] 125 | 126 | count = np.zeros(num_item) 127 | for j in range(len(member)): 128 | index = item_map[member_ids[j]] 129 | count[index] = count[index] + 1 130 | purity = purity + max(count) 131 | 132 | # compute Normalized Mutual Information (NMI) 133 | count_cross = np.zeros((num_cluster, num_item)) 134 | for i in range(N): 135 | index_cluster = np.where(idx[i] == centers)[0] 136 | index_item = item_map[item_ids[i]] 137 | count_cross[index_cluster, index_item] = count_cross[index_cluster, index_item] + 1 138 | 139 | # mutual information 140 | I = 0 141 | for k in range(num_cluster): 142 | for j in range(num_item): 143 | if count_cross[k, j] > 0: 144 | s = count_cross[k, j] / N * math.log(N * count_cross[k, j] / (count_cluster[k] * count_item[j])) 145 | I = I + s 146 | 147 | # entropy 148 | H_cluster = 0 149 | for k in range(num_cluster): 150 | s = -count_cluster[k] / N * math.log(count_cluster[k] / float(N)) 151 | H_cluster = H_cluster + s 152 | 153 | H_item = 0 154 | for j in range(num_item): 155 | s = -count_item[j] / N * math.log(count_item[j] / float(N)) 156 | H_item = H_item + s 157 | 158 | NMI = 2 * I / (H_cluster + H_item) 159 | 160 | # compute True Positive (TP) plus False Positive (FP) 161 | tp_fp = 0 162 | for k in range(num_cluster): 163 | if count_cluster[k] > 1: 164 | tp_fp = tp_fp + comb(count_cluster[k], 2) 165 | 166 | # compute True Positive (TP) 167 | tp = 0 168 | for k in range(num_cluster): 169 | member = np.where(idx == centers[k])[0] 170 | member_ids = item_ids[member] 171 | 172 | count = np.zeros(num_item) 173 | for j in range(len(member)): 174 | index = item_map[member_ids[j]] 175 | count[index] = count[index] + 1 176 | 177 | for i in range(num_item): 178 | if count[i] > 1: 179 | tp = tp + comb(count[i], 2) 180 | 181 | # False Positive (FP) 182 | fp = tp_fp - tp 183 | 184 | # compute False Negative (FN) 185 | count = 0 186 | for j in range(num_item): 187 | if count_item[j] > 1: 188 | count = count + comb(count_item[j], 2) 189 | 190 | fn = count - tp 191 | 192 | # compute F measure 193 | P = tp / (tp + fp) 194 | R = tp / (tp + fn) 195 | beta = 1 196 | F = (beta*beta + 1) * P * R / (beta*beta * P + R) 197 | 198 | return NMI, F 199 | 200 | -------------------------------------------------------------------------------- /imgs/SEC.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dyfine/SphericalEmbedding/f118c0ee05cfd3a0905a67cae2a5813a1e061647/imgs/SEC.png -------------------------------------------------------------------------------- /imgs/experiment_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dyfine/SphericalEmbedding/f118c0ee05cfd3a0905a67cae2a5813a1e061647/imgs/experiment_results.png -------------------------------------------------------------------------------- /learner.py: -------------------------------------------------------------------------------- 1 | import myutils 2 | import os 3 | import torch 4 | from loss import NpairLoss, TripletSemihardLoss, TripletLoss, MultiSimilarityLoss 5 | import logging 6 | import numpy as np 7 | from models.bninception import BNInception 8 | 9 | from torch.utils.data import DataLoader 10 | from torch import optim 11 | from tensorboardX import SummaryWriter 12 | from torch.utils.data.sampler import Sampler 13 | from datetime import datetime 14 | from evaluation import NMI_F1, pairwise_similarity, Recall_at_ks 15 | from data_engine import MSBaseDataSet, RandomIdSampler 16 | 17 | def get_time(): 18 | return (str(datetime.now())[:-10]).replace(' ', '-').replace(':', '-') 19 | 20 | class metric_learner(object): 21 | def __init__(self, conf, inference=False): 22 | 23 | logging.info(f'metric learner use {conf}') 24 | self.model = torch.nn.DataParallel(BNInception()).cuda() 25 | logging.info(f'model generated') 26 | 27 | if not inference: 28 | 29 | if conf.use_dataset == 'CUB': 30 | self.dataset = MSBaseDataSet(conf, './datasets/CUB_200_2011/cub_train.txt', 31 | transform=conf.transform_dict['rand-crop'], mode='RGB') 32 | elif conf.use_dataset == 'Cars': 33 | self.dataset = MSBaseDataSet(conf, './datasets/CARS196/cars_train.txt', 34 | transform=conf.transform_dict['rand-crop'], mode='RGB') 35 | elif conf.use_dataset == 'SOP': 36 | self.dataset = MSBaseDataSet(conf, './datasets/SOP/sop_train.txt', 37 | transform=conf.transform_dict['rand-crop'], mode='RGB') 38 | elif conf.use_dataset == 'Inshop': 39 | self.dataset = MSBaseDataSet(conf, './datasets/Inshop/inshop_train.txt', 40 | transform=conf.transform_dict['rand-crop'], mode='RGB') 41 | 42 | self.loader = DataLoader( 43 | self.dataset, batch_size=conf.batch_size, num_workers=conf.num_workers, 44 | shuffle=False, sampler=RandomIdSampler(conf, self.dataset.label_index_dict), drop_last=True, 45 | pin_memory=True, 46 | ) 47 | 48 | self.class_num = self.dataset.num_cls 49 | self.img_num = self.dataset.num_train 50 | 51 | myutils.mkdir_p(conf.log_path, delete=True) 52 | self.writer = SummaryWriter(str(conf.log_path)) 53 | self.step = 0 54 | 55 | self.head_npair = NpairLoss().to(conf.device) 56 | self.head_semih_triplet = TripletSemihardLoss().to(conf.device) 57 | self.head_triplet = TripletLoss(instance=conf.instances).to(conf.device) 58 | self.head_multisimiloss = MultiSimilarityLoss().to(conf.device) 59 | logging.info('model heads generated') 60 | 61 | backbone_bn_para, backbone_wo_bn_para = [ 62 | [p for k, p in self.model.named_parameters() if 63 | ('bn' in k) == is_bn and ('head' in k) == False] for is_bn in [True, False]] 64 | 65 | head_bn_para, head_wo_bn_para = [ 66 | [p for k, p in self.model.module.head.named_parameters() if 67 | ('bn' in k) == is_bn] for is_bn in [True, False]] 68 | 69 | self.optimizer = optim.Adam([ 70 | {'params': backbone_bn_para if conf.freeze_bn==False else [], 'lr': conf.lr_p}, 71 | {'params': backbone_wo_bn_para, 'weight_decay': conf.weight_decay, 'lr': conf.lr_p}, 72 | {'params': head_bn_para, 'lr': conf.lr}, 73 | {'params': head_wo_bn_para, 'weight_decay': conf.weight_decay, 'lr': conf.lr}, 74 | ]) 75 | 76 | logging.info(f'{self.optimizer}, optimizers generated') 77 | 78 | if conf.use_dataset=='CUB' or conf.use_dataset=='Cars': 79 | self.board_loss_every = 20 80 | self.evaluate_every = 100 81 | self.save_every = 1000 82 | elif conf.use_dataset=='Inshop': 83 | self.board_loss_every = 20 84 | self.evaluate_every = 200 85 | self.save_every = 2000 86 | else: 87 | self.board_loss_every = 20 88 | self.evaluate_every = 500 89 | self.save_every = 2000 90 | 91 | 92 | def train(self, conf): 93 | self.model.train() 94 | self.train_with_fixed_bn(conf) 95 | 96 | myutils.timer.since_last_check('start train') 97 | data_time = myutils.AverageMeter(20) 98 | loss_time = myutils.AverageMeter(20) 99 | loss_meter = myutils.AverageMeter(20) 100 | 101 | self.step = conf.start_step 102 | 103 | if self.step == 0 and conf.start_eval: 104 | nmi, f1, recall_ks = self.test(conf) 105 | self.writer.add_scalar('{}/test_nmi'.format(conf.use_dataset), nmi, self.step) 106 | self.writer.add_scalar('{}/test_f1'.format(conf.use_dataset), f1, self.step) 107 | self.writer.add_scalar('{}/test_recall_at_1'.format(conf.use_dataset), recall_ks[0], self.step) 108 | logging.info(f'test on {conf.use_dataset}: nmi is {nmi}, f1 is {f1}, recalls are {recall_ks[0]}, {recall_ks[1]}, {recall_ks[2]}, {recall_ks[3:]} ') 109 | 110 | nmi, f1, recall_ks = self.validate(conf) 111 | self.writer.add_scalar('{}/train_nmi'.format(conf.use_dataset), nmi, self.step) 112 | self.writer.add_scalar('{}/train_f1'.format(conf.use_dataset), f1, self.step) 113 | self.writer.add_scalar('{}/train_recall_at_1'.format(conf.use_dataset), recall_ks[0], self.step) 114 | logging.info(f'val on {conf.use_dataset}: nmi is {nmi}, f1 is {f1}, recall_at_1 is {recall_ks[0]} ') 115 | 116 | self.train_with_fixed_bn(conf) 117 | 118 | 119 | while self.step < conf.steps: 120 | 121 | loader_enum = enumerate(self.loader) 122 | while True: 123 | if self.step > conf.steps: 124 | break 125 | try: 126 | ind_data, data = loader_enum.__next__() 127 | except StopIteration as e: 128 | logging.info(f'one epoch finish {e} {ind_data}') 129 | break 130 | data_time.update(myutils.timer.since_last_check(verbose=False)) 131 | 132 | if self.step == conf.step_milestones[0]: 133 | self.schedule_lr(conf) 134 | if self.step == conf.step_milestones[1]: 135 | self.schedule_lr(conf) 136 | if self.step == conf.step_milestones[2]: 137 | self.schedule_lr(conf) 138 | 139 | imgs = data['image'].to(conf.device) 140 | labels = data['label'].to(conf.device) 141 | index = data['index'] 142 | 143 | self.optimizer.zero_grad() 144 | 145 | fea = self.model(imgs, normalized=False) 146 | 147 | fea_norm = fea.norm(p=2, dim=1) 148 | norm_mean = fea_norm.mean() 149 | norm_var = ((fea_norm - norm_mean) ** 2).mean() 150 | 151 | 152 | if self.step==0: 153 | self.record_norm_mean = norm_mean.detach() 154 | else: 155 | self.record_norm_mean = (1 - conf.norm_momentum) * self.record_norm_mean + \ 156 | conf.norm_momentum * norm_mean.detach() 157 | 158 | 159 | if conf.use_loss == 'triplet': 160 | loss, avg_ap, avg_an = self.head_triplet(fea, labels, normalized=True) 161 | elif conf.use_loss == 'n-npair': 162 | loss, avg_ap, avg_an = self.head_npair(fea, labels, normalized=True) 163 | elif conf.use_loss == 'semihtriplet': 164 | loss, avg_ap, avg_an = self.head_semih_triplet(fea, labels, normalized=True) 165 | elif conf.use_loss == 'ms': 166 | loss, avg_ap, avg_an = self.head_multisimiloss(fea, labels) 167 | 168 | 169 | 170 | loss_sec = ((fea_norm - self.record_norm_mean) ** 2).mean() 171 | loss_l2reg = (fea_norm ** 2).mean() 172 | 173 | if conf.sec_wei != 0: 174 | loss = loss + conf.sec_wei * loss_sec 175 | if conf.l2reg_wei != 0: 176 | loss = loss + conf.l2reg_wei * loss_l2reg 177 | 178 | loss.backward() 179 | 180 | self.writer.add_scalar('info/norm_var', norm_var.detach().item(), self.step) 181 | self.writer.add_scalar('info/norm_mean', norm_mean.detach().item(), self.step) 182 | self.writer.add_scalar('info/loss_sec', loss_sec.item(), self.step) 183 | self.writer.add_scalar('info/loss_l2reg', loss_l2reg.item(), self.step) 184 | self.writer.add_scalar('info/avg_ap', avg_ap.item(), self.step) 185 | self.writer.add_scalar('info/avg_an', avg_an.item(), self.step) 186 | self.writer.add_scalar('info/record_norm_mean', self.record_norm_mean.item(), self.step) 187 | self.writer.add_scalar('info/lr', self.optimizer.param_groups[2]['lr'], self.step) 188 | 189 | loss_meter.update(loss.item()) 190 | 191 | self.optimizer.step() 192 | 193 | if self.step % self.evaluate_every ==0 and self.step != 0: 194 | nmi, f1, recall_ks = self.test(conf) 195 | self.writer.add_scalar('{}/test_nmi'.format(conf.use_dataset), nmi, self.step) 196 | self.writer.add_scalar('{}/test_f1'.format(conf.use_dataset), f1, self.step) 197 | self.writer.add_scalar('{}/test_recall_at_1'.format(conf.use_dataset), recall_ks[0], self.step) 198 | logging.info(f'test on {conf.use_dataset}: nmi is {nmi}, f1 is {f1}, recalls are {recall_ks[0]}, {recall_ks[1]}, {recall_ks[2]}, {recall_ks[3:]} ') 199 | 200 | nmi, f1, recall_ks = self.validate(conf) 201 | self.writer.add_scalar('{}/train_nmi'.format(conf.use_dataset), nmi, self.step) 202 | self.writer.add_scalar('{}/train_f1'.format(conf.use_dataset), f1, self.step) 203 | self.writer.add_scalar('{}/train_recall_at_1'.format(conf.use_dataset), recall_ks[0], self.step) 204 | logging.info(f'val on {conf.use_dataset}: nmi is {nmi}, f1 is {f1}, recall_at_1 is {recall_ks[0]} ') 205 | 206 | self.train_with_fixed_bn(conf) 207 | 208 | if self.step % self.board_loss_every == 0 and self.step != 0: 209 | # record lr 210 | self.writer.add_scalar('train_loss', loss_meter.avg, self.step) 211 | 212 | logging.info(f'step {self.step}: ' + 213 | f'loss: {loss_meter.avg:.3f} ' + 214 | f'data time: {data_time.avg:.2f} ' + 215 | f'loss time: {loss_time.avg:.2f} ' + 216 | f'speed: {conf.batch_size/(data_time.avg+loss_time.avg):.2f} imgs/s ' + 217 | f'norm_mean: {norm_mean.item():.2f} ' + 218 | f'norm_var: {norm_var.item():.2f}') 219 | 220 | if self.step % self.save_every == 0 and self.step != 0: 221 | self.save_state(conf) 222 | 223 | self.step += 1 224 | 225 | loss_time.update(myutils.timer.since_last_check(verbose=False)) 226 | 227 | self.save_state(conf, to_save_folder=True) 228 | 229 | def train_with_fixed_bn(self, conf): 230 | def fix_bn(m): 231 | classname = m.__class__.__name__ 232 | if classname.find('BatchNorm') != -1: 233 | m.eval() 234 | if conf.freeze_bn: 235 | self.model.apply(fix_bn) 236 | self.model.module.head.train() 237 | else: 238 | pass 239 | 240 | def validate(self, conf): 241 | logging.info('start eval') 242 | self.model.eval() 243 | 244 | if conf.use_dataset == 'CUB' or conf.use_dataset == 'Cars' or conf.use_dataset == 'SOP': 245 | 246 | loader = DataLoader(self.dataset, batch_size=conf.batch_size, num_workers=conf.num_workers, 247 | shuffle=False, pin_memory=True, drop_last=False) 248 | 249 | loader_enum = enumerate(loader) 250 | feas = torch.tensor([]) 251 | labels = np.array([]) 252 | with torch.no_grad(): 253 | while True: 254 | try: 255 | ind_data, data = loader_enum.__next__() 256 | except StopIteration as e: 257 | break 258 | 259 | imgs = data['image'] 260 | label = data['label'] 261 | 262 | output1 = self.model(imgs, normalized=False) 263 | norm = output1.norm(dim=1, p=2, keepdim=True) 264 | output1 = output1.div(norm.expand_as(output1)) 265 | feas = torch.cat((feas, output1.cpu()), 0) 266 | labels = np.append(labels, label.cpu().numpy()) 267 | 268 | if conf.use_dataset == 'SOP': 269 | nmi = 0 270 | f1 = 0 271 | else: 272 | pids = np.unique(labels) 273 | nmi, f1 = NMI_F1(feas, labels, n_cluster=len(pids)) 274 | 275 | sim_mat = pairwise_similarity(feas) 276 | sim_mat = sim_mat - torch.eye(sim_mat.size(0)) 277 | recall_ks = Recall_at_ks(sim_mat, data_name=conf.use_dataset, gallery_ids=labels) 278 | 279 | elif conf.use_dataset=='Inshop': 280 | nmi = 0 281 | f1 = 0 282 | recall_ks = [0.0, 0.0] 283 | 284 | self.model.train() 285 | logging.info('eval end') 286 | return nmi, f1, recall_ks 287 | 288 | def test(self, conf): 289 | logging.info('start test') 290 | self.model.eval() 291 | 292 | if conf.use_dataset=='CUB' or conf.use_dataset=='Cars' or conf.use_dataset=='SOP': 293 | 294 | if conf.use_dataset == 'CUB': 295 | dataset = MSBaseDataSet(conf, './datasets/CUB_200_2011/cub_test.txt', 296 | transform=conf.transform_dict['center-crop'], mode='RGB') 297 | elif conf.use_dataset == 'Cars': 298 | dataset = MSBaseDataSet(conf, './datasets/CARS196/cars_test.txt', 299 | transform=conf.transform_dict['center-crop'], mode='RGB') 300 | elif conf.use_dataset == 'SOP': 301 | dataset = MSBaseDataSet(conf, './datasets/SOP/sop_test.txt', 302 | transform=conf.transform_dict['center-crop'], mode='RGB') 303 | 304 | loader = DataLoader(dataset, batch_size=conf.batch_size, num_workers=conf.num_workers, 305 | shuffle=False, pin_memory=True, drop_last=False) 306 | 307 | loader_enum = enumerate(loader) 308 | feas = torch.tensor([]) 309 | labels = np.array([]) 310 | with torch.no_grad(): 311 | while True: 312 | try: 313 | ind_data, data = loader_enum.__next__() 314 | except StopIteration as e: 315 | break 316 | 317 | imgs = data['image'] 318 | label = data['label'] 319 | 320 | output1 = self.model(imgs, normalized=False) 321 | norm = output1.norm(dim=1, p=2, keepdim=True) 322 | output1 = output1.div(norm.expand_as(output1)) 323 | feas = torch.cat((feas, output1.cpu()), 0) 324 | labels = np.append(labels, label.cpu().numpy()) 325 | 326 | if conf.use_dataset == 'SOP': 327 | nmi = 0 328 | f1 = 0 329 | else: 330 | pids = np.unique(labels) 331 | nmi, f1 = NMI_F1(feas, labels, n_cluster=len(pids)) 332 | 333 | sim_mat = pairwise_similarity(feas) 334 | sim_mat = sim_mat - torch.eye(sim_mat.size(0)) 335 | recall_ks = Recall_at_ks(sim_mat, data_name=conf.use_dataset, gallery_ids=labels) 336 | 337 | elif conf.use_dataset=='Inshop': 338 | nmi = 0 339 | f1 = 0 340 | # query 341 | dataset_query = MSBaseDataSet(conf, './datasets/Inshop/inshop_query.txt', 342 | transform=conf.transform_dict['center-crop'], mode='RGB') 343 | loader_query = DataLoader(dataset_query, batch_size=conf.batch_size, num_workers=conf.num_workers, 344 | shuffle=False, pin_memory=True, drop_last=False) 345 | loader_query_enum = enumerate(loader_query) 346 | feas_query = torch.tensor([]) 347 | labels_query = np.array([]) 348 | with torch.no_grad(): 349 | while True: 350 | try: 351 | ind_data, data = loader_query_enum.__next__() 352 | except StopIteration as e: 353 | break 354 | 355 | imgs = data['image'] 356 | label = data['label'] 357 | 358 | output1 = self.model(imgs, normalized=False) 359 | norm = output1.norm(dim=1, p=2, keepdim=True) 360 | output1 = output1.div(norm.expand_as(output1)) 361 | feas_query = torch.cat((feas_query, output1.cpu()), 0) 362 | labels_query = np.append(labels_query, label.cpu().numpy()) 363 | # gallery 364 | dataset_gallery = MSBaseDataSet(conf, './datasets/Inshop/inshop_gallery.txt', 365 | transform=conf.transform_dict['center-crop'], mode='RGB') 366 | loader_gallery = DataLoader(dataset_gallery, batch_size=conf.batch_size, num_workers=conf.num_workers, 367 | shuffle=False, pin_memory=True, drop_last=False) 368 | loader_gallery_enum = enumerate(loader_gallery) 369 | feas_gallery = torch.tensor([]) 370 | labels_gallery = np.array([]) 371 | with torch.no_grad(): 372 | while True: 373 | try: 374 | ind_data, data = loader_gallery_enum.__next__() 375 | except StopIteration as e: 376 | break 377 | 378 | imgs = data['image'] 379 | label = data['label'] 380 | 381 | output1 = self.model(imgs, normalized=False) 382 | norm = output1.norm(dim=1, p=2, keepdim=True) 383 | output1 = output1.div(norm.expand_as(output1)) 384 | feas_gallery = torch.cat((feas_gallery, output1.cpu()), 0) 385 | labels_gallery = np.append(labels_gallery, label.cpu().numpy()) 386 | # test 387 | sim_mat = pairwise_similarity(feas_query, feas_gallery) 388 | recall_ks = Recall_at_ks(sim_mat, data_name=conf.use_dataset, query_ids=labels_query, gallery_ids=labels_gallery) 389 | 390 | self.model.train() 391 | logging.info('test end') 392 | 393 | return nmi, f1, recall_ks 394 | 395 | def test_sop_complete(self, conf): 396 | assert conf.use_dataset == 'SOP' 397 | 398 | logging.info('start complete sop test') 399 | self.model.eval() 400 | 401 | dataset = MSBaseDataSet(conf, './datasets/SOP/sop_test.txt', 402 | transform=conf.transform_dict['center-crop'], mode='RGB') 403 | loader = DataLoader(dataset, batch_size=conf.batch_size, num_workers=conf.num_workers, 404 | shuffle=False, pin_memory=True, drop_last=False) 405 | 406 | loader_enum = enumerate(loader) 407 | feas = torch.tensor([]) 408 | labels = np.array([]) 409 | with torch.no_grad(): 410 | while True: 411 | try: 412 | ind_data, data = loader_enum.__next__() 413 | except StopIteration as e: 414 | break 415 | 416 | imgs = data['image'] 417 | label = data['label'] 418 | 419 | output1 = self.model(imgs, normalized=False) 420 | norm = output1.norm(dim=1, p=2, keepdim=True) 421 | output1 = output1.div(norm.expand_as(output1)) 422 | feas = torch.cat((feas, output1.cpu()), 0) 423 | labels = np.append(labels, label.cpu().numpy()) 424 | 425 | pids = np.unique(labels) 426 | nmi, f1 = NMI_F1(feas, labels, n_cluster=len(pids)) 427 | 428 | print(f'nmi: {nmi}, f1: {f1}') 429 | 430 | sim_mat = pairwise_similarity(feas) 431 | sim_mat = sim_mat - torch.eye(sim_mat.size(0)) 432 | recall_ks = Recall_at_ks(sim_mat, data_name=conf.use_dataset, gallery_ids=labels) 433 | 434 | self.model.train() 435 | logging.info('test end') 436 | 437 | return nmi, f1, recall_ks 438 | 439 | def load_bninception_pretrained(self, conf): 440 | model_dict = self.model.state_dict() 441 | my_dict = {'module.'+k: v for k, v in torch.load(conf.bninception_pretrained_model_path).items() if 'module.'+k in model_dict.keys()} 442 | print('################################## do not have pretrained:') 443 | for k in model_dict: 444 | if k not in my_dict.keys(): 445 | print(k) 446 | print('##################################') 447 | model_dict.update(my_dict) 448 | self.model.load_state_dict(model_dict) 449 | 450 | def schedule_lr(self, conf): 451 | for params in self.optimizer.param_groups: 452 | params['lr'] = params['lr'] * conf.lr_gamma 453 | logging.info(f'{self.optimizer}') 454 | 455 | def save_state(self, conf, to_save_folder=False, model_only=False): 456 | if to_save_folder: 457 | save_path = conf.save_path 458 | else: 459 | save_path = conf.model_path 460 | 461 | myutils.mkdir_p(save_path, delete=False) 462 | 463 | torch.save( 464 | self.model.state_dict(), 465 | save_path / 466 | ('model_{}_step:{}.pth'.format(get_time(), self.step))) 467 | if not model_only: 468 | torch.save( 469 | self.optimizer.state_dict(), 470 | save_path / 471 | ('optimizer_{}_step:{}.pth'.format(get_time(), self.step))) 472 | 473 | def load_state(self, conf, resume_path, fixed_str=None, load_optimizer=False): 474 | from pathlib import Path 475 | 476 | save_path = Path(resume_path) 477 | modelp = save_path / 'model_{}'.format(fixed_str) 478 | if not os.path.exists(modelp): 479 | fixed_strs = [t.name for t in save_path.glob('model*_*.pth')] 480 | step = [fixed_str.split('_')[-1].split(':')[-1].split('.')[-2] for fixed_str in fixed_strs] 481 | step = np.asarray(step, dtype=int) 482 | step_ind = step.argmax() 483 | fixed_str = fixed_strs[step_ind].replace('model_', '') 484 | modelp = save_path / 'model_{}'.format(fixed_str) 485 | 486 | print(fixed_str) 487 | 488 | model_dict = self.model.state_dict() 489 | my_dict = {k: v for k, v in torch.load(modelp).items() if k in model_dict.keys()} 490 | print('################################## do not have pretrained:') 491 | for k in model_dict: 492 | if k not in my_dict.keys(): 493 | print(k) 494 | print('##################################') 495 | model_dict.update(my_dict) 496 | self.model.load_state_dict(model_dict) 497 | 498 | if load_optimizer: 499 | self.optimizer.load_state_dict(torch.load(save_path / 'optimizer_{}'.format(fixed_str))) 500 | print(self.optimizer) 501 | 502 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import myutils 2 | from torch.nn import Module, Parameter 3 | import torch.nn.functional as F 4 | import torch 5 | import torch.nn as nn 6 | import numpy as np 7 | 8 | class TripletLoss(Module): 9 | def __init__(self, instance, margin=1.0): 10 | super(TripletLoss, self).__init__() 11 | self.margin = margin 12 | self.instance = instance 13 | 14 | def forward(self, inputs, targets, normalized=True): 15 | norm_temp = inputs.norm(dim=1, p=2, keepdim=True) 16 | if normalized: 17 | inputs = inputs.div(norm_temp.expand_as(inputs)) 18 | 19 | nB = inputs.size(0) 20 | idx_ = torch.arange(0, nB, dtype=torch.long) 21 | 22 | dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(nB, nB) 23 | dist = dist + dist.t() 24 | # use squared 25 | dist.addmm_(1, -2, inputs, inputs.t()).clamp_(min=1e-12) 26 | 27 | adjacency = targets.expand(nB, nB).eq(targets.expand(nB, nB).t()) 28 | adjacency_not = ~adjacency 29 | mask_ap = (adjacency.float() - torch.eye(nB).cuda()).long() 30 | mask_an = adjacency_not.long() 31 | 32 | dist_ap = (dist[mask_ap == 1]).view(-1, 1) 33 | dist_an = (dist[mask_an == 1]).view(nB, -1) 34 | dist_an = dist_an.repeat(1, self.instance - 1) 35 | dist_an = dist_an.view(nB * (self.instance - 1), nB - self.instance) 36 | num_loss = dist_an.size(0) * dist_an.size(1) 37 | 38 | triplet_loss = torch.sum( 39 | torch.max(torch.tensor(0, dtype=torch.float).cuda(), self.margin + dist_ap - dist_an)) / num_loss 40 | final_loss = triplet_loss * 1.0 41 | 42 | with torch.no_grad(): 43 | assert normalized == True 44 | cos_theta = torch.mm(inputs, inputs.t()) 45 | mask = targets.expand(nB, nB).eq(targets.expand(nB, nB).t()) 46 | avg_ap = cos_theta[(mask.float() - torch.eye(nB).cuda()) == 1].mean() 47 | avg_an = cos_theta[mask.float() == 0].mean() 48 | 49 | return final_loss, avg_ap, avg_an 50 | 51 | class TripletSemihardLoss(Module): 52 | def __init__(self, margin=0.2): 53 | super(TripletSemihardLoss, self).__init__() 54 | self.margin = margin 55 | 56 | def forward(self, inputs, targets, normalized=True): 57 | norm_temp = inputs.norm(dim=1, p=2, keepdim=True) 58 | if normalized: 59 | inputs = inputs.div(norm_temp.expand_as(inputs)) 60 | 61 | nB = inputs.size(0) 62 | idx_ = torch.arange(0, nB, dtype=torch.long) 63 | 64 | dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(nB, nB) 65 | dist = dist + dist.t() 66 | # use squared 67 | dist.addmm_(1, -2, inputs, inputs.t()).clamp_(min=1e-12) 68 | 69 | temp_euclidean_score = dist * 1.0 70 | 71 | adjacency = targets.expand(nB, nB).eq(targets.expand(nB, nB).t()) 72 | adjacency_not = ~ adjacency 73 | 74 | dist_tile = dist.repeat(nB, 1) 75 | mask = (adjacency_not.repeat(nB, 1)) * (dist_tile > (dist.transpose(0, 1).contiguous().view(-1, 1))) 76 | mask_final = (mask.float().sum(dim=1, keepdim=True) > 0).view(nB, nB).transpose(0, 1) 77 | 78 | # negatives_outside: smallest D_an where D_an > D_ap 79 | temp1 = (dist_tile - dist_tile.max(dim=1, keepdim=True)[0]) * (mask.float()) 80 | negtives_outside = temp1.min(dim=1, keepdim=True)[0] + dist_tile.max(dim=1, keepdim=True)[0] 81 | negtives_outside = negtives_outside.view(nB, nB).transpose(0, 1) 82 | 83 | # negatives_inside: largest D_an 84 | temp2 = (dist - dist.min(dim=1, keepdim=True)[0]) * (adjacency_not.float()) 85 | negtives_inside = temp2.max(dim=1, keepdim=True)[0] + dist.min(dim=1, keepdim=True)[0] 86 | negtives_inside = negtives_inside.repeat(1, nB) 87 | 88 | semi_hard_negtives = torch.where(mask_final, negtives_outside, negtives_inside) 89 | 90 | loss_mat = self.margin + dist - semi_hard_negtives 91 | 92 | mask_positives = adjacency.float() - torch.eye(nB).cuda() 93 | mask_positives = mask_positives.detach() 94 | num_positives = torch.sum(mask_positives) 95 | 96 | triplet_loss = torch.sum( 97 | torch.max(torch.tensor(0, dtype=torch.float).cuda(), loss_mat * mask_positives)) / num_positives 98 | final_loss = triplet_loss * 1.0 99 | 100 | with torch.no_grad(): 101 | assert normalized == True 102 | cos_theta = torch.mm(inputs, inputs.t()) 103 | mask = targets.expand(nB, nB).eq(targets.expand(nB, nB).t()) 104 | avg_ap = cos_theta[(mask.float() - torch.eye(nB).cuda()) == 1].mean() 105 | avg_an = cos_theta[mask.float() == 0].mean() 106 | 107 | return final_loss, avg_ap, avg_an 108 | 109 | def cross_entropy(logits, target, size_average=True): 110 | if size_average: 111 | return torch.mean(torch.sum(- target * F.log_softmax(logits, -1), -1)) 112 | else: 113 | return torch.sum(torch.sum(- target * F.log_softmax(logits, -1), -1)) 114 | 115 | class NpairLoss(Module): 116 | def __init__(self): 117 | super(NpairLoss, self).__init__() 118 | 119 | def forward(self, inputs, targets, normalized=False): 120 | nB = inputs.size(0) 121 | 122 | norm_temp = inputs.norm(p=2, dim=1, keepdim=True) 123 | 124 | inputs_n = inputs.div(norm_temp.expand_as(inputs)) 125 | mm_logits = torch.mm(inputs_n, inputs_n.t()).detach() 126 | mask = targets.expand(nB, nB).eq(targets.expand(nB, nB).t()) 127 | 128 | cos_ap = mm_logits[(mask.float() - torch.eye(nB).float().cuda()) == 1].view(nB, -1) 129 | cos_an = mm_logits[mask != 1].view(nB, -1) 130 | 131 | avg_ap = torch.mean(cos_ap) 132 | avg_an = torch.mean(cos_an) 133 | 134 | if normalized: 135 | inputs = inputs.div(norm_temp.expand_as(inputs)) 136 | inputs = inputs * 5.0 137 | 138 | labels = targets.view(-1).cpu().numpy() 139 | pids = np.unique(labels) 140 | 141 | anchor_idx = [] 142 | positive_idx = [] 143 | for i in pids: 144 | ap_idx = np.where(labels == i)[0] 145 | anchor_idx.append(ap_idx[0]) 146 | positive_idx.append(ap_idx[1]) 147 | 148 | anchor = inputs[anchor_idx, :] 149 | positive = inputs[positive_idx, :] 150 | 151 | batch_size = anchor.size(0) 152 | 153 | target = torch.from_numpy(pids).cuda() 154 | target = target.view(target.size(0), 1) 155 | 156 | target = (target == torch.transpose(target, 0, 1)).float() 157 | target = target / torch.sum(target, dim=1, keepdim=True).float() 158 | 159 | logit = torch.matmul(anchor, torch.transpose(positive, 0, 1)) 160 | 161 | loss_ce = cross_entropy(logit, target) 162 | loss = loss_ce * 1.0 163 | 164 | return loss, avg_ap, avg_an 165 | 166 | class MultiSimilarityLoss(Module): 167 | def __init__(self): 168 | super(MultiSimilarityLoss, self).__init__() 169 | self.thresh = 0.5 170 | self.margin = 0.1 171 | self.scale_pos = 2.0 172 | self.scale_neg = 40.0 173 | 174 | def forward(self, feats, labels): 175 | 176 | norm = feats.norm(dim=1, p=2, keepdim=True) 177 | feats = feats.div(norm.expand_as(feats)) 178 | 179 | labels = labels.view(-1) 180 | assert feats.size(0) == labels.size(0), \ 181 | f"feats.size(0): {feats.size(0)} is not equal to labels.size(0): {labels.size(0)}" 182 | batch_size = feats.size(0) 183 | sim_mat = torch.matmul(feats, torch.t(feats)) 184 | 185 | epsilon = 1e-5 186 | loss = list() 187 | 188 | avg_aps = list() 189 | avg_ans = list() 190 | 191 | for i in range(batch_size): 192 | pos_pair_ = sim_mat[i][labels == labels[i]] 193 | pos_pair_ = pos_pair_[pos_pair_ < 1 - epsilon] 194 | neg_pair_ = sim_mat[i][labels != labels[i]] 195 | 196 | if len(neg_pair_) < 1 or len(pos_pair_) < 1: 197 | continue 198 | 199 | avg_aps.append(pos_pair_.mean()) 200 | avg_ans.append(neg_pair_.mean()) 201 | 202 | neg_pair = neg_pair_[neg_pair_ + self.margin > torch.min(pos_pair_)] 203 | pos_pair = pos_pair_[pos_pair_ - self.margin < torch.max(neg_pair_)] 204 | 205 | if len(neg_pair) < 1 or len(pos_pair) < 1: 206 | continue 207 | 208 | # weighting step 209 | pos_loss = 1.0 / self.scale_pos * torch.log( 210 | 1 + torch.sum(torch.exp(-self.scale_pos * (pos_pair - self.thresh)))) 211 | neg_loss = 1.0 / self.scale_neg * torch.log( 212 | 1 + torch.sum(torch.exp(self.scale_neg * (neg_pair - self.thresh)))) 213 | loss.append(pos_loss + neg_loss) 214 | 215 | if len(loss) == 0: 216 | print('with ms loss = 0 !') 217 | loss = torch.zeros([], requires_grad=True).cuda() 218 | else: 219 | loss = sum(loss) / batch_size 220 | loss = loss.view(-1) 221 | 222 | avg_ap = sum(avg_aps) / batch_size 223 | avg_an = sum(avg_ans) / batch_size 224 | 225 | return loss, avg_ap, avg_an 226 | 227 | -------------------------------------------------------------------------------- /models/bninception.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.nn.parameter import Parameter 6 | 7 | class Flatten(nn.Module): 8 | def forward(self, input): 9 | return input.view(input.size(0), -1) 10 | 11 | 12 | class BNInception(nn.Module): 13 | 14 | def __init__(self, need_bn = True): 15 | super(BNInception, self).__init__() 16 | inplace = True 17 | self.conv1_7x7_s2 = nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3)) 18 | self.conv1_7x7_s2_bn = nn.BatchNorm2d(64, affine=True) 19 | self.conv1_relu_7x7 = nn.ReLU(inplace) 20 | self.pool1_3x3_s2 = nn.MaxPool2d((3, 3), stride=(2, 2), dilation=(1, 1), ceil_mode=True) 21 | self.conv2_3x3_reduce = nn.Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1)) 22 | self.conv2_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True) 23 | self.conv2_relu_3x3_reduce = nn.ReLU(inplace) 24 | self.conv2_3x3 = nn.Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 25 | self.conv2_3x3_bn = nn.BatchNorm2d(192, affine=True) 26 | self.conv2_relu_3x3 = nn.ReLU(inplace) 27 | self.pool2_3x3_s2 = nn.MaxPool2d((3, 3), stride=(2, 2), dilation=(1, 1), ceil_mode=True) 28 | self.inception_3a_1x1 = nn.Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1)) 29 | self.inception_3a_1x1_bn = nn.BatchNorm2d(64, affine=True) 30 | self.inception_3a_relu_1x1 = nn.ReLU(inplace) 31 | self.inception_3a_3x3_reduce = nn.Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1)) 32 | self.inception_3a_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True) 33 | self.inception_3a_relu_3x3_reduce = nn.ReLU(inplace) 34 | self.inception_3a_3x3 = nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 35 | self.inception_3a_3x3_bn = nn.BatchNorm2d(64, affine=True) 36 | self.inception_3a_relu_3x3 = nn.ReLU(inplace) 37 | self.inception_3a_double_3x3_reduce = nn.Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1)) 38 | self.inception_3a_double_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True) 39 | self.inception_3a_relu_double_3x3_reduce = nn.ReLU(inplace) 40 | self.inception_3a_double_3x3_1 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 41 | self.inception_3a_double_3x3_1_bn = nn.BatchNorm2d(96, affine=True) 42 | self.inception_3a_relu_double_3x3_1 = nn.ReLU(inplace) 43 | self.inception_3a_double_3x3_2 = nn.Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 44 | self.inception_3a_double_3x3_2_bn = nn.BatchNorm2d(96, affine=True) 45 | self.inception_3a_relu_double_3x3_2 = nn.ReLU(inplace) 46 | self.inception_3a_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) 47 | self.inception_3a_pool_proj = nn.Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1)) 48 | self.inception_3a_pool_proj_bn = nn.BatchNorm2d(32, affine=True) 49 | self.inception_3a_relu_pool_proj = nn.ReLU(inplace) 50 | self.inception_3b_1x1 = nn.Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1)) 51 | self.inception_3b_1x1_bn = nn.BatchNorm2d(64, affine=True) 52 | self.inception_3b_relu_1x1 = nn.ReLU(inplace) 53 | self.inception_3b_3x3_reduce = nn.Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1)) 54 | self.inception_3b_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True) 55 | self.inception_3b_relu_3x3_reduce = nn.ReLU(inplace) 56 | self.inception_3b_3x3 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 57 | self.inception_3b_3x3_bn = nn.BatchNorm2d(96, affine=True) 58 | self.inception_3b_relu_3x3 = nn.ReLU(inplace) 59 | self.inception_3b_double_3x3_reduce = nn.Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1)) 60 | self.inception_3b_double_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True) 61 | self.inception_3b_relu_double_3x3_reduce = nn.ReLU(inplace) 62 | self.inception_3b_double_3x3_1 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 63 | self.inception_3b_double_3x3_1_bn = nn.BatchNorm2d(96, affine=True) 64 | self.inception_3b_relu_double_3x3_1 = nn.ReLU(inplace) 65 | self.inception_3b_double_3x3_2 = nn.Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 66 | self.inception_3b_double_3x3_2_bn = nn.BatchNorm2d(96, affine=True) 67 | self.inception_3b_relu_double_3x3_2 = nn.ReLU(inplace) 68 | self.inception_3b_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) 69 | self.inception_3b_pool_proj = nn.Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1)) 70 | self.inception_3b_pool_proj_bn = nn.BatchNorm2d(64, affine=True) 71 | self.inception_3b_relu_pool_proj = nn.ReLU(inplace) 72 | self.inception_3c_3x3_reduce = nn.Conv2d(320, 128, kernel_size=(1, 1), stride=(1, 1)) 73 | self.inception_3c_3x3_reduce_bn = nn.BatchNorm2d(128, affine=True) 74 | self.inception_3c_relu_3x3_reduce = nn.ReLU(inplace) 75 | self.inception_3c_3x3 = nn.Conv2d(128, 160, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) 76 | self.inception_3c_3x3_bn = nn.BatchNorm2d(160, affine=True) 77 | self.inception_3c_relu_3x3 = nn.ReLU(inplace) 78 | self.inception_3c_double_3x3_reduce = nn.Conv2d(320, 64, kernel_size=(1, 1), stride=(1, 1)) 79 | self.inception_3c_double_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True) 80 | self.inception_3c_relu_double_3x3_reduce = nn.ReLU(inplace) 81 | self.inception_3c_double_3x3_1 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 82 | self.inception_3c_double_3x3_1_bn = nn.BatchNorm2d(96, affine=True) 83 | self.inception_3c_relu_double_3x3_1 = nn.ReLU(inplace) 84 | self.inception_3c_double_3x3_2 = nn.Conv2d(96, 96, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) 85 | self.inception_3c_double_3x3_2_bn = nn.BatchNorm2d(96, affine=True) 86 | self.inception_3c_relu_double_3x3_2 = nn.ReLU(inplace) 87 | self.inception_3c_pool = nn.MaxPool2d((3, 3), stride=(2, 2), dilation=(1, 1), ceil_mode=True) 88 | self.inception_4a_1x1 = nn.Conv2d(576, 224, kernel_size=(1, 1), stride=(1, 1)) 89 | self.inception_4a_1x1_bn = nn.BatchNorm2d(224, affine=True) 90 | self.inception_4a_relu_1x1 = nn.ReLU(inplace) 91 | self.inception_4a_3x3_reduce = nn.Conv2d(576, 64, kernel_size=(1, 1), stride=(1, 1)) 92 | self.inception_4a_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True) 93 | self.inception_4a_relu_3x3_reduce = nn.ReLU(inplace) 94 | self.inception_4a_3x3 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 95 | self.inception_4a_3x3_bn = nn.BatchNorm2d(96, affine=True) 96 | self.inception_4a_relu_3x3 = nn.ReLU(inplace) 97 | self.inception_4a_double_3x3_reduce = nn.Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1)) 98 | self.inception_4a_double_3x3_reduce_bn = nn.BatchNorm2d(96, affine=True) 99 | self.inception_4a_relu_double_3x3_reduce = nn.ReLU(inplace) 100 | self.inception_4a_double_3x3_1 = nn.Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 101 | self.inception_4a_double_3x3_1_bn = nn.BatchNorm2d(128, affine=True) 102 | self.inception_4a_relu_double_3x3_1 = nn.ReLU(inplace) 103 | self.inception_4a_double_3x3_2 = nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 104 | self.inception_4a_double_3x3_2_bn = nn.BatchNorm2d(128, affine=True) 105 | self.inception_4a_relu_double_3x3_2 = nn.ReLU(inplace) 106 | self.inception_4a_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) 107 | self.inception_4a_pool_proj = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1)) 108 | self.inception_4a_pool_proj_bn = nn.BatchNorm2d(128, affine=True) 109 | self.inception_4a_relu_pool_proj = nn.ReLU(inplace) 110 | self.inception_4b_1x1 = nn.Conv2d(576, 192, kernel_size=(1, 1), stride=(1, 1)) 111 | self.inception_4b_1x1_bn = nn.BatchNorm2d(192, affine=True) 112 | self.inception_4b_relu_1x1 = nn.ReLU(inplace) 113 | self.inception_4b_3x3_reduce = nn.Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1)) 114 | self.inception_4b_3x3_reduce_bn = nn.BatchNorm2d(96, affine=True) 115 | self.inception_4b_relu_3x3_reduce = nn.ReLU(inplace) 116 | self.inception_4b_3x3 = nn.Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 117 | self.inception_4b_3x3_bn = nn.BatchNorm2d(128, affine=True) 118 | self.inception_4b_relu_3x3 = nn.ReLU(inplace) 119 | self.inception_4b_double_3x3_reduce = nn.Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1)) 120 | self.inception_4b_double_3x3_reduce_bn = nn.BatchNorm2d(96, affine=True) 121 | self.inception_4b_relu_double_3x3_reduce = nn.ReLU(inplace) 122 | self.inception_4b_double_3x3_1 = nn.Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 123 | self.inception_4b_double_3x3_1_bn = nn.BatchNorm2d(128, affine=True) 124 | self.inception_4b_relu_double_3x3_1 = nn.ReLU(inplace) 125 | self.inception_4b_double_3x3_2 = nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 126 | self.inception_4b_double_3x3_2_bn = nn.BatchNorm2d(128, affine=True) 127 | self.inception_4b_relu_double_3x3_2 = nn.ReLU(inplace) 128 | self.inception_4b_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) 129 | self.inception_4b_pool_proj = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1)) 130 | self.inception_4b_pool_proj_bn = nn.BatchNorm2d(128, affine=True) 131 | self.inception_4b_relu_pool_proj = nn.ReLU(inplace) 132 | self.inception_4c_1x1 = nn.Conv2d(576, 160, kernel_size=(1, 1), stride=(1, 1)) 133 | self.inception_4c_1x1_bn = nn.BatchNorm2d(160, affine=True) 134 | self.inception_4c_relu_1x1 = nn.ReLU(inplace) 135 | self.inception_4c_3x3_reduce = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1)) 136 | self.inception_4c_3x3_reduce_bn = nn.BatchNorm2d(128, affine=True) 137 | self.inception_4c_relu_3x3_reduce = nn.ReLU(inplace) 138 | self.inception_4c_3x3 = nn.Conv2d(128, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 139 | self.inception_4c_3x3_bn = nn.BatchNorm2d(160, affine=True) 140 | self.inception_4c_relu_3x3 = nn.ReLU(inplace) 141 | self.inception_4c_double_3x3_reduce = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1)) 142 | self.inception_4c_double_3x3_reduce_bn = nn.BatchNorm2d(128, affine=True) 143 | self.inception_4c_relu_double_3x3_reduce = nn.ReLU(inplace) 144 | self.inception_4c_double_3x3_1 = nn.Conv2d(128, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 145 | self.inception_4c_double_3x3_1_bn = nn.BatchNorm2d(160, affine=True) 146 | self.inception_4c_relu_double_3x3_1 = nn.ReLU(inplace) 147 | self.inception_4c_double_3x3_2 = nn.Conv2d(160, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 148 | self.inception_4c_double_3x3_2_bn = nn.BatchNorm2d(160, affine=True) 149 | self.inception_4c_relu_double_3x3_2 = nn.ReLU(inplace) 150 | self.inception_4c_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) 151 | self.inception_4c_pool_proj = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1)) 152 | self.inception_4c_pool_proj_bn = nn.BatchNorm2d(128, affine=True) 153 | self.inception_4c_relu_pool_proj = nn.ReLU(inplace) 154 | self.inception_4d_1x1 = nn.Conv2d(608, 96, kernel_size=(1, 1), stride=(1, 1)) 155 | self.inception_4d_1x1_bn = nn.BatchNorm2d(96, affine=True) 156 | self.inception_4d_relu_1x1 = nn.ReLU(inplace) 157 | self.inception_4d_3x3_reduce = nn.Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1)) 158 | self.inception_4d_3x3_reduce_bn = nn.BatchNorm2d(128, affine=True) 159 | self.inception_4d_relu_3x3_reduce = nn.ReLU(inplace) 160 | self.inception_4d_3x3 = nn.Conv2d(128, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 161 | self.inception_4d_3x3_bn = nn.BatchNorm2d(192, affine=True) 162 | self.inception_4d_relu_3x3 = nn.ReLU(inplace) 163 | self.inception_4d_double_3x3_reduce = nn.Conv2d(608, 160, kernel_size=(1, 1), stride=(1, 1)) 164 | self.inception_4d_double_3x3_reduce_bn = nn.BatchNorm2d(160, affine=True) 165 | self.inception_4d_relu_double_3x3_reduce = nn.ReLU(inplace) 166 | self.inception_4d_double_3x3_1 = nn.Conv2d(160, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 167 | self.inception_4d_double_3x3_1_bn = nn.BatchNorm2d(192, affine=True) 168 | self.inception_4d_relu_double_3x3_1 = nn.ReLU(inplace) 169 | self.inception_4d_double_3x3_2 = nn.Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 170 | self.inception_4d_double_3x3_2_bn = nn.BatchNorm2d(192, affine=True) 171 | self.inception_4d_relu_double_3x3_2 = nn.ReLU(inplace) 172 | self.inception_4d_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) 173 | self.inception_4d_pool_proj = nn.Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1)) 174 | self.inception_4d_pool_proj_bn = nn.BatchNorm2d(128, affine=True) 175 | self.inception_4d_relu_pool_proj = nn.ReLU(inplace) 176 | self.inception_4e_3x3_reduce = nn.Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1)) 177 | self.inception_4e_3x3_reduce_bn = nn.BatchNorm2d(128, affine=True) 178 | self.inception_4e_relu_3x3_reduce = nn.ReLU(inplace) 179 | self.inception_4e_3x3 = nn.Conv2d(128, 192, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) 180 | self.inception_4e_3x3_bn = nn.BatchNorm2d(192, affine=True) 181 | self.inception_4e_relu_3x3 = nn.ReLU(inplace) 182 | self.inception_4e_double_3x3_reduce = nn.Conv2d(608, 192, kernel_size=(1, 1), stride=(1, 1)) 183 | self.inception_4e_double_3x3_reduce_bn = nn.BatchNorm2d(192, affine=True) 184 | self.inception_4e_relu_double_3x3_reduce = nn.ReLU(inplace) 185 | self.inception_4e_double_3x3_1 = nn.Conv2d(192, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 186 | self.inception_4e_double_3x3_1_bn = nn.BatchNorm2d(256, affine=True) 187 | self.inception_4e_relu_double_3x3_1 = nn.ReLU(inplace) 188 | self.inception_4e_double_3x3_2 = nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) 189 | self.inception_4e_double_3x3_2_bn = nn.BatchNorm2d(256, affine=True) 190 | self.inception_4e_relu_double_3x3_2 = nn.ReLU(inplace) 191 | self.inception_4e_pool = nn.MaxPool2d((3, 3), stride=(2, 2), dilation=(1, 1), ceil_mode=True) 192 | self.inception_5a_1x1 = nn.Conv2d(1056, 352, kernel_size=(1, 1), stride=(1, 1)) 193 | self.inception_5a_1x1_bn = nn.BatchNorm2d(352, affine=True) 194 | self.inception_5a_relu_1x1 = nn.ReLU(inplace) 195 | self.inception_5a_3x3_reduce = nn.Conv2d(1056, 192, kernel_size=(1, 1), stride=(1, 1)) 196 | self.inception_5a_3x3_reduce_bn = nn.BatchNorm2d(192, affine=True) 197 | self.inception_5a_relu_3x3_reduce = nn.ReLU(inplace) 198 | self.inception_5a_3x3 = nn.Conv2d(192, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 199 | self.inception_5a_3x3_bn = nn.BatchNorm2d(320, affine=True) 200 | self.inception_5a_relu_3x3 = nn.ReLU(inplace) 201 | self.inception_5a_double_3x3_reduce = nn.Conv2d(1056, 160, kernel_size=(1, 1), stride=(1, 1)) 202 | self.inception_5a_double_3x3_reduce_bn = nn.BatchNorm2d(160, affine=True) 203 | self.inception_5a_relu_double_3x3_reduce = nn.ReLU(inplace) 204 | self.inception_5a_double_3x3_1 = nn.Conv2d(160, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 205 | self.inception_5a_double_3x3_1_bn = nn.BatchNorm2d(224, affine=True) 206 | self.inception_5a_relu_double_3x3_1 = nn.ReLU(inplace) 207 | self.inception_5a_double_3x3_2 = nn.Conv2d(224, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 208 | self.inception_5a_double_3x3_2_bn = nn.BatchNorm2d(224, affine=True) 209 | self.inception_5a_relu_double_3x3_2 = nn.ReLU(inplace) 210 | self.inception_5a_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) 211 | self.inception_5a_pool_proj = nn.Conv2d(1056, 128, kernel_size=(1, 1), stride=(1, 1)) 212 | self.inception_5a_pool_proj_bn = nn.BatchNorm2d(128, affine=True) 213 | self.inception_5a_relu_pool_proj = nn.ReLU(inplace) 214 | self.inception_5b_1x1 = nn.Conv2d(1024, 352, kernel_size=(1, 1), stride=(1, 1)) 215 | self.inception_5b_1x1_bn = nn.BatchNorm2d(352, affine=True) 216 | self.inception_5b_relu_1x1 = nn.ReLU(inplace) 217 | self.inception_5b_3x3_reduce = nn.Conv2d(1024, 192, kernel_size=(1, 1), stride=(1, 1)) 218 | self.inception_5b_3x3_reduce_bn = nn.BatchNorm2d(192, affine=True) 219 | self.inception_5b_relu_3x3_reduce = nn.ReLU(inplace) 220 | self.inception_5b_3x3 = nn.Conv2d(192, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 221 | self.inception_5b_3x3_bn = nn.BatchNorm2d(320, affine=True) 222 | self.inception_5b_relu_3x3 = nn.ReLU(inplace) 223 | self.inception_5b_double_3x3_reduce = nn.Conv2d(1024, 192, kernel_size=(1, 1), stride=(1, 1)) 224 | self.inception_5b_double_3x3_reduce_bn = nn.BatchNorm2d(192, affine=True) 225 | self.inception_5b_relu_double_3x3_reduce = nn.ReLU(inplace) 226 | self.inception_5b_double_3x3_1 = nn.Conv2d(192, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 227 | self.inception_5b_double_3x3_1_bn = nn.BatchNorm2d(224, affine=True) 228 | self.inception_5b_relu_double_3x3_1 = nn.ReLU(inplace) 229 | self.inception_5b_double_3x3_2 = nn.Conv2d(224, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 230 | self.inception_5b_double_3x3_2_bn = nn.BatchNorm2d(224, affine=True) 231 | self.inception_5b_relu_double_3x3_2 = nn.ReLU(inplace) 232 | self.inception_5b_pool = nn.MaxPool2d((3, 3), stride=(1, 1), padding=(1, 1), dilation=(1, 1), ceil_mode=True) 233 | self.inception_5b_pool_proj = nn.Conv2d(1024, 128, kernel_size=(1, 1), stride=(1, 1)) 234 | self.inception_5b_pool_proj_bn = nn.BatchNorm2d(128, affine=True) 235 | self.inception_5b_relu_pool_proj = nn.ReLU(inplace) 236 | 237 | if need_bn: 238 | self.head = nn.Sequential(OrderedDict([ 239 | ('avgpool', nn.AdaptiveAvgPool2d(1)), 240 | ('bn', nn.BatchNorm2d(1024, eps=1e-5)), 241 | ('flatten', Flatten()), 242 | ('fc', nn.Linear(in_features=1024, out_features=512)), 243 | ])) 244 | else: 245 | self.head = nn.Sequential(OrderedDict([ 246 | ('avgpool', nn.AdaptiveAvgPool2d(1)), 247 | ('flatten', Flatten()), 248 | ('fc', nn.Linear(in_features=1024, out_features=512)), 249 | ])) 250 | 251 | def features(self, input): 252 | conv1_7x7_s2_out = self.conv1_7x7_s2(input) 253 | conv1_7x7_s2_bn_out = self.conv1_7x7_s2_bn(conv1_7x7_s2_out) 254 | conv1_relu_7x7_out = self.conv1_relu_7x7(conv1_7x7_s2_bn_out) 255 | pool1_3x3_s2_out = self.pool1_3x3_s2(conv1_relu_7x7_out) 256 | conv2_3x3_reduce_out = self.conv2_3x3_reduce(pool1_3x3_s2_out) 257 | conv2_3x3_reduce_bn_out = self.conv2_3x3_reduce_bn(conv2_3x3_reduce_out) 258 | conv2_relu_3x3_reduce_out = self.conv2_relu_3x3_reduce(conv2_3x3_reduce_bn_out) 259 | conv2_3x3_out = self.conv2_3x3(conv2_relu_3x3_reduce_out) 260 | conv2_3x3_bn_out = self.conv2_3x3_bn(conv2_3x3_out) 261 | conv2_relu_3x3_out = self.conv2_relu_3x3(conv2_3x3_bn_out) 262 | pool2_3x3_s2_out = self.pool2_3x3_s2(conv2_relu_3x3_out) 263 | 264 | inception_3a_1x1_out = self.inception_3a_1x1(pool2_3x3_s2_out) 265 | inception_3a_1x1_bn_out = self.inception_3a_1x1_bn(inception_3a_1x1_out) 266 | inception_3a_relu_1x1_out = self.inception_3a_relu_1x1(inception_3a_1x1_bn_out) 267 | inception_3a_3x3_reduce_out = self.inception_3a_3x3_reduce(pool2_3x3_s2_out) 268 | inception_3a_3x3_reduce_bn_out = self.inception_3a_3x3_reduce_bn(inception_3a_3x3_reduce_out) 269 | inception_3a_relu_3x3_reduce_out = self.inception_3a_relu_3x3_reduce(inception_3a_3x3_reduce_bn_out) 270 | inception_3a_3x3_out = self.inception_3a_3x3(inception_3a_relu_3x3_reduce_out) 271 | inception_3a_3x3_bn_out = self.inception_3a_3x3_bn(inception_3a_3x3_out) 272 | inception_3a_relu_3x3_out = self.inception_3a_relu_3x3(inception_3a_3x3_bn_out) 273 | inception_3a_double_3x3_reduce_out = self.inception_3a_double_3x3_reduce(pool2_3x3_s2_out) 274 | inception_3a_double_3x3_reduce_bn_out = self.inception_3a_double_3x3_reduce_bn( 275 | inception_3a_double_3x3_reduce_out) 276 | inception_3a_relu_double_3x3_reduce_out = self.inception_3a_relu_double_3x3_reduce( 277 | inception_3a_double_3x3_reduce_bn_out) 278 | inception_3a_double_3x3_1_out = self.inception_3a_double_3x3_1(inception_3a_relu_double_3x3_reduce_out) 279 | inception_3a_double_3x3_1_bn_out = self.inception_3a_double_3x3_1_bn(inception_3a_double_3x3_1_out) 280 | inception_3a_relu_double_3x3_1_out = self.inception_3a_relu_double_3x3_1(inception_3a_double_3x3_1_bn_out) 281 | inception_3a_double_3x3_2_out = self.inception_3a_double_3x3_2(inception_3a_relu_double_3x3_1_out) 282 | inception_3a_double_3x3_2_bn_out = self.inception_3a_double_3x3_2_bn(inception_3a_double_3x3_2_out) 283 | inception_3a_relu_double_3x3_2_out = self.inception_3a_relu_double_3x3_2(inception_3a_double_3x3_2_bn_out) 284 | inception_3a_pool_out = self.inception_3a_pool(pool2_3x3_s2_out) 285 | inception_3a_pool_proj_out = self.inception_3a_pool_proj(inception_3a_pool_out) 286 | inception_3a_pool_proj_bn_out = self.inception_3a_pool_proj_bn(inception_3a_pool_proj_out) 287 | inception_3a_relu_pool_proj_out = self.inception_3a_relu_pool_proj(inception_3a_pool_proj_bn_out) 288 | inception_3a_output_out = torch.cat( 289 | [inception_3a_relu_1x1_out, inception_3a_relu_3x3_out, inception_3a_relu_double_3x3_2_out, 290 | inception_3a_relu_pool_proj_out], 1) 291 | 292 | inception_3b_1x1_out = self.inception_3b_1x1(inception_3a_output_out) 293 | inception_3b_1x1_bn_out = self.inception_3b_1x1_bn(inception_3b_1x1_out) 294 | inception_3b_relu_1x1_out = self.inception_3b_relu_1x1(inception_3b_1x1_bn_out) 295 | inception_3b_3x3_reduce_out = self.inception_3b_3x3_reduce(inception_3a_output_out) 296 | inception_3b_3x3_reduce_bn_out = self.inception_3b_3x3_reduce_bn(inception_3b_3x3_reduce_out) 297 | inception_3b_relu_3x3_reduce_out = self.inception_3b_relu_3x3_reduce(inception_3b_3x3_reduce_bn_out) 298 | inception_3b_3x3_out = self.inception_3b_3x3(inception_3b_relu_3x3_reduce_out) 299 | inception_3b_3x3_bn_out = self.inception_3b_3x3_bn(inception_3b_3x3_out) 300 | inception_3b_relu_3x3_out = self.inception_3b_relu_3x3(inception_3b_3x3_bn_out) 301 | inception_3b_double_3x3_reduce_out = self.inception_3b_double_3x3_reduce(inception_3a_output_out) 302 | inception_3b_double_3x3_reduce_bn_out = self.inception_3b_double_3x3_reduce_bn( 303 | inception_3b_double_3x3_reduce_out) 304 | inception_3b_relu_double_3x3_reduce_out = self.inception_3b_relu_double_3x3_reduce( 305 | inception_3b_double_3x3_reduce_bn_out) 306 | inception_3b_double_3x3_1_out = self.inception_3b_double_3x3_1(inception_3b_relu_double_3x3_reduce_out) 307 | inception_3b_double_3x3_1_bn_out = self.inception_3b_double_3x3_1_bn(inception_3b_double_3x3_1_out) 308 | inception_3b_relu_double_3x3_1_out = self.inception_3b_relu_double_3x3_1(inception_3b_double_3x3_1_bn_out) 309 | inception_3b_double_3x3_2_out = self.inception_3b_double_3x3_2(inception_3b_relu_double_3x3_1_out) 310 | inception_3b_double_3x3_2_bn_out = self.inception_3b_double_3x3_2_bn(inception_3b_double_3x3_2_out) 311 | inception_3b_relu_double_3x3_2_out = self.inception_3b_relu_double_3x3_2(inception_3b_double_3x3_2_bn_out) 312 | inception_3b_pool_out = self.inception_3b_pool(inception_3a_output_out) 313 | inception_3b_pool_proj_out = self.inception_3b_pool_proj(inception_3b_pool_out) 314 | inception_3b_pool_proj_bn_out = self.inception_3b_pool_proj_bn(inception_3b_pool_proj_out) 315 | inception_3b_relu_pool_proj_out = self.inception_3b_relu_pool_proj(inception_3b_pool_proj_bn_out) 316 | inception_3b_output_out = torch.cat( 317 | [inception_3b_relu_1x1_out, inception_3b_relu_3x3_out, inception_3b_relu_double_3x3_2_out, 318 | inception_3b_relu_pool_proj_out], 1) 319 | 320 | inception_3c_3x3_reduce_out = self.inception_3c_3x3_reduce(inception_3b_output_out) 321 | inception_3c_3x3_reduce_bn_out = self.inception_3c_3x3_reduce_bn(inception_3c_3x3_reduce_out) 322 | inception_3c_relu_3x3_reduce_out = self.inception_3c_relu_3x3_reduce(inception_3c_3x3_reduce_bn_out) 323 | inception_3c_3x3_out = self.inception_3c_3x3(inception_3c_relu_3x3_reduce_out) 324 | inception_3c_3x3_bn_out = self.inception_3c_3x3_bn(inception_3c_3x3_out) 325 | inception_3c_relu_3x3_out = self.inception_3c_relu_3x3(inception_3c_3x3_bn_out) 326 | inception_3c_double_3x3_reduce_out = self.inception_3c_double_3x3_reduce(inception_3b_output_out) 327 | inception_3c_double_3x3_reduce_bn_out = self.inception_3c_double_3x3_reduce_bn( 328 | inception_3c_double_3x3_reduce_out) 329 | inception_3c_relu_double_3x3_reduce_out = self.inception_3c_relu_double_3x3_reduce( 330 | inception_3c_double_3x3_reduce_bn_out) 331 | inception_3c_double_3x3_1_out = self.inception_3c_double_3x3_1(inception_3c_relu_double_3x3_reduce_out) 332 | inception_3c_double_3x3_1_bn_out = self.inception_3c_double_3x3_1_bn(inception_3c_double_3x3_1_out) 333 | inception_3c_relu_double_3x3_1_out = self.inception_3c_relu_double_3x3_1(inception_3c_double_3x3_1_bn_out) 334 | inception_3c_double_3x3_2_out = self.inception_3c_double_3x3_2(inception_3c_relu_double_3x3_1_out) 335 | inception_3c_double_3x3_2_bn_out = self.inception_3c_double_3x3_2_bn(inception_3c_double_3x3_2_out) 336 | inception_3c_relu_double_3x3_2_out = self.inception_3c_relu_double_3x3_2(inception_3c_double_3x3_2_bn_out) 337 | inception_3c_pool_out = self.inception_3c_pool(inception_3b_output_out) 338 | inception_3c_output_out = torch.cat( 339 | [inception_3c_relu_3x3_out, inception_3c_relu_double_3x3_2_out, inception_3c_pool_out], 1) 340 | 341 | inception_4a_1x1_out = self.inception_4a_1x1(inception_3c_output_out) 342 | inception_4a_1x1_bn_out = self.inception_4a_1x1_bn(inception_4a_1x1_out) 343 | inception_4a_relu_1x1_out = self.inception_4a_relu_1x1(inception_4a_1x1_bn_out) 344 | inception_4a_3x3_reduce_out = self.inception_4a_3x3_reduce(inception_3c_output_out) 345 | inception_4a_3x3_reduce_bn_out = self.inception_4a_3x3_reduce_bn(inception_4a_3x3_reduce_out) 346 | inception_4a_relu_3x3_reduce_out = self.inception_4a_relu_3x3_reduce(inception_4a_3x3_reduce_bn_out) 347 | inception_4a_3x3_out = self.inception_4a_3x3(inception_4a_relu_3x3_reduce_out) 348 | inception_4a_3x3_bn_out = self.inception_4a_3x3_bn(inception_4a_3x3_out) 349 | inception_4a_relu_3x3_out = self.inception_4a_relu_3x3(inception_4a_3x3_bn_out) 350 | inception_4a_double_3x3_reduce_out = self.inception_4a_double_3x3_reduce(inception_3c_output_out) 351 | inception_4a_double_3x3_reduce_bn_out = self.inception_4a_double_3x3_reduce_bn( 352 | inception_4a_double_3x3_reduce_out) 353 | inception_4a_relu_double_3x3_reduce_out = self.inception_4a_relu_double_3x3_reduce( 354 | inception_4a_double_3x3_reduce_bn_out) 355 | inception_4a_double_3x3_1_out = self.inception_4a_double_3x3_1(inception_4a_relu_double_3x3_reduce_out) 356 | inception_4a_double_3x3_1_bn_out = self.inception_4a_double_3x3_1_bn(inception_4a_double_3x3_1_out) 357 | inception_4a_relu_double_3x3_1_out = self.inception_4a_relu_double_3x3_1(inception_4a_double_3x3_1_bn_out) 358 | inception_4a_double_3x3_2_out = self.inception_4a_double_3x3_2(inception_4a_relu_double_3x3_1_out) 359 | inception_4a_double_3x3_2_bn_out = self.inception_4a_double_3x3_2_bn(inception_4a_double_3x3_2_out) 360 | inception_4a_relu_double_3x3_2_out = self.inception_4a_relu_double_3x3_2(inception_4a_double_3x3_2_bn_out) 361 | inception_4a_pool_out = self.inception_4a_pool(inception_3c_output_out) 362 | inception_4a_pool_proj_out = self.inception_4a_pool_proj(inception_4a_pool_out) 363 | inception_4a_pool_proj_bn_out = self.inception_4a_pool_proj_bn(inception_4a_pool_proj_out) 364 | inception_4a_relu_pool_proj_out = self.inception_4a_relu_pool_proj(inception_4a_pool_proj_bn_out) 365 | inception_4a_output_out = torch.cat( 366 | [inception_4a_relu_1x1_out, inception_4a_relu_3x3_out, inception_4a_relu_double_3x3_2_out, 367 | inception_4a_relu_pool_proj_out], 1) 368 | 369 | inception_4b_1x1_out = self.inception_4b_1x1(inception_4a_output_out) 370 | inception_4b_1x1_bn_out = self.inception_4b_1x1_bn(inception_4b_1x1_out) 371 | inception_4b_relu_1x1_out = self.inception_4b_relu_1x1(inception_4b_1x1_bn_out) 372 | inception_4b_3x3_reduce_out = self.inception_4b_3x3_reduce(inception_4a_output_out) 373 | inception_4b_3x3_reduce_bn_out = self.inception_4b_3x3_reduce_bn(inception_4b_3x3_reduce_out) 374 | inception_4b_relu_3x3_reduce_out = self.inception_4b_relu_3x3_reduce(inception_4b_3x3_reduce_bn_out) 375 | inception_4b_3x3_out = self.inception_4b_3x3(inception_4b_relu_3x3_reduce_out) 376 | inception_4b_3x3_bn_out = self.inception_4b_3x3_bn(inception_4b_3x3_out) 377 | inception_4b_relu_3x3_out = self.inception_4b_relu_3x3(inception_4b_3x3_bn_out) 378 | inception_4b_double_3x3_reduce_out = self.inception_4b_double_3x3_reduce(inception_4a_output_out) 379 | inception_4b_double_3x3_reduce_bn_out = self.inception_4b_double_3x3_reduce_bn( 380 | inception_4b_double_3x3_reduce_out) 381 | inception_4b_relu_double_3x3_reduce_out = self.inception_4b_relu_double_3x3_reduce( 382 | inception_4b_double_3x3_reduce_bn_out) 383 | inception_4b_double_3x3_1_out = self.inception_4b_double_3x3_1(inception_4b_relu_double_3x3_reduce_out) 384 | inception_4b_double_3x3_1_bn_out = self.inception_4b_double_3x3_1_bn(inception_4b_double_3x3_1_out) 385 | inception_4b_relu_double_3x3_1_out = self.inception_4b_relu_double_3x3_1(inception_4b_double_3x3_1_bn_out) 386 | inception_4b_double_3x3_2_out = self.inception_4b_double_3x3_2(inception_4b_relu_double_3x3_1_out) 387 | inception_4b_double_3x3_2_bn_out = self.inception_4b_double_3x3_2_bn(inception_4b_double_3x3_2_out) 388 | inception_4b_relu_double_3x3_2_out = self.inception_4b_relu_double_3x3_2(inception_4b_double_3x3_2_bn_out) 389 | inception_4b_pool_out = self.inception_4b_pool(inception_4a_output_out) 390 | inception_4b_pool_proj_out = self.inception_4b_pool_proj(inception_4b_pool_out) 391 | inception_4b_pool_proj_bn_out = self.inception_4b_pool_proj_bn(inception_4b_pool_proj_out) 392 | inception_4b_relu_pool_proj_out = self.inception_4b_relu_pool_proj(inception_4b_pool_proj_bn_out) 393 | inception_4b_output_out = torch.cat( 394 | [inception_4b_relu_1x1_out, inception_4b_relu_3x3_out, inception_4b_relu_double_3x3_2_out, 395 | inception_4b_relu_pool_proj_out], 1) 396 | 397 | inception_4c_1x1_out = self.inception_4c_1x1(inception_4b_output_out) 398 | inception_4c_1x1_bn_out = self.inception_4c_1x1_bn(inception_4c_1x1_out) 399 | inception_4c_relu_1x1_out = self.inception_4c_relu_1x1(inception_4c_1x1_bn_out) 400 | inception_4c_3x3_reduce_out = self.inception_4c_3x3_reduce(inception_4b_output_out) 401 | inception_4c_3x3_reduce_bn_out = self.inception_4c_3x3_reduce_bn(inception_4c_3x3_reduce_out) 402 | inception_4c_relu_3x3_reduce_out = self.inception_4c_relu_3x3_reduce(inception_4c_3x3_reduce_bn_out) 403 | inception_4c_3x3_out = self.inception_4c_3x3(inception_4c_relu_3x3_reduce_out) 404 | inception_4c_3x3_bn_out = self.inception_4c_3x3_bn(inception_4c_3x3_out) 405 | inception_4c_relu_3x3_out = self.inception_4c_relu_3x3(inception_4c_3x3_bn_out) 406 | inception_4c_double_3x3_reduce_out = self.inception_4c_double_3x3_reduce(inception_4b_output_out) 407 | inception_4c_double_3x3_reduce_bn_out = self.inception_4c_double_3x3_reduce_bn( 408 | inception_4c_double_3x3_reduce_out) 409 | inception_4c_relu_double_3x3_reduce_out = self.inception_4c_relu_double_3x3_reduce( 410 | inception_4c_double_3x3_reduce_bn_out) 411 | inception_4c_double_3x3_1_out = self.inception_4c_double_3x3_1(inception_4c_relu_double_3x3_reduce_out) 412 | inception_4c_double_3x3_1_bn_out = self.inception_4c_double_3x3_1_bn(inception_4c_double_3x3_1_out) 413 | inception_4c_relu_double_3x3_1_out = self.inception_4c_relu_double_3x3_1(inception_4c_double_3x3_1_bn_out) 414 | inception_4c_double_3x3_2_out = self.inception_4c_double_3x3_2(inception_4c_relu_double_3x3_1_out) 415 | inception_4c_double_3x3_2_bn_out = self.inception_4c_double_3x3_2_bn(inception_4c_double_3x3_2_out) 416 | inception_4c_relu_double_3x3_2_out = self.inception_4c_relu_double_3x3_2(inception_4c_double_3x3_2_bn_out) 417 | inception_4c_pool_out = self.inception_4c_pool(inception_4b_output_out) 418 | inception_4c_pool_proj_out = self.inception_4c_pool_proj(inception_4c_pool_out) 419 | inception_4c_pool_proj_bn_out = self.inception_4c_pool_proj_bn(inception_4c_pool_proj_out) 420 | inception_4c_relu_pool_proj_out = self.inception_4c_relu_pool_proj(inception_4c_pool_proj_bn_out) 421 | inception_4c_output_out = torch.cat( 422 | [inception_4c_relu_1x1_out, inception_4c_relu_3x3_out, inception_4c_relu_double_3x3_2_out, 423 | inception_4c_relu_pool_proj_out], 1) 424 | 425 | inception_4d_1x1_out = self.inception_4d_1x1(inception_4c_output_out) 426 | inception_4d_1x1_bn_out = self.inception_4d_1x1_bn(inception_4d_1x1_out) 427 | inception_4d_relu_1x1_out = self.inception_4d_relu_1x1(inception_4d_1x1_bn_out) 428 | inception_4d_3x3_reduce_out = self.inception_4d_3x3_reduce(inception_4c_output_out) 429 | inception_4d_3x3_reduce_bn_out = self.inception_4d_3x3_reduce_bn(inception_4d_3x3_reduce_out) 430 | inception_4d_relu_3x3_reduce_out = self.inception_4d_relu_3x3_reduce(inception_4d_3x3_reduce_bn_out) 431 | inception_4d_3x3_out = self.inception_4d_3x3(inception_4d_relu_3x3_reduce_out) 432 | inception_4d_3x3_bn_out = self.inception_4d_3x3_bn(inception_4d_3x3_out) 433 | inception_4d_relu_3x3_out = self.inception_4d_relu_3x3(inception_4d_3x3_bn_out) 434 | inception_4d_double_3x3_reduce_out = self.inception_4d_double_3x3_reduce(inception_4c_output_out) 435 | inception_4d_double_3x3_reduce_bn_out = self.inception_4d_double_3x3_reduce_bn( 436 | inception_4d_double_3x3_reduce_out) 437 | inception_4d_relu_double_3x3_reduce_out = self.inception_4d_relu_double_3x3_reduce( 438 | inception_4d_double_3x3_reduce_bn_out) 439 | inception_4d_double_3x3_1_out = self.inception_4d_double_3x3_1(inception_4d_relu_double_3x3_reduce_out) 440 | inception_4d_double_3x3_1_bn_out = self.inception_4d_double_3x3_1_bn(inception_4d_double_3x3_1_out) 441 | inception_4d_relu_double_3x3_1_out = self.inception_4d_relu_double_3x3_1(inception_4d_double_3x3_1_bn_out) 442 | inception_4d_double_3x3_2_out = self.inception_4d_double_3x3_2(inception_4d_relu_double_3x3_1_out) 443 | inception_4d_double_3x3_2_bn_out = self.inception_4d_double_3x3_2_bn(inception_4d_double_3x3_2_out) 444 | inception_4d_relu_double_3x3_2_out = self.inception_4d_relu_double_3x3_2(inception_4d_double_3x3_2_bn_out) 445 | inception_4d_pool_out = self.inception_4d_pool(inception_4c_output_out) 446 | inception_4d_pool_proj_out = self.inception_4d_pool_proj(inception_4d_pool_out) 447 | inception_4d_pool_proj_bn_out = self.inception_4d_pool_proj_bn(inception_4d_pool_proj_out) 448 | inception_4d_relu_pool_proj_out = self.inception_4d_relu_pool_proj(inception_4d_pool_proj_bn_out) 449 | inception_4d_output_out = torch.cat( 450 | [inception_4d_relu_1x1_out, inception_4d_relu_3x3_out, inception_4d_relu_double_3x3_2_out, 451 | inception_4d_relu_pool_proj_out], 1) 452 | 453 | inception_4e_3x3_reduce_out = self.inception_4e_3x3_reduce(inception_4d_output_out) 454 | inception_4e_3x3_reduce_bn_out = self.inception_4e_3x3_reduce_bn(inception_4e_3x3_reduce_out) 455 | inception_4e_relu_3x3_reduce_out = self.inception_4e_relu_3x3_reduce(inception_4e_3x3_reduce_bn_out) 456 | inception_4e_3x3_out = self.inception_4e_3x3(inception_4e_relu_3x3_reduce_out) 457 | inception_4e_3x3_bn_out = self.inception_4e_3x3_bn(inception_4e_3x3_out) 458 | inception_4e_relu_3x3_out = self.inception_4e_relu_3x3(inception_4e_3x3_bn_out) 459 | inception_4e_double_3x3_reduce_out = self.inception_4e_double_3x3_reduce(inception_4d_output_out) 460 | inception_4e_double_3x3_reduce_bn_out = self.inception_4e_double_3x3_reduce_bn( 461 | inception_4e_double_3x3_reduce_out) 462 | inception_4e_relu_double_3x3_reduce_out = self.inception_4e_relu_double_3x3_reduce( 463 | inception_4e_double_3x3_reduce_bn_out) 464 | inception_4e_double_3x3_1_out = self.inception_4e_double_3x3_1(inception_4e_relu_double_3x3_reduce_out) 465 | inception_4e_double_3x3_1_bn_out = self.inception_4e_double_3x3_1_bn(inception_4e_double_3x3_1_out) 466 | inception_4e_relu_double_3x3_1_out = self.inception_4e_relu_double_3x3_1(inception_4e_double_3x3_1_bn_out) 467 | inception_4e_double_3x3_2_out = self.inception_4e_double_3x3_2(inception_4e_relu_double_3x3_1_out) 468 | inception_4e_double_3x3_2_bn_out = self.inception_4e_double_3x3_2_bn(inception_4e_double_3x3_2_out) 469 | inception_4e_relu_double_3x3_2_out = self.inception_4e_relu_double_3x3_2(inception_4e_double_3x3_2_bn_out) 470 | inception_4e_pool_out = self.inception_4e_pool(inception_4d_output_out) 471 | inception_4e_output_out = torch.cat( 472 | [inception_4e_relu_3x3_out, inception_4e_relu_double_3x3_2_out, inception_4e_pool_out], 1) 473 | 474 | inception_5a_1x1_out = self.inception_5a_1x1(inception_4e_output_out) 475 | inception_5a_1x1_bn_out = self.inception_5a_1x1_bn(inception_5a_1x1_out) 476 | inception_5a_relu_1x1_out = self.inception_5a_relu_1x1(inception_5a_1x1_bn_out) 477 | inception_5a_3x3_reduce_out = self.inception_5a_3x3_reduce(inception_4e_output_out) 478 | inception_5a_3x3_reduce_bn_out = self.inception_5a_3x3_reduce_bn(inception_5a_3x3_reduce_out) 479 | inception_5a_relu_3x3_reduce_out = self.inception_5a_relu_3x3_reduce(inception_5a_3x3_reduce_bn_out) 480 | inception_5a_3x3_out = self.inception_5a_3x3(inception_5a_relu_3x3_reduce_out) 481 | inception_5a_3x3_bn_out = self.inception_5a_3x3_bn(inception_5a_3x3_out) 482 | inception_5a_relu_3x3_out = self.inception_5a_relu_3x3(inception_5a_3x3_bn_out) 483 | inception_5a_double_3x3_reduce_out = self.inception_5a_double_3x3_reduce(inception_4e_output_out) 484 | inception_5a_double_3x3_reduce_bn_out = self.inception_5a_double_3x3_reduce_bn( 485 | inception_5a_double_3x3_reduce_out) 486 | inception_5a_relu_double_3x3_reduce_out = self.inception_5a_relu_double_3x3_reduce( 487 | inception_5a_double_3x3_reduce_bn_out) 488 | inception_5a_double_3x3_1_out = self.inception_5a_double_3x3_1(inception_5a_relu_double_3x3_reduce_out) 489 | inception_5a_double_3x3_1_bn_out = self.inception_5a_double_3x3_1_bn(inception_5a_double_3x3_1_out) 490 | inception_5a_relu_double_3x3_1_out = self.inception_5a_relu_double_3x3_1(inception_5a_double_3x3_1_bn_out) 491 | inception_5a_double_3x3_2_out = self.inception_5a_double_3x3_2(inception_5a_relu_double_3x3_1_out) 492 | inception_5a_double_3x3_2_bn_out = self.inception_5a_double_3x3_2_bn(inception_5a_double_3x3_2_out) 493 | inception_5a_relu_double_3x3_2_out = self.inception_5a_relu_double_3x3_2(inception_5a_double_3x3_2_bn_out) 494 | inception_5a_pool_out = self.inception_5a_pool(inception_4e_output_out) 495 | inception_5a_pool_proj_out = self.inception_5a_pool_proj(inception_5a_pool_out) 496 | inception_5a_pool_proj_bn_out = self.inception_5a_pool_proj_bn(inception_5a_pool_proj_out) 497 | inception_5a_relu_pool_proj_out = self.inception_5a_relu_pool_proj(inception_5a_pool_proj_bn_out) 498 | inception_5a_output_out = torch.cat( 499 | [inception_5a_relu_1x1_out, inception_5a_relu_3x3_out, inception_5a_relu_double_3x3_2_out, 500 | inception_5a_relu_pool_proj_out], 1) 501 | 502 | inception_5b_1x1_out = self.inception_5b_1x1(inception_5a_output_out) 503 | inception_5b_1x1_bn_out = self.inception_5b_1x1_bn(inception_5b_1x1_out) 504 | inception_5b_relu_1x1_out = self.inception_5b_relu_1x1(inception_5b_1x1_bn_out) 505 | inception_5b_3x3_reduce_out = self.inception_5b_3x3_reduce(inception_5a_output_out) 506 | inception_5b_3x3_reduce_bn_out = self.inception_5b_3x3_reduce_bn(inception_5b_3x3_reduce_out) 507 | inception_5b_relu_3x3_reduce_out = self.inception_5b_relu_3x3_reduce(inception_5b_3x3_reduce_bn_out) 508 | inception_5b_3x3_out = self.inception_5b_3x3(inception_5b_relu_3x3_reduce_out) 509 | inception_5b_3x3_bn_out = self.inception_5b_3x3_bn(inception_5b_3x3_out) 510 | inception_5b_relu_3x3_out = self.inception_5b_relu_3x3(inception_5b_3x3_bn_out) 511 | inception_5b_double_3x3_reduce_out = self.inception_5b_double_3x3_reduce(inception_5a_output_out) 512 | inception_5b_double_3x3_reduce_bn_out = self.inception_5b_double_3x3_reduce_bn( 513 | inception_5b_double_3x3_reduce_out) 514 | inception_5b_relu_double_3x3_reduce_out = self.inception_5b_relu_double_3x3_reduce( 515 | inception_5b_double_3x3_reduce_bn_out) 516 | inception_5b_double_3x3_1_out = self.inception_5b_double_3x3_1(inception_5b_relu_double_3x3_reduce_out) 517 | inception_5b_double_3x3_1_bn_out = self.inception_5b_double_3x3_1_bn(inception_5b_double_3x3_1_out) 518 | inception_5b_relu_double_3x3_1_out = self.inception_5b_relu_double_3x3_1(inception_5b_double_3x3_1_bn_out) 519 | inception_5b_double_3x3_2_out = self.inception_5b_double_3x3_2(inception_5b_relu_double_3x3_1_out) 520 | inception_5b_double_3x3_2_bn_out = self.inception_5b_double_3x3_2_bn(inception_5b_double_3x3_2_out) 521 | inception_5b_relu_double_3x3_2_out = self.inception_5b_relu_double_3x3_2(inception_5b_double_3x3_2_bn_out) 522 | inception_5b_pool_out = self.inception_5b_pool(inception_5a_output_out) 523 | inception_5b_pool_proj_out = self.inception_5b_pool_proj(inception_5b_pool_out) 524 | inception_5b_pool_proj_bn_out = self.inception_5b_pool_proj_bn(inception_5b_pool_proj_out) 525 | inception_5b_relu_pool_proj_out = self.inception_5b_relu_pool_proj(inception_5b_pool_proj_bn_out) 526 | inception_5b_output_out = torch.cat( 527 | [inception_5b_relu_1x1_out, inception_5b_relu_3x3_out, inception_5b_relu_double_3x3_2_out, 528 | inception_5b_relu_pool_proj_out], 1) 529 | 530 | return inception_5b_output_out 531 | 532 | def forward(self, input, normalized=False): 533 | bbout = self.features(input) 534 | x = self.head(bbout) 535 | 536 | if normalized: 537 | norm = x.norm(dim=1, p=2, keepdim=True) 538 | x = x.div(norm.expand_as(x)) 539 | 540 | return x 541 | -------------------------------------------------------------------------------- /mytrain.py: -------------------------------------------------------------------------------- 1 | import myutils 2 | from config import get_config 3 | from learner import metric_learner 4 | import argparse 5 | from pathlib import Path 6 | import numpy as np 7 | import torch 8 | 9 | if __name__ == '__main__': 10 | 11 | conf = get_config() 12 | 13 | learner = metric_learner(conf) 14 | 15 | learner.load_bninception_pretrained(conf) 16 | 17 | learner.train(conf) 18 | 19 | -------------------------------------------------------------------------------- /mytrain.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | export PYTHONPATH=. 3 | 4 | python mytrain.py \ 5 | --use_dataset CUB \ 6 | --instances 3 \ 7 | --use_loss triplet \ 8 | --lr 0.5e-5 \ 9 | --lr_p 0.25e-5 \ 10 | --lr_gamma 0.1 \ 11 | --sec_wei 1.0 12 | -------------------------------------------------------------------------------- /myutils.py: -------------------------------------------------------------------------------- 1 | import os, sys, \ 2 | subprocess, glob, re, \ 3 | numpy as np, \ 4 | logging, \ 5 | collections, copy, \ 6 | datetime 7 | from os import path as osp 8 | import time 9 | 10 | root_path = osp.normpath(osp.join(osp.abspath(osp.dirname(__file__)), )) + '/' 11 | sys.path.insert(0, root_path) 12 | 13 | def set_stream_logger(log_level=logging.DEBUG): 14 | import colorlog 15 | sh = colorlog.StreamHandler() 16 | sh.setLevel(log_level) 17 | sh.setFormatter( 18 | colorlog.ColoredFormatter( 19 | ' %(asctime)s %(filename)s [line:%(lineno)d] %(log_color)s%(levelname)s%(reset)s %(message)s')) 20 | logging.root.addHandler(sh) 21 | 22 | def set_file_logger(work_dir=None, log_level=logging.DEBUG): 23 | work_dir = work_dir or root_path 24 | fh = logging.FileHandler(os.path.join(work_dir, 'log-ing')) 25 | fh.setLevel(log_level) 26 | fh.setFormatter( 27 | logging.Formatter('%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s')) 28 | logging.root.addHandler(fh) 29 | 30 | logging.root.setLevel(logging.INFO) 31 | set_stream_logger(logging.DEBUG) 32 | 33 | def shell(cmd, block=True, return_msg=True, verbose=True, timeout=None): 34 | import os 35 | my_env = os.environ.copy() 36 | home = os.path.expanduser('~') 37 | my_env['http_proxy'] = '' 38 | my_env['https_proxy'] = '' 39 | if verbose: 40 | logging.info('cmd is ' + cmd) 41 | if block: 42 | 43 | task = subprocess.Popen(cmd, 44 | shell=True, 45 | stdout=subprocess.PIPE, 46 | stderr=subprocess.PIPE, 47 | env=my_env, 48 | preexec_fn=os.setsid 49 | ) 50 | if return_msg: 51 | msg = task.communicate(timeout) 52 | msg = [msg_.decode('utf-8') for msg_ in msg] 53 | if msg[0] != '' and verbose: 54 | logging.info('stdout {}'.format(msg[0])) 55 | if msg[1] != '' and verbose: 56 | logging.error('stderr {}'.format(msg[1])) 57 | return msg 58 | else: 59 | return task 60 | else: 61 | logging.debug('Non-block!') 62 | task = subprocess.Popen(cmd, 63 | shell=True, 64 | stdout=subprocess.PIPE, 65 | stderr=subprocess.PIPE, 66 | env=my_env, 67 | preexec_fn=os.setsid 68 | ) 69 | return task 70 | 71 | def rm(path, block=True): 72 | path = osp.abspath(path) 73 | if not osp.exists(path): 74 | logging.info(f'no need rm {path}') 75 | stdout, _ = shell('which trash', verbose=False) 76 | if 'trash' not in stdout: 77 | dst = glob.glob('{}.bak*'.format(path)) 78 | parsr = re.compile(r'{}.bak(\d+?)'.format(path)) 79 | used = [0, ] 80 | for d in dst: 81 | m = re.match(parsr, d) 82 | if not m: 83 | used.append(0) 84 | elif m.groups()[0] == '': 85 | used.append(0) 86 | else: 87 | used.append(int(m.groups()[0])) 88 | dst_path = '{}.bak{}'.format(path, max(used) + 1) 89 | cmd = 'mv {} {} '.format(path, dst_path) 90 | return shell(cmd, block=block) 91 | else: 92 | return shell(f'trash -r {path}', block=block) 93 | 94 | def mkdir_p(path, delete=True): 95 | path = str(path) 96 | if path == '': 97 | return 98 | if delete and osp.exists(path): 99 | rm(path) 100 | if not osp.exists(path): 101 | shell('mkdir -p ' + path) 102 | 103 | 104 | class Logger(object): 105 | def __init__(self, fpath=None, console=sys.stdout): 106 | self.console = console 107 | self.file = None 108 | if fpath is not None: 109 | mkdir_p(os.path.dirname(fpath), delete=False) 110 | 111 | self.file = open(fpath, 'a') 112 | 113 | def __del__(self): 114 | self.close() 115 | 116 | def __enter__(self): 117 | pass 118 | 119 | def __exit__(self, *args): 120 | self.close() 121 | 122 | def write(self, msg): 123 | self.console.write(msg) 124 | if self.file is not None: 125 | self.file.write(msg) 126 | 127 | def flush(self): 128 | self.console.flush() 129 | if self.file is not None: 130 | self.file.flush() 131 | os.fsync(self.file.fileno()) 132 | 133 | def close(self): 134 | self.console.close() 135 | if self.file is not None: 136 | self.file.close() 137 | 138 | 139 | class Timer(object): 140 | """A flexible Timer class. 141 | 142 | :Example: 143 | 144 | >>> import time 145 | >>> import cvbase as cvb 146 | >>> with cvb.Timer(): 147 | >>> # simulate a code block that will run for 1s 148 | >>> time.sleep(1) 149 | 1.000 150 | >>> with cvb.Timer(print_tmpl='hey it taks {:.1f} seconds'): 151 | >>> # simulate a code block that will run for 1s 152 | >>> time.sleep(1) 153 | hey it taks 1.0 seconds 154 | >>> timer = cvb.Timer() 155 | >>> time.sleep(0.5) 156 | >>> print(timer.since_start()) 157 | 0.500 158 | >>> time.sleep(0.5) 159 | >>> print(timer.since_last_check()) 160 | 0.500 161 | >>> print(timer.since_start()) 162 | 1.000 163 | 164 | """ 165 | 166 | def __init__(self, start=True, print_tmpl=None): 167 | self._is_running = False 168 | self.print_tmpl = print_tmpl if print_tmpl else '{:.3f}' 169 | if start: 170 | self.start() 171 | 172 | @property 173 | def is_running(self): 174 | """bool: indicate whether the timer is running""" 175 | return self._is_running 176 | 177 | def __enter__(self): 178 | self.start() 179 | return self 180 | 181 | def __exit__(self, type, value, traceback): 182 | print(self.print_tmpl.format(self.since_last_check())) 183 | self._is_running = False 184 | 185 | def start(self): 186 | """Start the timer.""" 187 | if not self._is_running: 188 | self._t_start = time.time() 189 | self._is_running = True 190 | self._t_last = time.time() 191 | 192 | def since_start(self, aux=''): 193 | """Total time since the timer is started. 194 | 195 | Returns(float): the time in seconds 196 | """ 197 | if not self._is_running: 198 | raise ValueError('timer is not running') 199 | self._t_last = time.time() 200 | logging.info(f'{aux} time {self.print_tmpl.format(self._t_last - self._t_start)}') 201 | return self._t_last - self._t_start 202 | 203 | def since_last_check(self, aux='', verbose=True): 204 | """Time since the last checking. 205 | 206 | Either :func:`since_start` or :func:`since_last_check` is a checking operation. 207 | 208 | Returns(float): the time in seconds 209 | """ 210 | if not self._is_running: 211 | raise ValueError('timer is not running') 212 | dur = time.time() - self._t_last 213 | self._t_last = time.time() 214 | if verbose: 215 | logging.info(f'{aux} time {self.print_tmpl.format(dur)}') 216 | return dur 217 | 218 | 219 | class AverageMeter(object): 220 | """Computes and stores the average and current value""" 221 | 222 | def __init__(self, maxlen=100): 223 | 224 | self.val = 0 225 | self.avg = 0 226 | self.sum = 0 227 | self.count = 0 228 | self.mem = collections.deque(maxlen=maxlen) 229 | 230 | def reset(self): 231 | self.val = 0 232 | self.avg = 0 233 | self.sum = 0 234 | self.count = 0 235 | 236 | def update(self, val, n=1): 237 | val = float(val) 238 | self.mem.append(val) 239 | self.avg = np.mean(list(self.mem)) 240 | 241 | 242 | timer = Timer() 243 | logging.info('import myutils') 244 | -------------------------------------------------------------------------------- /test_sop.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from config import get_config 3 | from learner import metric_learner 4 | 5 | if __name__ == '__main__': 6 | conf = get_config() 7 | 8 | test_path = 'work_space/' + conf.test_sop_model + '/models' 9 | logging.info(test_path) 10 | test_name = 'SOP' 11 | 12 | conf.use_dataset = test_name 13 | learner = metric_learner(conf, inference=True) 14 | learner.load_state(conf, resume_path=test_path) 15 | 16 | nmi, f1, recall_ks = learner.test_sop_complete(conf) 17 | 18 | ks_dict = dict() 19 | ks_dict['CUB'] = [1, 2, 4, 8, 16, 32] 20 | ks_dict['Cars'] = [1, 2, 4, 8, 16, 32] 21 | ks_dict['SOP'] = [1, 10, 100, 1000, 10000] 22 | ks_dict['Inshop'] = [1, 10, 20, 30, 40, 50] 23 | k_s = ks_dict[test_name] 24 | 25 | logging.info(f'nmi: {nmi}') 26 | logging.info(f'f1: {f1}') 27 | for i in range(len(recall_ks)): 28 | logging.info(f'R{k_s[i]} {recall_ks[i]}') 29 | -------------------------------------------------------------------------------- /test_sop.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | export PYTHONPATH=. 3 | 4 | python test_sop.py \ 5 | --use_dataset SOP \ 6 | --test_sop_model SOP_0000_0000 7 | 8 | 9 | 10 | 11 | 12 | 13 | --------------------------------------------------------------------------------