├── .gitignore
├── LICENSE
├── README.md
├── config.py
├── data_engine.py
├── datasets
└── split_train_test.py
├── evaluation.py
├── imgs
├── SEC.png
└── experiment_results.png
├── learner.py
├── loss.py
├── models
└── bninception.py
├── mytrain.py
├── mytrain.sh
├── myutils.py
├── test_sop.py
└── test_sop.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea/
2 | data/
3 | output*/
4 | ckpts/
5 | *.pth
6 | *.t7
7 | *.png
8 | *.jpg
9 | tmp*.py
10 | # run*.sh
11 | *.pdf
12 |
13 |
14 | # Byte-compiled / optimized / DLL files
15 | __pycache__/
16 | *.py[cod]
17 | *$py.class
18 |
19 | # C extensions
20 | *.so
21 |
22 | # Distribution / packaging
23 | .Python
24 | build/
25 | develop-eggs/
26 | dist/
27 | downloads/
28 | eggs/
29 | .eggs/
30 | lib/
31 | lib64/
32 | parts/
33 | sdist/
34 | var/
35 | wheels/
36 | *.egg-info/
37 | .installed.cfg
38 | *.egg
39 | MANIFEST
40 |
41 | # PyInstaller
42 | # Usually these files are written by a python script from a template
43 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
44 | *.manifest
45 | *.spec
46 |
47 | # Installer logs
48 | pip-log.txt
49 | pip-delete-this-directory.txt
50 |
51 | # Unit test / coverage reports
52 | htmlcov/
53 | .tox/
54 | .coverage
55 | .coverage.*
56 | .cache
57 | nosetests.xml
58 | coverage.xml
59 | *.cover
60 | .hypothesis/
61 | .pytest_cache/
62 |
63 | # Translations
64 | *.mo
65 | *.pot
66 |
67 | # Django stuff:
68 | *.log
69 | local_settings.py
70 | db.sqlite3
71 |
72 | # Flask stuff:
73 | instance/
74 | .webassets-cache
75 |
76 | # Scrapy stuff:
77 | .scrapy
78 |
79 | # Sphinx documentation
80 | docs/_build/
81 |
82 | # PyBuilder
83 | target/
84 |
85 | # Jupyter Notebook
86 | .ipynb_checkpoints
87 |
88 | # pyenv
89 | .python-version
90 |
91 | # celery beat schedule file
92 | celerybeat-schedule
93 |
94 | # SageMath parsed files
95 | *.sage.py
96 |
97 | # Environments
98 | .env
99 | .venv
100 | env/
101 | venv/
102 | ENV/
103 | env.bak/
104 | venv.bak/
105 |
106 | # Spyder project settings
107 | .spyderproject
108 | .spyproject
109 |
110 | # Rope project settings
111 | .ropeproject
112 |
113 | # mkdocs documentation
114 | /site
115 |
116 | # mypy
117 | .mypy_cache/
118 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Dyfine
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SphericalEmbedding
2 |
3 | This repository is the official implementation of [Deep Metric Learning with Spherical Embedding](https://arxiv.org/abs/2011.02785) on deep metric learning (DML) task.
4 |
5 | >📋 Training a vanilla triplet loss / semihard triplet loss / normalized N-pair loss (tuplet loss) / multi-similarity loss on CUB200-2011 / Cars196 / SOP / In-Shop datasets.
6 |
7 |

8 |
9 |
10 | ## Requirements
11 |
12 | This repo was tested with Ubuntu 16.04.1 LTS, Python 3.6, PyTorch 1.1.0, and CUDA 10.1.
13 |
14 | Requirements: torch==1.1.0, tensorboardX
15 |
16 | ## Training
17 |
18 | 1. Prepare datasets and pertained BN-Inception.
19 |
20 | Download datasets: [CUB200-2011](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html), [Cars196](https://ai.stanford.edu/~jkrause/cars/car_dataset.html), [SOP](https://cvgl.stanford.edu/projects/lifted_struct/), [In-Shop](http://mmlab.ie.cuhk.edu.hk/projects/DeepFashion/InShopRetrieval.html), unzip and organize them as follows.
21 |
22 | ```
23 | └───datasets
24 | └───split_train_test.py
25 | └───CUB_200_2011
26 | | └───images.txt
27 | | └───images
28 | | └───001.Black_footed_Albatross
29 | | └───...
30 | └───CARS196
31 | | └───cars_annos.mat
32 | | └───car_ims
33 | | └───000001.jpg
34 | | └───...
35 | └───SOP
36 | | └───Stanford_Online_Products
37 | | └───Ebay_train.txt
38 | | └───Ebay_test.txt
39 | | └───bicycle_final
40 | | └───...
41 | └───Inshop
42 | | └───list_eval_partition.txt
43 | | └───img
44 | | └───MEN
45 | | └───WOMEN
46 | | └───...
47 | ```
48 |
49 | Then run ```split_train_test.py``` to generate training and testing lists.
50 |
51 | Download the imagenet pertained [BN-Inception](http://data.lip6.fr/cadene/pretrainedmodels/bn_inception-52deb4733.pth) and put it into ```./pretrained_models```.
52 |
53 | 2. To train the model(s) in the paper, run the following commands or use ```sh mytrain.sh```.
54 |
55 | Train models with vanilla triplet loss.
56 |
57 | ```train
58 | CUDA_VISIBLE_DEVICES=0 python train.py --use_dataset CUB --instances 3 --lr 0.5e-5 --lr_p 0.25e-5 \
59 | --lr_gamma 0.1 --use_loss triplet
60 | ```
61 |
62 | Train models with vanilla triplet loss + SEC.
63 |
64 | ```train
65 | CUDA_VISIBLE_DEVICES=0 python train.py --use_dataset CUB --instances 3 --lr 0.5e-5 --lr_p 0.25e-5 \
66 | --lr_gamma 0.1 --use_loss triplet --sec_wei 1.0
67 | ```
68 |
69 | Train models with vanilla triplet loss + L2-reg.
70 |
71 | ```train
72 | CUDA_VISIBLE_DEVICES=0 python train.py --use_dataset CUB --instances 3 --lr 0.5e-5 --lr_p 0.25e-5 \
73 | --lr_gamma 0.1 --use_loss triplet --l2reg_wei 1e-4
74 | ```
75 |
76 | Similarly, we set ```--use_loss``` to ```semihtriplet```/```n-npair```/```ms``` and ```--instances``` to ```3```/```2```/```5```, for training models with semihard triplet loss / normalized N-pair loss / multi-similarity loss. We set ```--use_dataset``` to ```Cars```/```SOP```/```Inshop```, for training models on other datasets.
77 |
78 | >📋 The detailed settings of the above hyper-parameters is provided in Appendix B of our paper (with two exceptions to the lr settings listed below).
79 | >
80 | >(a) multi-similarity loss without SEC/L2-reg on CUB: 1e-5/0.5e-5/0.1@3k, 6k
81 | >
82 | >(b) multi-similarity loss without SEC/L2-reg on Cars: 2e-5/2e-5/0.1@2k
83 | >
84 | >(We find that using a larger learning rate harms the original loss function.)
85 | >
86 | >When training on a different dataset or with a different loss function, we only need to modify the hyper-parameters in above commands and the head settings (only when using multi-similarity loss without SEC/L2-reg, we need to set need_bn=False,
87 | >
88 | >```
89 | >self.model = torch.nn.DataParallel(BNInception(need_bn=False)).cuda()
90 | >```
91 | >
92 | >in line 24 of learner.py).
93 |
94 | >📋 Additionally, to use SEC with EMA method, we need to set ```--norm_momentum ```, where norm_momentum denotes $\rho$ in Appendix D of our paper.
95 |
96 | ## Testing
97 |
98 | The test of NMI and F1 on SOP costs a lot of time, and we thus conduct it only after the training process (we only conduct test of R@K during training). In particular, run:
99 |
100 | ```eval
101 | CUDA_VISIBLE_DEVICES=0 python test_sop.py --use_dataset SOP --test_sop_model SOP_xxxx_xxxx
102 | ```
103 |
104 | or use ```sh test_sop.sh``` for a complete test of NMI, F1, and R@K on SOP. Here ```SOP_xxxx_xxxx``` is the model to be tested which could be found in ```./work_space```.
105 |
106 | For other three datasets, the test of NMI, F1, and R@K is conducted during the training process.
107 |
108 | ## Results
109 |
110 | Our model achieves the following performance on CUB200-2011, Cars196, SOP, and In-Shop datasets:
111 |
112 |
113 |
114 | ## Citation
115 |
116 | If you find this repo useful for your research, please consider citing this paper
117 |
118 | @article{zhang2020deep,
119 | title={Deep Metric Learning with Spherical Embedding},
120 | author={Zhang, Dingyi and Li, Yingming and Zhang, Zhongfei},
121 | journal={arXiv preprint arXiv:2011.02785},
122 | year={2020}
123 | }
124 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import myutils
3 | import argparse
4 | import torch
5 | from torchvision import transforms
6 | import datetime
7 | from easydict import EasyDict as edict
8 | import os, logging, sys
9 |
10 | def get_config():
11 |
12 | parser = argparse.ArgumentParser('argument for training')
13 |
14 | parser.add_argument('--use_dataset', type=str, default='Cars', choices=['CUB', 'Cars', 'SOP', 'Inshop'])
15 | # batch
16 | parser.add_argument('--batch_size', type=int, default=120)
17 | parser.add_argument('--instances', type=int, default=3)
18 | # optimization
19 | parser.add_argument('--lr', type=float, default=0.0)
20 | parser.add_argument('--lr_p', type=float, default=0.0)
21 | parser.add_argument('--lr_gamma', type=float, default=0.0)
22 | # model dataset
23 | parser.add_argument('--freeze_bn', type=int, default=1)
24 | # method
25 | parser.add_argument('--use_loss', type=str, default='triplet', choices=['triplet', 'n-npair', 'semihtriplet', 'ms'])
26 | parser.add_argument('--sec_wei', type=float, default=0.0)
27 | parser.add_argument('--norm_momentum', type=float, default=1.0)
28 | parser.add_argument('--l2reg_wei', type=float, default=0.0)
29 |
30 | parser.add_argument('--test_sop_model', type=str, default='')
31 |
32 | conf = parser.parse_args()
33 |
34 | conf.num_devs = 1
35 |
36 | if conf.use_dataset == 'CUB':
37 | conf.lr = 1.0e-5 if conf.lr==0 else conf.lr
38 | conf.lr_p = 0.5e-5 if conf.lr_p==0 else conf.lr_p
39 | conf.weight_decay = 0.5 * 5e-3
40 |
41 | conf.start_step = 0
42 | conf.lr_gamma = 0.1 if conf.lr_gamma==0 else conf.lr_gamma
43 | if conf.use_loss=='ms':
44 | conf.step_milestones = [3000, 6000, 9000]
45 | else:
46 | conf.step_milestones = [5000, 9000, 9000]
47 | conf.steps = 8000
48 |
49 | elif conf.use_dataset == 'Cars':
50 | conf.lr = 1e-5 if conf.lr==0 else conf.lr
51 | conf.lr_p = 1e-5 if conf.lr_p==0 else conf.lr_p
52 | conf.weight_decay = 0.5 * 5e-3
53 |
54 | conf.start_step = 0
55 | if conf.lr_gamma == 0.1:
56 | conf.step_milestones = [2000, 9000, 9000]
57 | elif conf.lr_gamma == 0.5:
58 | conf.step_milestones = [4000, 6000, 9000]
59 | conf.steps = 8000
60 |
61 | elif conf.use_dataset == 'SOP':
62 | conf.lr = 2.5e-4 if conf.lr==0 else conf.lr
63 | conf.lr_p = 0.5e-4 if conf.lr_p==0 else conf.lr_p
64 | conf.weight_decay = 1e-5
65 |
66 | conf.start_step = 0
67 | conf.lr_gamma = 0.1 if conf.lr_gamma==0 else conf.lr_gamma
68 | conf.step_milestones = [6e3, 18e3, 35e3]
69 | conf.steps = 12e3
70 |
71 | elif conf.use_dataset == 'Inshop':
72 | conf.lr = 5e-4 if conf.lr==0 else conf.lr
73 | conf.lr_p = 1e-4 if conf.lr_p==0 else conf.lr_p
74 | conf.weight_decay = 1e-5
75 |
76 | conf.start_step = 0
77 | conf.lr_gamma = 0.1 if conf.lr_gamma==0 else conf.lr_gamma
78 | conf.step_milestones = [6e3, 18e3, 35e3]
79 | conf.steps = 12e3
80 |
81 | conf.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
82 |
83 | now_time = datetime.datetime.now().strftime('%m%d_%H%M')
84 | conf_work_path = 'work_space/' + conf.use_dataset + '_' + now_time
85 | myutils.mkdir_p(conf_work_path, delete=True)
86 | myutils.set_file_logger(work_dir=conf_work_path, log_level=logging.DEBUG)
87 | sys.stdout = myutils.Logger(conf_work_path + '/log-prt')
88 | sys.stderr = myutils.Logger(conf_work_path + '/log-prt-err')
89 |
90 | path0, path1 = conf_work_path.split('/')
91 | conf.log_path = Path(path0) / 'logs' / path1 / 'log'
92 | conf.work_path = Path(conf_work_path)
93 | conf.model_path = conf.work_path / 'models'
94 | conf.save_path = conf.work_path / 'save'
95 |
96 | conf.start_eval = False
97 |
98 | conf.num_workers = 8
99 |
100 | conf.bninception_pretrained_model_path = './pretrained_models/bn_inception-52deb4733.pth'
101 |
102 | conf.transform_dict = {}
103 | conf.use_simple_aug = False
104 |
105 | conf.transform_dict['rand-crop'] = \
106 | transforms.Compose([
107 | transforms.Resize(size=(256, 256)) if conf.use_simple_aug else transforms.Resize(size=256),
108 | transforms.RandomCrop((227, 227)) if conf.use_simple_aug else transforms.RandomResizedCrop(
109 | scale=[0.16, 1],
110 | size=227
111 | ),
112 | transforms.RandomHorizontalFlip(),
113 | transforms.ToTensor(),
114 | transforms.Normalize(mean=[123 / 255.0, 117 / 255.0, 104 / 255.0],
115 | std=[1.0 / 255, 1.0 / 255, 1.0 / 255]),
116 | transforms.Lambda(lambda x: x[[2, 1, 0], ...]) #to BGR
117 | ])
118 | conf.transform_dict['center-crop'] = \
119 | transforms.Compose([
120 | transforms.Resize(size=(256, 256)) if conf.use_simple_aug else transforms.Resize(size=256),
121 | transforms.CenterCrop(227),
122 | transforms.ToTensor(),
123 | transforms.Normalize(mean=[123 / 255.0, 117 / 255.0, 104 / 255.0],
124 | std=[1.0 / 255, 1.0 / 255, 1.0 / 255]),
125 | transforms.Lambda(lambda x: x[[2, 1, 0], ...]) #to BGR
126 | ])
127 |
128 |
129 |
130 | return conf
131 |
--------------------------------------------------------------------------------
/data_engine.py:
--------------------------------------------------------------------------------
1 | import myutils
2 | import os
3 | import torch
4 | import numpy as np
5 | import os.path as osp
6 | from PIL import Image
7 | from torch.utils.data.sampler import Sampler
8 | from collections import defaultdict
9 | import re
10 |
11 | class MSBaseDataSet(torch.utils.data.Dataset):
12 | """
13 | Basic Dataset read image path from img_source
14 | img_source: list of img_path and label
15 | """
16 | def __init__(self, conf, img_source, transform=None, mode="RGB"):
17 | self.mode = mode
18 |
19 | self.root = os.path.dirname(img_source)
20 | assert os.path.exists(img_source), f"{img_source} NOT found."
21 | self.img_source = img_source
22 |
23 | self.label_list = list()
24 | self.path_list = list()
25 | self._load_data()
26 | self.label_index_dict = self._build_label_index_dict()
27 |
28 | self.num_cls = len(self.label_index_dict.keys())
29 | self.num_train = len(self.label_list)
30 |
31 | self.transform = transform
32 |
33 | def __len__(self):
34 | return len(self.label_list)
35 |
36 | def __repr__(self):
37 | return self.__str__()
38 |
39 | def __str__(self):
40 | return f"| Dataset Info |datasize: {self.__len__()}|num_labels: {len(set(self.label_list))}|"
41 |
42 | def _load_data(self):
43 | with open(self.img_source, 'r') as f:
44 | for line in f:
45 | _path, _label = re.split(r",| ", line.strip())
46 | self.path_list.append(_path)
47 | self.label_list.append(_label)
48 |
49 | def _build_label_index_dict(self):
50 | index_dict = defaultdict(list)
51 | for i, label in enumerate(self.label_list):
52 | index_dict[label].append(i)
53 | return index_dict
54 |
55 | def read_image(self, img_path, mode='RGB'):
56 | """Keep reading image until succeed.
57 | This can avoid IOError incurred by heavy IO process."""
58 | got_img = False
59 | if not osp.exists(img_path):
60 | raise IOError(f"{img_path} does not exist")
61 | while not got_img:
62 | try:
63 | img = Image.open(img_path).convert("RGB")
64 | if mode == "BGR":
65 | r, g, b = img.split()
66 | img = Image.merge("RGB", (b, g, r))
67 | got_img = True
68 | except IOError:
69 | print(f"IOError incurred when reading '{img_path}'. Will redo.")
70 | pass
71 | return img
72 |
73 | def __getitem__(self, index):
74 | path = self.path_list[index]
75 | img_path = os.path.join(self.root, path)
76 | label = self.label_list[index]
77 |
78 | img = self.read_image(img_path, mode=self.mode)
79 |
80 | if self.transform is not None:
81 | img = self.transform(img)
82 | return {'image': img, 'label': int(label), 'index': index}
83 |
84 |
85 | class RandomIdSampler(Sampler):
86 | def __init__(self, conf, label_index_dict):
87 | self.label_index_dict = label_index_dict
88 | self.num_train = 0
89 | for k in self.label_index_dict.keys():
90 | self.num_train += len(self.label_index_dict[k])
91 |
92 | self.num_instances = conf.instances
93 | self.batch_size = conf.batch_size
94 | assert self.batch_size % self.num_instances == 0
95 | self.num_pids_per_batch = self.batch_size // self.num_instances
96 |
97 | self.ids = list(self.label_index_dict.keys())
98 |
99 | self.length = self.num_train//self.batch_size * self.batch_size
100 | self.conf = conf
101 |
102 | def __len__(self):
103 | return self.length
104 |
105 | def get_batch_ids(self):
106 | pids = []
107 |
108 | pids = np.random.choice(self.ids,
109 | size=self.num_pids_per_batch,
110 | replace=False)
111 | return pids
112 |
113 | def get_batch_idxs(self):
114 | pids = self.get_batch_ids()
115 |
116 | inds = []
117 | cnt = 0
118 | for pid in pids:
119 | index_list = self.label_index_dict[pid]
120 | if len(index_list) >= self.num_instances:
121 | t = np.random.choice(index_list, size=self.num_instances, replace=False)
122 | else:
123 | t = np.random.choice(index_list, size=self.num_instances, replace=True)
124 | t_ = t.astype(int)
125 | for ind in t:
126 | yield ind
127 | cnt += 1
128 | if cnt == self.batch_size:
129 | break
130 | if cnt == self.batch_size:
131 | break
132 |
133 | def __iter__(self):
134 | cnt = 0
135 | while cnt < len(self):
136 | for ind in self.get_batch_idxs():
137 | cnt += 1
138 | yield ind
139 |
140 |
141 |
--------------------------------------------------------------------------------
/datasets/split_train_test.py:
--------------------------------------------------------------------------------
1 | from scipy.io import loadmat
2 |
3 |
4 | # CUB200-2011
5 | with open('./CUB_200_2011/images.txt', 'r') as src:
6 | srclines = src.readlines()
7 |
8 | with open('./CUB_200_2011/cub_train.txt', 'w') as tf:
9 | for line in srclines:
10 | i, fname = line.strip().split()
11 | label = int(fname.split('.', 1)[0])
12 | if label <= 100:
13 | print('images/{},{}'.format(fname, label-1), file=tf)
14 |
15 | with open('./CUB_200_2011/cub_test.txt', 'w') as tf:
16 | for line in srclines:
17 | i, fname = line.strip().split()
18 | label = int(fname.split('.', 1)[0])
19 | if label > 100:
20 | print('images/{},{}'.format(fname, label-1), file=tf)
21 |
22 |
23 | # Cars196
24 | file = loadmat('./CARS196/cars_annos.mat')
25 | annos = file['annotations']
26 |
27 | with open('./CARS196/cars_train.txt', 'w') as tf:
28 | for i in range(16185):
29 | if annos[0,i][-2] <= 98:
30 | print('{},{}'.format(annos[0,i][0][0], annos[0,i][-2][0][0]-1), file=tf)
31 |
32 | with open('./CARS196/cars_test.txt', 'w') as tf:
33 | for i in range(16185):
34 | if annos[0,i][-2] > 98:
35 | print('{},{}'.format(annos[0,i][0][0], annos[0,i][-2][0][0]-1), file=tf)
36 |
37 |
38 | # SOP
39 | with open('./SOP/Stanford_Online_Products/Ebay_train.txt', 'r') as src:
40 | srclines = src.readlines()
41 |
42 | with open('./SOP/sop_train.txt', 'w') as tf:
43 | for i in range(1, len(srclines)):
44 | line = srclines[i]
45 | line_split = line.strip().split(' ')
46 | cls_id = str(int(line_split[1]) - 1)
47 | img_path = 'Stanford_Online_Products/'+line_split[3]
48 | print(img_path+','+cls_id, file=tf)
49 |
50 | with open('./SOP/Stanford_Online_Products/Ebay_test.txt', 'r') as src:
51 | srclines = src.readlines()
52 |
53 | with open('./SOP/sop_test.txt', 'w') as tf:
54 | for i in range(1, len(srclines)):
55 | line = srclines[i]
56 | line_split = line.strip().split(' ')
57 | cls_id = str(int(line_split[1]) - 1)
58 | img_path = 'Stanford_Online_Products/'+line_split[3]
59 | print(img_path+','+cls_id, file=tf)
60 |
61 |
62 | # In-Shop
63 | with open('./Inshop/list_eval_partition.txt', 'r') as file_to_read:
64 | lines = file_to_read.readlines()
65 |
66 | with open('./Inshop/inshop_train.txt', 'w') as tf:
67 | cls_name2idx = {}
68 | cls_num = 0
69 | for line in lines:
70 | words = line.strip().split()
71 | if len(words)==3:
72 | if words[-1]=='train':
73 | path = words[0]
74 | cls_name = words[1]
75 | if cls_name not in cls_name2idx.keys():
76 | cls_name2idx[cls_name] = cls_num
77 | cls_num += 1
78 | print('{},{}'.format(path, cls_name2idx[cls_name]), file=tf)
79 |
80 | with open('./Inshop/inshop_query.txt', 'w') as tf:
81 | test_cls_name2idx = {}
82 | cls_num = 0
83 | for line in lines:
84 | words = line.strip().split()
85 | if len(words)==3:
86 | if words[-1]=='query':
87 | path = words[0]
88 | cls_name = words[1]
89 | if cls_name not in test_cls_name2idx.keys():
90 | test_cls_name2idx[cls_name] = cls_num
91 | cls_num += 1
92 | print('{},{}'.format(path, test_cls_name2idx[cls_name]), file=tf)
93 |
94 | with open('./Inshop/inshop_gallery.txt', 'w') as tf:
95 | for line in lines:
96 | words = line.strip().split()
97 | if len(words)==3:
98 | if words[-1]=='gallery':
99 | path = words[0]
100 | cls_name = words[1]
101 | if cls_name not in test_cls_name2idx.keys():
102 | print('error!')
103 | break
104 | print('{},{}'.format(path, test_cls_name2idx[cls_name]), file=tf)
105 |
106 |
107 |
108 |
109 |
110 |
--------------------------------------------------------------------------------
/evaluation.py:
--------------------------------------------------------------------------------
1 | from sklearn.cluster import KMeans
2 | from sklearn.metrics.cluster import normalized_mutual_info_score
3 | import numpy as np
4 | import logging
5 | import torch
6 | import myutils
7 | import math
8 | from scipy.special import comb
9 |
10 | def NMI_F1(X, ground_truth, n_cluster):
11 | X = [x.cpu().numpy() for x in X]
12 | # list to numpy
13 | X = np.array(X)
14 |
15 | ground_truth = np.array(ground_truth)
16 |
17 | kmeans = KMeans(n_clusters=n_cluster, n_jobs=4, random_state=0).fit(X)
18 |
19 | logging.info('K-means done')
20 | nmi, f1 = compute_clutering_metric(np.asarray(kmeans.labels_), ground_truth)
21 |
22 | return nmi, f1
23 |
24 | def normalize(x):
25 | norm = x.norm(dim=1, p=2, keepdim=True)
26 | x = x.div(norm.expand_as(x))
27 | return x
28 |
29 | def pairwise_similarity(x, y=None):
30 | if y is None:
31 | y = x
32 |
33 | y = normalize(y)
34 | x = normalize(x)
35 |
36 | similarity = torch.mm(x, y.t())
37 | return similarity
38 |
39 |
40 | def Recall_at_ks(sim_mat, data_name=None, query_ids=None, gallery_ids=None):
41 | # start_time = time.time()
42 | # print(start_time)
43 | """
44 | :param sim_mat:
45 | :param query_ids
46 | :param gallery_ids
47 | :param data
48 |
49 | Compute [R@1, R@2, R@4, R@8]
50 | """
51 |
52 | ks_dict = dict()
53 | ks_dict['CUB'] = [1, 2, 4, 8, 16, 32]
54 | ks_dict['Cars'] = [1, 2, 4, 8, 16, 32]
55 | ks_dict['SOP'] = [1, 10, 100, 1000]
56 | ks_dict['Inshop'] = [1, 10, 20, 30, 40, 50]
57 |
58 | assert data_name in ['CUB', 'Cars', 'SOP', 'Inshop']
59 | k_s = ks_dict[data_name]
60 |
61 | sim_mat = sim_mat.cpu().numpy()
62 | m, n = sim_mat.shape
63 |
64 |
65 | gallery_ids = np.asarray(gallery_ids)
66 | if query_ids is None:
67 | query_ids = gallery_ids
68 | else:
69 | query_ids = np.asarray(query_ids)
70 |
71 |
72 | num_valid = np.zeros(len(k_s))
73 | neg_nums = np.zeros(m)
74 | for i in range(m):
75 | x = sim_mat[i]
76 |
77 | pos_max = np.max(x[gallery_ids == query_ids[i]])
78 |
79 | neg_num = np.sum(x > pos_max)
80 | neg_nums[i] = neg_num
81 |
82 | for i, k in enumerate(k_s):
83 | if i == 0:
84 | temp = np.sum(neg_nums < k)
85 | num_valid[i] = temp
86 | else:
87 | temp = np.sum(neg_nums < k)
88 | num_valid[i] = temp
89 |
90 | return num_valid / float(m)
91 |
92 |
93 | def compute_clutering_metric(idx, item_ids):
94 |
95 | N = len(idx)
96 |
97 | # cluster centers
98 | centers = np.unique(idx)
99 | num_cluster = len(centers)
100 |
101 | # count the number of objects in each cluster
102 | count_cluster = np.zeros(num_cluster)
103 | for i in range(num_cluster):
104 | count_cluster[i] = len(np.where(idx == centers[i])[0])
105 |
106 | # build a mapping from item_id to item index
107 | keys = np.unique(item_ids)
108 | num_item = len(keys)
109 | values = range(num_item)
110 | item_map = dict()
111 | for i in range(len(keys)):
112 | item_map.update([(keys[i], values[i])])
113 |
114 | # count the number of objects of each item
115 | count_item = np.zeros(num_item)
116 | for i in range(N):
117 | index = item_map[item_ids[i]]
118 | count_item[index] = count_item[index] + 1
119 |
120 | # compute purity
121 | purity = 0
122 | for i in range(num_cluster):
123 | member = np.where(idx == centers[i])[0]
124 | member_ids = item_ids[member]
125 |
126 | count = np.zeros(num_item)
127 | for j in range(len(member)):
128 | index = item_map[member_ids[j]]
129 | count[index] = count[index] + 1
130 | purity = purity + max(count)
131 |
132 | # compute Normalized Mutual Information (NMI)
133 | count_cross = np.zeros((num_cluster, num_item))
134 | for i in range(N):
135 | index_cluster = np.where(idx[i] == centers)[0]
136 | index_item = item_map[item_ids[i]]
137 | count_cross[index_cluster, index_item] = count_cross[index_cluster, index_item] + 1
138 |
139 | # mutual information
140 | I = 0
141 | for k in range(num_cluster):
142 | for j in range(num_item):
143 | if count_cross[k, j] > 0:
144 | s = count_cross[k, j] / N * math.log(N * count_cross[k, j] / (count_cluster[k] * count_item[j]))
145 | I = I + s
146 |
147 | # entropy
148 | H_cluster = 0
149 | for k in range(num_cluster):
150 | s = -count_cluster[k] / N * math.log(count_cluster[k] / float(N))
151 | H_cluster = H_cluster + s
152 |
153 | H_item = 0
154 | for j in range(num_item):
155 | s = -count_item[j] / N * math.log(count_item[j] / float(N))
156 | H_item = H_item + s
157 |
158 | NMI = 2 * I / (H_cluster + H_item)
159 |
160 | # compute True Positive (TP) plus False Positive (FP)
161 | tp_fp = 0
162 | for k in range(num_cluster):
163 | if count_cluster[k] > 1:
164 | tp_fp = tp_fp + comb(count_cluster[k], 2)
165 |
166 | # compute True Positive (TP)
167 | tp = 0
168 | for k in range(num_cluster):
169 | member = np.where(idx == centers[k])[0]
170 | member_ids = item_ids[member]
171 |
172 | count = np.zeros(num_item)
173 | for j in range(len(member)):
174 | index = item_map[member_ids[j]]
175 | count[index] = count[index] + 1
176 |
177 | for i in range(num_item):
178 | if count[i] > 1:
179 | tp = tp + comb(count[i], 2)
180 |
181 | # False Positive (FP)
182 | fp = tp_fp - tp
183 |
184 | # compute False Negative (FN)
185 | count = 0
186 | for j in range(num_item):
187 | if count_item[j] > 1:
188 | count = count + comb(count_item[j], 2)
189 |
190 | fn = count - tp
191 |
192 | # compute F measure
193 | P = tp / (tp + fp)
194 | R = tp / (tp + fn)
195 | beta = 1
196 | F = (beta*beta + 1) * P * R / (beta*beta * P + R)
197 |
198 | return NMI, F
199 |
200 |
--------------------------------------------------------------------------------
/imgs/SEC.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Dyfine/SphericalEmbedding/f118c0ee05cfd3a0905a67cae2a5813a1e061647/imgs/SEC.png
--------------------------------------------------------------------------------
/imgs/experiment_results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Dyfine/SphericalEmbedding/f118c0ee05cfd3a0905a67cae2a5813a1e061647/imgs/experiment_results.png
--------------------------------------------------------------------------------
/learner.py:
--------------------------------------------------------------------------------
1 | import myutils
2 | import os
3 | import torch
4 | from loss import NpairLoss, TripletSemihardLoss, TripletLoss, MultiSimilarityLoss
5 | import logging
6 | import numpy as np
7 | from models.bninception import BNInception
8 |
9 | from torch.utils.data import DataLoader
10 | from torch import optim
11 | from tensorboardX import SummaryWriter
12 | from torch.utils.data.sampler import Sampler
13 | from datetime import datetime
14 | from evaluation import NMI_F1, pairwise_similarity, Recall_at_ks
15 | from data_engine import MSBaseDataSet, RandomIdSampler
16 |
17 | def get_time():
18 | return (str(datetime.now())[:-10]).replace(' ', '-').replace(':', '-')
19 |
20 | class metric_learner(object):
21 | def __init__(self, conf, inference=False):
22 |
23 | logging.info(f'metric learner use {conf}')
24 | self.model = torch.nn.DataParallel(BNInception()).cuda()
25 | logging.info(f'model generated')
26 |
27 | if not inference:
28 |
29 | if conf.use_dataset == 'CUB':
30 | self.dataset = MSBaseDataSet(conf, './datasets/CUB_200_2011/cub_train.txt',
31 | transform=conf.transform_dict['rand-crop'], mode='RGB')
32 | elif conf.use_dataset == 'Cars':
33 | self.dataset = MSBaseDataSet(conf, './datasets/CARS196/cars_train.txt',
34 | transform=conf.transform_dict['rand-crop'], mode='RGB')
35 | elif conf.use_dataset == 'SOP':
36 | self.dataset = MSBaseDataSet(conf, './datasets/SOP/sop_train.txt',
37 | transform=conf.transform_dict['rand-crop'], mode='RGB')
38 | elif conf.use_dataset == 'Inshop':
39 | self.dataset = MSBaseDataSet(conf, './datasets/Inshop/inshop_train.txt',
40 | transform=conf.transform_dict['rand-crop'], mode='RGB')
41 |
42 | self.loader = DataLoader(
43 | self.dataset, batch_size=conf.batch_size, num_workers=conf.num_workers,
44 | shuffle=False, sampler=RandomIdSampler(conf, self.dataset.label_index_dict), drop_last=True,
45 | pin_memory=True,
46 | )
47 |
48 | self.class_num = self.dataset.num_cls
49 | self.img_num = self.dataset.num_train
50 |
51 | myutils.mkdir_p(conf.log_path, delete=True)
52 | self.writer = SummaryWriter(str(conf.log_path))
53 | self.step = 0
54 |
55 | self.head_npair = NpairLoss().to(conf.device)
56 | self.head_semih_triplet = TripletSemihardLoss().to(conf.device)
57 | self.head_triplet = TripletLoss(instance=conf.instances).to(conf.device)
58 | self.head_multisimiloss = MultiSimilarityLoss().to(conf.device)
59 | logging.info('model heads generated')
60 |
61 | backbone_bn_para, backbone_wo_bn_para = [
62 | [p for k, p in self.model.named_parameters() if
63 | ('bn' in k) == is_bn and ('head' in k) == False] for is_bn in [True, False]]
64 |
65 | head_bn_para, head_wo_bn_para = [
66 | [p for k, p in self.model.module.head.named_parameters() if
67 | ('bn' in k) == is_bn] for is_bn in [True, False]]
68 |
69 | self.optimizer = optim.Adam([
70 | {'params': backbone_bn_para if conf.freeze_bn==False else [], 'lr': conf.lr_p},
71 | {'params': backbone_wo_bn_para, 'weight_decay': conf.weight_decay, 'lr': conf.lr_p},
72 | {'params': head_bn_para, 'lr': conf.lr},
73 | {'params': head_wo_bn_para, 'weight_decay': conf.weight_decay, 'lr': conf.lr},
74 | ])
75 |
76 | logging.info(f'{self.optimizer}, optimizers generated')
77 |
78 | if conf.use_dataset=='CUB' or conf.use_dataset=='Cars':
79 | self.board_loss_every = 20
80 | self.evaluate_every = 100
81 | self.save_every = 1000
82 | elif conf.use_dataset=='Inshop':
83 | self.board_loss_every = 20
84 | self.evaluate_every = 200
85 | self.save_every = 2000
86 | else:
87 | self.board_loss_every = 20
88 | self.evaluate_every = 500
89 | self.save_every = 2000
90 |
91 |
92 | def train(self, conf):
93 | self.model.train()
94 | self.train_with_fixed_bn(conf)
95 |
96 | myutils.timer.since_last_check('start train')
97 | data_time = myutils.AverageMeter(20)
98 | loss_time = myutils.AverageMeter(20)
99 | loss_meter = myutils.AverageMeter(20)
100 |
101 | self.step = conf.start_step
102 |
103 | if self.step == 0 and conf.start_eval:
104 | nmi, f1, recall_ks = self.test(conf)
105 | self.writer.add_scalar('{}/test_nmi'.format(conf.use_dataset), nmi, self.step)
106 | self.writer.add_scalar('{}/test_f1'.format(conf.use_dataset), f1, self.step)
107 | self.writer.add_scalar('{}/test_recall_at_1'.format(conf.use_dataset), recall_ks[0], self.step)
108 | logging.info(f'test on {conf.use_dataset}: nmi is {nmi}, f1 is {f1}, recalls are {recall_ks[0]}, {recall_ks[1]}, {recall_ks[2]}, {recall_ks[3:]} ')
109 |
110 | nmi, f1, recall_ks = self.validate(conf)
111 | self.writer.add_scalar('{}/train_nmi'.format(conf.use_dataset), nmi, self.step)
112 | self.writer.add_scalar('{}/train_f1'.format(conf.use_dataset), f1, self.step)
113 | self.writer.add_scalar('{}/train_recall_at_1'.format(conf.use_dataset), recall_ks[0], self.step)
114 | logging.info(f'val on {conf.use_dataset}: nmi is {nmi}, f1 is {f1}, recall_at_1 is {recall_ks[0]} ')
115 |
116 | self.train_with_fixed_bn(conf)
117 |
118 |
119 | while self.step < conf.steps:
120 |
121 | loader_enum = enumerate(self.loader)
122 | while True:
123 | if self.step > conf.steps:
124 | break
125 | try:
126 | ind_data, data = loader_enum.__next__()
127 | except StopIteration as e:
128 | logging.info(f'one epoch finish {e} {ind_data}')
129 | break
130 | data_time.update(myutils.timer.since_last_check(verbose=False))
131 |
132 | if self.step == conf.step_milestones[0]:
133 | self.schedule_lr(conf)
134 | if self.step == conf.step_milestones[1]:
135 | self.schedule_lr(conf)
136 | if self.step == conf.step_milestones[2]:
137 | self.schedule_lr(conf)
138 |
139 | imgs = data['image'].to(conf.device)
140 | labels = data['label'].to(conf.device)
141 | index = data['index']
142 |
143 | self.optimizer.zero_grad()
144 |
145 | fea = self.model(imgs, normalized=False)
146 |
147 | fea_norm = fea.norm(p=2, dim=1)
148 | norm_mean = fea_norm.mean()
149 | norm_var = ((fea_norm - norm_mean) ** 2).mean()
150 |
151 |
152 | if self.step==0:
153 | self.record_norm_mean = norm_mean.detach()
154 | else:
155 | self.record_norm_mean = (1 - conf.norm_momentum) * self.record_norm_mean + \
156 | conf.norm_momentum * norm_mean.detach()
157 |
158 |
159 | if conf.use_loss == 'triplet':
160 | loss, avg_ap, avg_an = self.head_triplet(fea, labels, normalized=True)
161 | elif conf.use_loss == 'n-npair':
162 | loss, avg_ap, avg_an = self.head_npair(fea, labels, normalized=True)
163 | elif conf.use_loss == 'semihtriplet':
164 | loss, avg_ap, avg_an = self.head_semih_triplet(fea, labels, normalized=True)
165 | elif conf.use_loss == 'ms':
166 | loss, avg_ap, avg_an = self.head_multisimiloss(fea, labels)
167 |
168 |
169 |
170 | loss_sec = ((fea_norm - self.record_norm_mean) ** 2).mean()
171 | loss_l2reg = (fea_norm ** 2).mean()
172 |
173 | if conf.sec_wei != 0:
174 | loss = loss + conf.sec_wei * loss_sec
175 | if conf.l2reg_wei != 0:
176 | loss = loss + conf.l2reg_wei * loss_l2reg
177 |
178 | loss.backward()
179 |
180 | self.writer.add_scalar('info/norm_var', norm_var.detach().item(), self.step)
181 | self.writer.add_scalar('info/norm_mean', norm_mean.detach().item(), self.step)
182 | self.writer.add_scalar('info/loss_sec', loss_sec.item(), self.step)
183 | self.writer.add_scalar('info/loss_l2reg', loss_l2reg.item(), self.step)
184 | self.writer.add_scalar('info/avg_ap', avg_ap.item(), self.step)
185 | self.writer.add_scalar('info/avg_an', avg_an.item(), self.step)
186 | self.writer.add_scalar('info/record_norm_mean', self.record_norm_mean.item(), self.step)
187 | self.writer.add_scalar('info/lr', self.optimizer.param_groups[2]['lr'], self.step)
188 |
189 | loss_meter.update(loss.item())
190 |
191 | self.optimizer.step()
192 |
193 | if self.step % self.evaluate_every ==0 and self.step != 0:
194 | nmi, f1, recall_ks = self.test(conf)
195 | self.writer.add_scalar('{}/test_nmi'.format(conf.use_dataset), nmi, self.step)
196 | self.writer.add_scalar('{}/test_f1'.format(conf.use_dataset), f1, self.step)
197 | self.writer.add_scalar('{}/test_recall_at_1'.format(conf.use_dataset), recall_ks[0], self.step)
198 | logging.info(f'test on {conf.use_dataset}: nmi is {nmi}, f1 is {f1}, recalls are {recall_ks[0]}, {recall_ks[1]}, {recall_ks[2]}, {recall_ks[3:]} ')
199 |
200 | nmi, f1, recall_ks = self.validate(conf)
201 | self.writer.add_scalar('{}/train_nmi'.format(conf.use_dataset), nmi, self.step)
202 | self.writer.add_scalar('{}/train_f1'.format(conf.use_dataset), f1, self.step)
203 | self.writer.add_scalar('{}/train_recall_at_1'.format(conf.use_dataset), recall_ks[0], self.step)
204 | logging.info(f'val on {conf.use_dataset}: nmi is {nmi}, f1 is {f1}, recall_at_1 is {recall_ks[0]} ')
205 |
206 | self.train_with_fixed_bn(conf)
207 |
208 | if self.step % self.board_loss_every == 0 and self.step != 0:
209 | # record lr
210 | self.writer.add_scalar('train_loss', loss_meter.avg, self.step)
211 |
212 | logging.info(f'step {self.step}: ' +
213 | f'loss: {loss_meter.avg:.3f} ' +
214 | f'data time: {data_time.avg:.2f} ' +
215 | f'loss time: {loss_time.avg:.2f} ' +
216 | f'speed: {conf.batch_size/(data_time.avg+loss_time.avg):.2f} imgs/s ' +
217 | f'norm_mean: {norm_mean.item():.2f} ' +
218 | f'norm_var: {norm_var.item():.2f}')
219 |
220 | if self.step % self.save_every == 0 and self.step != 0:
221 | self.save_state(conf)
222 |
223 | self.step += 1
224 |
225 | loss_time.update(myutils.timer.since_last_check(verbose=False))
226 |
227 | self.save_state(conf, to_save_folder=True)
228 |
229 | def train_with_fixed_bn(self, conf):
230 | def fix_bn(m):
231 | classname = m.__class__.__name__
232 | if classname.find('BatchNorm') != -1:
233 | m.eval()
234 | if conf.freeze_bn:
235 | self.model.apply(fix_bn)
236 | self.model.module.head.train()
237 | else:
238 | pass
239 |
240 | def validate(self, conf):
241 | logging.info('start eval')
242 | self.model.eval()
243 |
244 | if conf.use_dataset == 'CUB' or conf.use_dataset == 'Cars' or conf.use_dataset == 'SOP':
245 |
246 | loader = DataLoader(self.dataset, batch_size=conf.batch_size, num_workers=conf.num_workers,
247 | shuffle=False, pin_memory=True, drop_last=False)
248 |
249 | loader_enum = enumerate(loader)
250 | feas = torch.tensor([])
251 | labels = np.array([])
252 | with torch.no_grad():
253 | while True:
254 | try:
255 | ind_data, data = loader_enum.__next__()
256 | except StopIteration as e:
257 | break
258 |
259 | imgs = data['image']
260 | label = data['label']
261 |
262 | output1 = self.model(imgs, normalized=False)
263 | norm = output1.norm(dim=1, p=2, keepdim=True)
264 | output1 = output1.div(norm.expand_as(output1))
265 | feas = torch.cat((feas, output1.cpu()), 0)
266 | labels = np.append(labels, label.cpu().numpy())
267 |
268 | if conf.use_dataset == 'SOP':
269 | nmi = 0
270 | f1 = 0
271 | else:
272 | pids = np.unique(labels)
273 | nmi, f1 = NMI_F1(feas, labels, n_cluster=len(pids))
274 |
275 | sim_mat = pairwise_similarity(feas)
276 | sim_mat = sim_mat - torch.eye(sim_mat.size(0))
277 | recall_ks = Recall_at_ks(sim_mat, data_name=conf.use_dataset, gallery_ids=labels)
278 |
279 | elif conf.use_dataset=='Inshop':
280 | nmi = 0
281 | f1 = 0
282 | recall_ks = [0.0, 0.0]
283 |
284 | self.model.train()
285 | logging.info('eval end')
286 | return nmi, f1, recall_ks
287 |
288 | def test(self, conf):
289 | logging.info('start test')
290 | self.model.eval()
291 |
292 | if conf.use_dataset=='CUB' or conf.use_dataset=='Cars' or conf.use_dataset=='SOP':
293 |
294 | if conf.use_dataset == 'CUB':
295 | dataset = MSBaseDataSet(conf, './datasets/CUB_200_2011/cub_test.txt',
296 | transform=conf.transform_dict['center-crop'], mode='RGB')
297 | elif conf.use_dataset == 'Cars':
298 | dataset = MSBaseDataSet(conf, './datasets/CARS196/cars_test.txt',
299 | transform=conf.transform_dict['center-crop'], mode='RGB')
300 | elif conf.use_dataset == 'SOP':
301 | dataset = MSBaseDataSet(conf, './datasets/SOP/sop_test.txt',
302 | transform=conf.transform_dict['center-crop'], mode='RGB')
303 |
304 | loader = DataLoader(dataset, batch_size=conf.batch_size, num_workers=conf.num_workers,
305 | shuffle=False, pin_memory=True, drop_last=False)
306 |
307 | loader_enum = enumerate(loader)
308 | feas = torch.tensor([])
309 | labels = np.array([])
310 | with torch.no_grad():
311 | while True:
312 | try:
313 | ind_data, data = loader_enum.__next__()
314 | except StopIteration as e:
315 | break
316 |
317 | imgs = data['image']
318 | label = data['label']
319 |
320 | output1 = self.model(imgs, normalized=False)
321 | norm = output1.norm(dim=1, p=2, keepdim=True)
322 | output1 = output1.div(norm.expand_as(output1))
323 | feas = torch.cat((feas, output1.cpu()), 0)
324 | labels = np.append(labels, label.cpu().numpy())
325 |
326 | if conf.use_dataset == 'SOP':
327 | nmi = 0
328 | f1 = 0
329 | else:
330 | pids = np.unique(labels)
331 | nmi, f1 = NMI_F1(feas, labels, n_cluster=len(pids))
332 |
333 | sim_mat = pairwise_similarity(feas)
334 | sim_mat = sim_mat - torch.eye(sim_mat.size(0))
335 | recall_ks = Recall_at_ks(sim_mat, data_name=conf.use_dataset, gallery_ids=labels)
336 |
337 | elif conf.use_dataset=='Inshop':
338 | nmi = 0
339 | f1 = 0
340 | # query
341 | dataset_query = MSBaseDataSet(conf, './datasets/Inshop/inshop_query.txt',
342 | transform=conf.transform_dict['center-crop'], mode='RGB')
343 | loader_query = DataLoader(dataset_query, batch_size=conf.batch_size, num_workers=conf.num_workers,
344 | shuffle=False, pin_memory=True, drop_last=False)
345 | loader_query_enum = enumerate(loader_query)
346 | feas_query = torch.tensor([])
347 | labels_query = np.array([])
348 | with torch.no_grad():
349 | while True:
350 | try:
351 | ind_data, data = loader_query_enum.__next__()
352 | except StopIteration as e:
353 | break
354 |
355 | imgs = data['image']
356 | label = data['label']
357 |
358 | output1 = self.model(imgs, normalized=False)
359 | norm = output1.norm(dim=1, p=2, keepdim=True)
360 | output1 = output1.div(norm.expand_as(output1))
361 | feas_query = torch.cat((feas_query, output1.cpu()), 0)
362 | labels_query = np.append(labels_query, label.cpu().numpy())
363 | # gallery
364 | dataset_gallery = MSBaseDataSet(conf, './datasets/Inshop/inshop_gallery.txt',
365 | transform=conf.transform_dict['center-crop'], mode='RGB')
366 | loader_gallery = DataLoader(dataset_gallery, batch_size=conf.batch_size, num_workers=conf.num_workers,
367 | shuffle=False, pin_memory=True, drop_last=False)
368 | loader_gallery_enum = enumerate(loader_gallery)
369 | feas_gallery = torch.tensor([])
370 | labels_gallery = np.array([])
371 | with torch.no_grad():
372 | while True:
373 | try:
374 | ind_data, data = loader_gallery_enum.__next__()
375 | except StopIteration as e:
376 | break
377 |
378 | imgs = data['image']
379 | label = data['label']
380 |
381 | output1 = self.model(imgs, normalized=False)
382 | norm = output1.norm(dim=1, p=2, keepdim=True)
383 | output1 = output1.div(norm.expand_as(output1))
384 | feas_gallery = torch.cat((feas_gallery, output1.cpu()), 0)
385 | labels_gallery = np.append(labels_gallery, label.cpu().numpy())
386 | # test
387 | sim_mat = pairwise_similarity(feas_query, feas_gallery)
388 | recall_ks = Recall_at_ks(sim_mat, data_name=conf.use_dataset, query_ids=labels_query, gallery_ids=labels_gallery)
389 |
390 | self.model.train()
391 | logging.info('test end')
392 |
393 | return nmi, f1, recall_ks
394 |
395 | def test_sop_complete(self, conf):
396 | assert conf.use_dataset == 'SOP'
397 |
398 | logging.info('start complete sop test')
399 | self.model.eval()
400 |
401 | dataset = MSBaseDataSet(conf, './datasets/SOP/sop_test.txt',
402 | transform=conf.transform_dict['center-crop'], mode='RGB')
403 | loader = DataLoader(dataset, batch_size=conf.batch_size, num_workers=conf.num_workers,
404 | shuffle=False, pin_memory=True, drop_last=False)
405 |
406 | loader_enum = enumerate(loader)
407 | feas = torch.tensor([])
408 | labels = np.array([])
409 | with torch.no_grad():
410 | while True:
411 | try:
412 | ind_data, data = loader_enum.__next__()
413 | except StopIteration as e:
414 | break
415 |
416 | imgs = data['image']
417 | label = data['label']
418 |
419 | output1 = self.model(imgs, normalized=False)
420 | norm = output1.norm(dim=1, p=2, keepdim=True)
421 | output1 = output1.div(norm.expand_as(output1))
422 | feas = torch.cat((feas, output1.cpu()), 0)
423 | labels = np.append(labels, label.cpu().numpy())
424 |
425 | pids = np.unique(labels)
426 | nmi, f1 = NMI_F1(feas, labels, n_cluster=len(pids))
427 |
428 | print(f'nmi: {nmi}, f1: {f1}')
429 |
430 | sim_mat = pairwise_similarity(feas)
431 | sim_mat = sim_mat - torch.eye(sim_mat.size(0))
432 | recall_ks = Recall_at_ks(sim_mat, data_name=conf.use_dataset, gallery_ids=labels)
433 |
434 | self.model.train()
435 | logging.info('test end')
436 |
437 | return nmi, f1, recall_ks
438 |
439 | def load_bninception_pretrained(self, conf):
440 | model_dict = self.model.state_dict()
441 | my_dict = {'module.'+k: v for k, v in torch.load(conf.bninception_pretrained_model_path).items() if 'module.'+k in model_dict.keys()}
442 | print('################################## do not have pretrained:')
443 | for k in model_dict:
444 | if k not in my_dict.keys():
445 | print(k)
446 | print('##################################')
447 | model_dict.update(my_dict)
448 | self.model.load_state_dict(model_dict)
449 |
450 | def schedule_lr(self, conf):
451 | for params in self.optimizer.param_groups:
452 | params['lr'] = params['lr'] * conf.lr_gamma
453 | logging.info(f'{self.optimizer}')
454 |
455 | def save_state(self, conf, to_save_folder=False, model_only=False):
456 | if to_save_folder:
457 | save_path = conf.save_path
458 | else:
459 | save_path = conf.model_path
460 |
461 | myutils.mkdir_p(save_path, delete=False)
462 |
463 | torch.save(
464 | self.model.state_dict(),
465 | save_path /
466 | ('model_{}_step:{}.pth'.format(get_time(), self.step)))
467 | if not model_only:
468 | torch.save(
469 | self.optimizer.state_dict(),
470 | save_path /
471 | ('optimizer_{}_step:{}.pth'.format(get_time(), self.step)))
472 |
473 | def load_state(self, conf, resume_path, fixed_str=None, load_optimizer=False):
474 | from pathlib import Path
475 |
476 | save_path = Path(resume_path)
477 | modelp = save_path / 'model_{}'.format(fixed_str)
478 | if not os.path.exists(modelp):
479 | fixed_strs = [t.name for t in save_path.glob('model*_*.pth')]
480 | step = [fixed_str.split('_')[-1].split(':')[-1].split('.')[-2] for fixed_str in fixed_strs]
481 | step = np.asarray(step, dtype=int)
482 | step_ind = step.argmax()
483 | fixed_str = fixed_strs[step_ind].replace('model_', '')
484 | modelp = save_path / 'model_{}'.format(fixed_str)
485 |
486 | print(fixed_str)
487 |
488 | model_dict = self.model.state_dict()
489 | my_dict = {k: v for k, v in torch.load(modelp).items() if k in model_dict.keys()}
490 | print('################################## do not have pretrained:')
491 | for k in model_dict:
492 | if k not in my_dict.keys():
493 | print(k)
494 | print('##################################')
495 | model_dict.update(my_dict)
496 | self.model.load_state_dict(model_dict)
497 |
498 | if load_optimizer:
499 | self.optimizer.load_state_dict(torch.load(save_path / 'optimizer_{}'.format(fixed_str)))
500 | print(self.optimizer)
501 |
502 |
--------------------------------------------------------------------------------
/loss.py:
--------------------------------------------------------------------------------
1 | import myutils
2 | from torch.nn import Module, Parameter
3 | import torch.nn.functional as F
4 | import torch
5 | import torch.nn as nn
6 | import numpy as np
7 |
8 | class TripletLoss(Module):
9 | def __init__(self, instance, margin=1.0):
10 | super(TripletLoss, self).__init__()
11 | self.margin = margin
12 | self.instance = instance
13 |
14 | def forward(self, inputs, targets, normalized=True):
15 | norm_temp = inputs.norm(dim=1, p=2, keepdim=True)
16 | if normalized:
17 | inputs = inputs.div(norm_temp.expand_as(inputs))
18 |
19 | nB = inputs.size(0)
20 | idx_ = torch.arange(0, nB, dtype=torch.long)
21 |
22 | dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(nB, nB)
23 | dist = dist + dist.t()
24 | # use squared
25 | dist.addmm_(1, -2, inputs, inputs.t()).clamp_(min=1e-12)
26 |
27 | adjacency = targets.expand(nB, nB).eq(targets.expand(nB, nB).t())
28 | adjacency_not = ~adjacency
29 | mask_ap = (adjacency.float() - torch.eye(nB).cuda()).long()
30 | mask_an = adjacency_not.long()
31 |
32 | dist_ap = (dist[mask_ap == 1]).view(-1, 1)
33 | dist_an = (dist[mask_an == 1]).view(nB, -1)
34 | dist_an = dist_an.repeat(1, self.instance - 1)
35 | dist_an = dist_an.view(nB * (self.instance - 1), nB - self.instance)
36 | num_loss = dist_an.size(0) * dist_an.size(1)
37 |
38 | triplet_loss = torch.sum(
39 | torch.max(torch.tensor(0, dtype=torch.float).cuda(), self.margin + dist_ap - dist_an)) / num_loss
40 | final_loss = triplet_loss * 1.0
41 |
42 | with torch.no_grad():
43 | assert normalized == True
44 | cos_theta = torch.mm(inputs, inputs.t())
45 | mask = targets.expand(nB, nB).eq(targets.expand(nB, nB).t())
46 | avg_ap = cos_theta[(mask.float() - torch.eye(nB).cuda()) == 1].mean()
47 | avg_an = cos_theta[mask.float() == 0].mean()
48 |
49 | return final_loss, avg_ap, avg_an
50 |
51 | class TripletSemihardLoss(Module):
52 | def __init__(self, margin=0.2):
53 | super(TripletSemihardLoss, self).__init__()
54 | self.margin = margin
55 |
56 | def forward(self, inputs, targets, normalized=True):
57 | norm_temp = inputs.norm(dim=1, p=2, keepdim=True)
58 | if normalized:
59 | inputs = inputs.div(norm_temp.expand_as(inputs))
60 |
61 | nB = inputs.size(0)
62 | idx_ = torch.arange(0, nB, dtype=torch.long)
63 |
64 | dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(nB, nB)
65 | dist = dist + dist.t()
66 | # use squared
67 | dist.addmm_(1, -2, inputs, inputs.t()).clamp_(min=1e-12)
68 |
69 | temp_euclidean_score = dist * 1.0
70 |
71 | adjacency = targets.expand(nB, nB).eq(targets.expand(nB, nB).t())
72 | adjacency_not = ~ adjacency
73 |
74 | dist_tile = dist.repeat(nB, 1)
75 | mask = (adjacency_not.repeat(nB, 1)) * (dist_tile > (dist.transpose(0, 1).contiguous().view(-1, 1)))
76 | mask_final = (mask.float().sum(dim=1, keepdim=True) > 0).view(nB, nB).transpose(0, 1)
77 |
78 | # negatives_outside: smallest D_an where D_an > D_ap
79 | temp1 = (dist_tile - dist_tile.max(dim=1, keepdim=True)[0]) * (mask.float())
80 | negtives_outside = temp1.min(dim=1, keepdim=True)[0] + dist_tile.max(dim=1, keepdim=True)[0]
81 | negtives_outside = negtives_outside.view(nB, nB).transpose(0, 1)
82 |
83 | # negatives_inside: largest D_an
84 | temp2 = (dist - dist.min(dim=1, keepdim=True)[0]) * (adjacency_not.float())
85 | negtives_inside = temp2.max(dim=1, keepdim=True)[0] + dist.min(dim=1, keepdim=True)[0]
86 | negtives_inside = negtives_inside.repeat(1, nB)
87 |
88 | semi_hard_negtives = torch.where(mask_final, negtives_outside, negtives_inside)
89 |
90 | loss_mat = self.margin + dist - semi_hard_negtives
91 |
92 | mask_positives = adjacency.float() - torch.eye(nB).cuda()
93 | mask_positives = mask_positives.detach()
94 | num_positives = torch.sum(mask_positives)
95 |
96 | triplet_loss = torch.sum(
97 | torch.max(torch.tensor(0, dtype=torch.float).cuda(), loss_mat * mask_positives)) / num_positives
98 | final_loss = triplet_loss * 1.0
99 |
100 | with torch.no_grad():
101 | assert normalized == True
102 | cos_theta = torch.mm(inputs, inputs.t())
103 | mask = targets.expand(nB, nB).eq(targets.expand(nB, nB).t())
104 | avg_ap = cos_theta[(mask.float() - torch.eye(nB).cuda()) == 1].mean()
105 | avg_an = cos_theta[mask.float() == 0].mean()
106 |
107 | return final_loss, avg_ap, avg_an
108 |
109 | def cross_entropy(logits, target, size_average=True):
110 | if size_average:
111 | return torch.mean(torch.sum(- target * F.log_softmax(logits, -1), -1))
112 | else:
113 | return torch.sum(torch.sum(- target * F.log_softmax(logits, -1), -1))
114 |
115 | class NpairLoss(Module):
116 | def __init__(self):
117 | super(NpairLoss, self).__init__()
118 |
119 | def forward(self, inputs, targets, normalized=False):
120 | nB = inputs.size(0)
121 |
122 | norm_temp = inputs.norm(p=2, dim=1, keepdim=True)
123 |
124 | inputs_n = inputs.div(norm_temp.expand_as(inputs))
125 | mm_logits = torch.mm(inputs_n, inputs_n.t()).detach()
126 | mask = targets.expand(nB, nB).eq(targets.expand(nB, nB).t())
127 |
128 | cos_ap = mm_logits[(mask.float() - torch.eye(nB).float().cuda()) == 1].view(nB, -1)
129 | cos_an = mm_logits[mask != 1].view(nB, -1)
130 |
131 | avg_ap = torch.mean(cos_ap)
132 | avg_an = torch.mean(cos_an)
133 |
134 | if normalized:
135 | inputs = inputs.div(norm_temp.expand_as(inputs))
136 | inputs = inputs * 5.0
137 |
138 | labels = targets.view(-1).cpu().numpy()
139 | pids = np.unique(labels)
140 |
141 | anchor_idx = []
142 | positive_idx = []
143 | for i in pids:
144 | ap_idx = np.where(labels == i)[0]
145 | anchor_idx.append(ap_idx[0])
146 | positive_idx.append(ap_idx[1])
147 |
148 | anchor = inputs[anchor_idx, :]
149 | positive = inputs[positive_idx, :]
150 |
151 | batch_size = anchor.size(0)
152 |
153 | target = torch.from_numpy(pids).cuda()
154 | target = target.view(target.size(0), 1)
155 |
156 | target = (target == torch.transpose(target, 0, 1)).float()
157 | target = target / torch.sum(target, dim=1, keepdim=True).float()
158 |
159 | logit = torch.matmul(anchor, torch.transpose(positive, 0, 1))
160 |
161 | loss_ce = cross_entropy(logit, target)
162 | loss = loss_ce * 1.0
163 |
164 | return loss, avg_ap, avg_an
165 |
166 | class MultiSimilarityLoss(Module):
167 | def __init__(self):
168 | super(MultiSimilarityLoss, self).__init__()
169 | self.thresh = 0.5
170 | self.margin = 0.1
171 | self.scale_pos = 2.0
172 | self.scale_neg = 40.0
173 |
174 | def forward(self, feats, labels):
175 |
176 | norm = feats.norm(dim=1, p=2, keepdim=True)
177 | feats = feats.div(norm.expand_as(feats))
178 |
179 | labels = labels.view(-1)
180 | assert feats.size(0) == labels.size(0), \
181 | f"feats.size(0): {feats.size(0)} is not equal to labels.size(0): {labels.size(0)}"
182 | batch_size = feats.size(0)
183 | sim_mat = torch.matmul(feats, torch.t(feats))
184 |
185 | epsilon = 1e-5
186 | loss = list()
187 |
188 | avg_aps = list()
189 | avg_ans = list()
190 |
191 | for i in range(batch_size):
192 | pos_pair_ = sim_mat[i][labels == labels[i]]
193 | pos_pair_ = pos_pair_[pos_pair_ < 1 - epsilon]
194 | neg_pair_ = sim_mat[i][labels != labels[i]]
195 |
196 | if len(neg_pair_) < 1 or len(pos_pair_) < 1:
197 | continue
198 |
199 | avg_aps.append(pos_pair_.mean())
200 | avg_ans.append(neg_pair_.mean())
201 |
202 | neg_pair = neg_pair_[neg_pair_ + self.margin > torch.min(pos_pair_)]
203 | pos_pair = pos_pair_[pos_pair_ - self.margin < torch.max(neg_pair_)]
204 |
205 | if len(neg_pair) < 1 or len(pos_pair) < 1:
206 | continue
207 |
208 | # weighting step
209 | pos_loss = 1.0 / self.scale_pos * torch.log(
210 | 1 + torch.sum(torch.exp(-self.scale_pos * (pos_pair - self.thresh))))
211 | neg_loss = 1.0 / self.scale_neg * torch.log(
212 | 1 + torch.sum(torch.exp(self.scale_neg * (neg_pair - self.thresh))))
213 | loss.append(pos_loss + neg_loss)
214 |
215 | if len(loss) == 0:
216 | print('with ms loss = 0 !')
217 | loss = torch.zeros([], requires_grad=True).cuda()
218 | else:
219 | loss = sum(loss) / batch_size
220 | loss = loss.view(-1)
221 |
222 | avg_ap = sum(avg_aps) / batch_size
223 | avg_an = sum(avg_ans) / batch_size
224 |
225 | return loss, avg_ap, avg_an
226 |
227 |
--------------------------------------------------------------------------------
/models/bninception.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch.nn.parameter import Parameter
6 |
7 | class Flatten(nn.Module):
8 | def forward(self, input):
9 | return input.view(input.size(0), -1)
10 |
11 |
12 | class BNInception(nn.Module):
13 |
14 | def __init__(self, need_bn = True):
15 | super(BNInception, self).__init__()
16 | inplace = True
17 | self.conv1_7x7_s2 = nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
18 | self.conv1_7x7_s2_bn = nn.BatchNorm2d(64, affine=True)
19 | self.conv1_relu_7x7 = nn.ReLU(inplace)
20 | self.pool1_3x3_s2 = nn.MaxPool2d((3, 3), stride=(2, 2), dilation=(1, 1), ceil_mode=True)
21 | self.conv2_3x3_reduce = nn.Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
22 | self.conv2_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True)
23 | self.conv2_relu_3x3_reduce = nn.ReLU(inplace)
24 | self.conv2_3x3 = nn.Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
25 | self.conv2_3x3_bn = nn.BatchNorm2d(192, affine=True)
26 | self.conv2_relu_3x3 = nn.ReLU(inplace)
27 | self.pool2_3x3_s2 = nn.MaxPool2d((3, 3), stride=(2, 2), dilation=(1, 1), ceil_mode=True)
28 | self.inception_3a_1x1 = nn.Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1))
29 | self.inception_3a_1x1_bn = nn.BatchNorm2d(64, affine=True)
30 | self.inception_3a_relu_1x1 = nn.ReLU(inplace)
31 | self.inception_3a_3x3_reduce = nn.Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1))
32 | self.inception_3a_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True)
33 | self.inception_3a_relu_3x3_reduce = nn.ReLU(inplace)
34 | self.inception_3a_3x3 = nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
35 | self.inception_3a_3x3_bn = nn.BatchNorm2d(64, affine=True)
36 | self.inception_3a_relu_3x3 = nn.ReLU(inplace)
37 | self.inception_3a_double_3x3_reduce = nn.Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1))
38 | self.inception_3a_double_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True)
39 | self.inception_3a_relu_double_3x3_reduce = nn.ReLU(inplace)
40 | self.inception_3a_double_3x3_1 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
41 | self.inception_3a_double_3x3_1_bn = nn.BatchNorm2d(96, affine=True)
42 | self.inception_3a_relu_double_3x3_1 = nn.ReLU(inplace)
43 | self.inception_3a_double_3x3_2 = nn.Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
44 | self.inception_3a_double_3x3_2_bn = nn.BatchNorm2d(96, affine=True)
45 | self.inception_3a_relu_double_3x3_2 = nn.ReLU(inplace)
46 | self.inception_3a_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True)
47 | self.inception_3a_pool_proj = nn.Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1))
48 | self.inception_3a_pool_proj_bn = nn.BatchNorm2d(32, affine=True)
49 | self.inception_3a_relu_pool_proj = nn.ReLU(inplace)
50 | self.inception_3b_1x1 = nn.Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
51 | self.inception_3b_1x1_bn = nn.BatchNorm2d(64, affine=True)
52 | self.inception_3b_relu_1x1 = nn.ReLU(inplace)
53 | self.inception_3b_3x3_reduce = nn.Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
54 | self.inception_3b_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True)
55 | self.inception_3b_relu_3x3_reduce = nn.ReLU(inplace)
56 | self.inception_3b_3x3 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
57 | self.inception_3b_3x3_bn = nn.BatchNorm2d(96, affine=True)
58 | self.inception_3b_relu_3x3 = nn.ReLU(inplace)
59 | self.inception_3b_double_3x3_reduce = nn.Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
60 | self.inception_3b_double_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True)
61 | self.inception_3b_relu_double_3x3_reduce = nn.ReLU(inplace)
62 | self.inception_3b_double_3x3_1 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
63 | self.inception_3b_double_3x3_1_bn = nn.BatchNorm2d(96, affine=True)
64 | self.inception_3b_relu_double_3x3_1 = nn.ReLU(inplace)
65 | self.inception_3b_double_3x3_2 = nn.Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
66 | self.inception_3b_double_3x3_2_bn = nn.BatchNorm2d(96, affine=True)
67 | self.inception_3b_relu_double_3x3_2 = nn.ReLU(inplace)
68 | self.inception_3b_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True)
69 | self.inception_3b_pool_proj = nn.Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
70 | self.inception_3b_pool_proj_bn = nn.BatchNorm2d(64, affine=True)
71 | self.inception_3b_relu_pool_proj = nn.ReLU(inplace)
72 | self.inception_3c_3x3_reduce = nn.Conv2d(320, 128, kernel_size=(1, 1), stride=(1, 1))
73 | self.inception_3c_3x3_reduce_bn = nn.BatchNorm2d(128, affine=True)
74 | self.inception_3c_relu_3x3_reduce = nn.ReLU(inplace)
75 | self.inception_3c_3x3 = nn.Conv2d(128, 160, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
76 | self.inception_3c_3x3_bn = nn.BatchNorm2d(160, affine=True)
77 | self.inception_3c_relu_3x3 = nn.ReLU(inplace)
78 | self.inception_3c_double_3x3_reduce = nn.Conv2d(320, 64, kernel_size=(1, 1), stride=(1, 1))
79 | self.inception_3c_double_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True)
80 | self.inception_3c_relu_double_3x3_reduce = nn.ReLU(inplace)
81 | self.inception_3c_double_3x3_1 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
82 | self.inception_3c_double_3x3_1_bn = nn.BatchNorm2d(96, affine=True)
83 | self.inception_3c_relu_double_3x3_1 = nn.ReLU(inplace)
84 | self.inception_3c_double_3x3_2 = nn.Conv2d(96, 96, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
85 | self.inception_3c_double_3x3_2_bn = nn.BatchNorm2d(96, affine=True)
86 | self.inception_3c_relu_double_3x3_2 = nn.ReLU(inplace)
87 | self.inception_3c_pool = nn.MaxPool2d((3, 3), stride=(2, 2), dilation=(1, 1), ceil_mode=True)
88 | self.inception_4a_1x1 = nn.Conv2d(576, 224, kernel_size=(1, 1), stride=(1, 1))
89 | self.inception_4a_1x1_bn = nn.BatchNorm2d(224, affine=True)
90 | self.inception_4a_relu_1x1 = nn.ReLU(inplace)
91 | self.inception_4a_3x3_reduce = nn.Conv2d(576, 64, kernel_size=(1, 1), stride=(1, 1))
92 | self.inception_4a_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True)
93 | self.inception_4a_relu_3x3_reduce = nn.ReLU(inplace)
94 | self.inception_4a_3x3 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
95 | self.inception_4a_3x3_bn = nn.BatchNorm2d(96, affine=True)
96 | self.inception_4a_relu_3x3 = nn.ReLU(inplace)
97 | self.inception_4a_double_3x3_reduce = nn.Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1))
98 | self.inception_4a_double_3x3_reduce_bn = nn.BatchNorm2d(96, affine=True)
99 | self.inception_4a_relu_double_3x3_reduce = nn.ReLU(inplace)
100 | self.inception_4a_double_3x3_1 = nn.Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
101 | self.inception_4a_double_3x3_1_bn = nn.BatchNorm2d(128, affine=True)
102 | self.inception_4a_relu_double_3x3_1 = nn.ReLU(inplace)
103 | self.inception_4a_double_3x3_2 = nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
104 | self.inception_4a_double_3x3_2_bn = nn.BatchNorm2d(128, affine=True)
105 | self.inception_4a_relu_double_3x3_2 = nn.ReLU(inplace)
106 | self.inception_4a_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True)
107 | self.inception_4a_pool_proj = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1))
108 | self.inception_4a_pool_proj_bn = nn.BatchNorm2d(128, affine=True)
109 | self.inception_4a_relu_pool_proj = nn.ReLU(inplace)
110 | self.inception_4b_1x1 = nn.Conv2d(576, 192, kernel_size=(1, 1), stride=(1, 1))
111 | self.inception_4b_1x1_bn = nn.BatchNorm2d(192, affine=True)
112 | self.inception_4b_relu_1x1 = nn.ReLU(inplace)
113 | self.inception_4b_3x3_reduce = nn.Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1))
114 | self.inception_4b_3x3_reduce_bn = nn.BatchNorm2d(96, affine=True)
115 | self.inception_4b_relu_3x3_reduce = nn.ReLU(inplace)
116 | self.inception_4b_3x3 = nn.Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
117 | self.inception_4b_3x3_bn = nn.BatchNorm2d(128, affine=True)
118 | self.inception_4b_relu_3x3 = nn.ReLU(inplace)
119 | self.inception_4b_double_3x3_reduce = nn.Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1))
120 | self.inception_4b_double_3x3_reduce_bn = nn.BatchNorm2d(96, affine=True)
121 | self.inception_4b_relu_double_3x3_reduce = nn.ReLU(inplace)
122 | self.inception_4b_double_3x3_1 = nn.Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
123 | self.inception_4b_double_3x3_1_bn = nn.BatchNorm2d(128, affine=True)
124 | self.inception_4b_relu_double_3x3_1 = nn.ReLU(inplace)
125 | self.inception_4b_double_3x3_2 = nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
126 | self.inception_4b_double_3x3_2_bn = nn.BatchNorm2d(128, affine=True)
127 | self.inception_4b_relu_double_3x3_2 = nn.ReLU(inplace)
128 | self.inception_4b_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True)
129 | self.inception_4b_pool_proj = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1))
130 | self.inception_4b_pool_proj_bn = nn.BatchNorm2d(128, affine=True)
131 | self.inception_4b_relu_pool_proj = nn.ReLU(inplace)
132 | self.inception_4c_1x1 = nn.Conv2d(576, 160, kernel_size=(1, 1), stride=(1, 1))
133 | self.inception_4c_1x1_bn = nn.BatchNorm2d(160, affine=True)
134 | self.inception_4c_relu_1x1 = nn.ReLU(inplace)
135 | self.inception_4c_3x3_reduce = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1))
136 | self.inception_4c_3x3_reduce_bn = nn.BatchNorm2d(128, affine=True)
137 | self.inception_4c_relu_3x3_reduce = nn.ReLU(inplace)
138 | self.inception_4c_3x3 = nn.Conv2d(128, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
139 | self.inception_4c_3x3_bn = nn.BatchNorm2d(160, affine=True)
140 | self.inception_4c_relu_3x3 = nn.ReLU(inplace)
141 | self.inception_4c_double_3x3_reduce = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1))
142 | self.inception_4c_double_3x3_reduce_bn = nn.BatchNorm2d(128, affine=True)
143 | self.inception_4c_relu_double_3x3_reduce = nn.ReLU(inplace)
144 | self.inception_4c_double_3x3_1 = nn.Conv2d(128, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
145 | self.inception_4c_double_3x3_1_bn = nn.BatchNorm2d(160, affine=True)
146 | self.inception_4c_relu_double_3x3_1 = nn.ReLU(inplace)
147 | self.inception_4c_double_3x3_2 = nn.Conv2d(160, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
148 | self.inception_4c_double_3x3_2_bn = nn.BatchNorm2d(160, affine=True)
149 | self.inception_4c_relu_double_3x3_2 = nn.ReLU(inplace)
150 | self.inception_4c_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True)
151 | self.inception_4c_pool_proj = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1))
152 | self.inception_4c_pool_proj_bn = nn.BatchNorm2d(128, affine=True)
153 | self.inception_4c_relu_pool_proj = nn.ReLU(inplace)
154 | self.inception_4d_1x1 = nn.Conv2d(608, 96, kernel_size=(1, 1), stride=(1, 1))
155 | self.inception_4d_1x1_bn = nn.BatchNorm2d(96, affine=True)
156 | self.inception_4d_relu_1x1 = nn.ReLU(inplace)
157 | self.inception_4d_3x3_reduce = nn.Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1))
158 | self.inception_4d_3x3_reduce_bn = nn.BatchNorm2d(128, affine=True)
159 | self.inception_4d_relu_3x3_reduce = nn.ReLU(inplace)
160 | self.inception_4d_3x3 = nn.Conv2d(128, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
161 | self.inception_4d_3x3_bn = nn.BatchNorm2d(192, affine=True)
162 | self.inception_4d_relu_3x3 = nn.ReLU(inplace)
163 | self.inception_4d_double_3x3_reduce = nn.Conv2d(608, 160, kernel_size=(1, 1), stride=(1, 1))
164 | self.inception_4d_double_3x3_reduce_bn = nn.BatchNorm2d(160, affine=True)
165 | self.inception_4d_relu_double_3x3_reduce = nn.ReLU(inplace)
166 | self.inception_4d_double_3x3_1 = nn.Conv2d(160, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
167 | self.inception_4d_double_3x3_1_bn = nn.BatchNorm2d(192, affine=True)
168 | self.inception_4d_relu_double_3x3_1 = nn.ReLU(inplace)
169 | self.inception_4d_double_3x3_2 = nn.Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
170 | self.inception_4d_double_3x3_2_bn = nn.BatchNorm2d(192, affine=True)
171 | self.inception_4d_relu_double_3x3_2 = nn.ReLU(inplace)
172 | self.inception_4d_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True)
173 | self.inception_4d_pool_proj = nn.Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1))
174 | self.inception_4d_pool_proj_bn = nn.BatchNorm2d(128, affine=True)
175 | self.inception_4d_relu_pool_proj = nn.ReLU(inplace)
176 | self.inception_4e_3x3_reduce = nn.Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1))
177 | self.inception_4e_3x3_reduce_bn = nn.BatchNorm2d(128, affine=True)
178 | self.inception_4e_relu_3x3_reduce = nn.ReLU(inplace)
179 | self.inception_4e_3x3 = nn.Conv2d(128, 192, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
180 | self.inception_4e_3x3_bn = nn.BatchNorm2d(192, affine=True)
181 | self.inception_4e_relu_3x3 = nn.ReLU(inplace)
182 | self.inception_4e_double_3x3_reduce = nn.Conv2d(608, 192, kernel_size=(1, 1), stride=(1, 1))
183 | self.inception_4e_double_3x3_reduce_bn = nn.BatchNorm2d(192, affine=True)
184 | self.inception_4e_relu_double_3x3_reduce = nn.ReLU(inplace)
185 | self.inception_4e_double_3x3_1 = nn.Conv2d(192, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
186 | self.inception_4e_double_3x3_1_bn = nn.BatchNorm2d(256, affine=True)
187 | self.inception_4e_relu_double_3x3_1 = nn.ReLU(inplace)
188 | self.inception_4e_double_3x3_2 = nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
189 | self.inception_4e_double_3x3_2_bn = nn.BatchNorm2d(256, affine=True)
190 | self.inception_4e_relu_double_3x3_2 = nn.ReLU(inplace)
191 | self.inception_4e_pool = nn.MaxPool2d((3, 3), stride=(2, 2), dilation=(1, 1), ceil_mode=True)
192 | self.inception_5a_1x1 = nn.Conv2d(1056, 352, kernel_size=(1, 1), stride=(1, 1))
193 | self.inception_5a_1x1_bn = nn.BatchNorm2d(352, affine=True)
194 | self.inception_5a_relu_1x1 = nn.ReLU(inplace)
195 | self.inception_5a_3x3_reduce = nn.Conv2d(1056, 192, kernel_size=(1, 1), stride=(1, 1))
196 | self.inception_5a_3x3_reduce_bn = nn.BatchNorm2d(192, affine=True)
197 | self.inception_5a_relu_3x3_reduce = nn.ReLU(inplace)
198 | self.inception_5a_3x3 = nn.Conv2d(192, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
199 | self.inception_5a_3x3_bn = nn.BatchNorm2d(320, affine=True)
200 | self.inception_5a_relu_3x3 = nn.ReLU(inplace)
201 | self.inception_5a_double_3x3_reduce = nn.Conv2d(1056, 160, kernel_size=(1, 1), stride=(1, 1))
202 | self.inception_5a_double_3x3_reduce_bn = nn.BatchNorm2d(160, affine=True)
203 | self.inception_5a_relu_double_3x3_reduce = nn.ReLU(inplace)
204 | self.inception_5a_double_3x3_1 = nn.Conv2d(160, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
205 | self.inception_5a_double_3x3_1_bn = nn.BatchNorm2d(224, affine=True)
206 | self.inception_5a_relu_double_3x3_1 = nn.ReLU(inplace)
207 | self.inception_5a_double_3x3_2 = nn.Conv2d(224, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
208 | self.inception_5a_double_3x3_2_bn = nn.BatchNorm2d(224, affine=True)
209 | self.inception_5a_relu_double_3x3_2 = nn.ReLU(inplace)
210 | self.inception_5a_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True)
211 | self.inception_5a_pool_proj = nn.Conv2d(1056, 128, kernel_size=(1, 1), stride=(1, 1))
212 | self.inception_5a_pool_proj_bn = nn.BatchNorm2d(128, affine=True)
213 | self.inception_5a_relu_pool_proj = nn.ReLU(inplace)
214 | self.inception_5b_1x1 = nn.Conv2d(1024, 352, kernel_size=(1, 1), stride=(1, 1))
215 | self.inception_5b_1x1_bn = nn.BatchNorm2d(352, affine=True)
216 | self.inception_5b_relu_1x1 = nn.ReLU(inplace)
217 | self.inception_5b_3x3_reduce = nn.Conv2d(1024, 192, kernel_size=(1, 1), stride=(1, 1))
218 | self.inception_5b_3x3_reduce_bn = nn.BatchNorm2d(192, affine=True)
219 | self.inception_5b_relu_3x3_reduce = nn.ReLU(inplace)
220 | self.inception_5b_3x3 = nn.Conv2d(192, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
221 | self.inception_5b_3x3_bn = nn.BatchNorm2d(320, affine=True)
222 | self.inception_5b_relu_3x3 = nn.ReLU(inplace)
223 | self.inception_5b_double_3x3_reduce = nn.Conv2d(1024, 192, kernel_size=(1, 1), stride=(1, 1))
224 | self.inception_5b_double_3x3_reduce_bn = nn.BatchNorm2d(192, affine=True)
225 | self.inception_5b_relu_double_3x3_reduce = nn.ReLU(inplace)
226 | self.inception_5b_double_3x3_1 = nn.Conv2d(192, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
227 | self.inception_5b_double_3x3_1_bn = nn.BatchNorm2d(224, affine=True)
228 | self.inception_5b_relu_double_3x3_1 = nn.ReLU(inplace)
229 | self.inception_5b_double_3x3_2 = nn.Conv2d(224, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
230 | self.inception_5b_double_3x3_2_bn = nn.BatchNorm2d(224, affine=True)
231 | self.inception_5b_relu_double_3x3_2 = nn.ReLU(inplace)
232 | self.inception_5b_pool = nn.MaxPool2d((3, 3), stride=(1, 1), padding=(1, 1), dilation=(1, 1), ceil_mode=True)
233 | self.inception_5b_pool_proj = nn.Conv2d(1024, 128, kernel_size=(1, 1), stride=(1, 1))
234 | self.inception_5b_pool_proj_bn = nn.BatchNorm2d(128, affine=True)
235 | self.inception_5b_relu_pool_proj = nn.ReLU(inplace)
236 |
237 | if need_bn:
238 | self.head = nn.Sequential(OrderedDict([
239 | ('avgpool', nn.AdaptiveAvgPool2d(1)),
240 | ('bn', nn.BatchNorm2d(1024, eps=1e-5)),
241 | ('flatten', Flatten()),
242 | ('fc', nn.Linear(in_features=1024, out_features=512)),
243 | ]))
244 | else:
245 | self.head = nn.Sequential(OrderedDict([
246 | ('avgpool', nn.AdaptiveAvgPool2d(1)),
247 | ('flatten', Flatten()),
248 | ('fc', nn.Linear(in_features=1024, out_features=512)),
249 | ]))
250 |
251 | def features(self, input):
252 | conv1_7x7_s2_out = self.conv1_7x7_s2(input)
253 | conv1_7x7_s2_bn_out = self.conv1_7x7_s2_bn(conv1_7x7_s2_out)
254 | conv1_relu_7x7_out = self.conv1_relu_7x7(conv1_7x7_s2_bn_out)
255 | pool1_3x3_s2_out = self.pool1_3x3_s2(conv1_relu_7x7_out)
256 | conv2_3x3_reduce_out = self.conv2_3x3_reduce(pool1_3x3_s2_out)
257 | conv2_3x3_reduce_bn_out = self.conv2_3x3_reduce_bn(conv2_3x3_reduce_out)
258 | conv2_relu_3x3_reduce_out = self.conv2_relu_3x3_reduce(conv2_3x3_reduce_bn_out)
259 | conv2_3x3_out = self.conv2_3x3(conv2_relu_3x3_reduce_out)
260 | conv2_3x3_bn_out = self.conv2_3x3_bn(conv2_3x3_out)
261 | conv2_relu_3x3_out = self.conv2_relu_3x3(conv2_3x3_bn_out)
262 | pool2_3x3_s2_out = self.pool2_3x3_s2(conv2_relu_3x3_out)
263 |
264 | inception_3a_1x1_out = self.inception_3a_1x1(pool2_3x3_s2_out)
265 | inception_3a_1x1_bn_out = self.inception_3a_1x1_bn(inception_3a_1x1_out)
266 | inception_3a_relu_1x1_out = self.inception_3a_relu_1x1(inception_3a_1x1_bn_out)
267 | inception_3a_3x3_reduce_out = self.inception_3a_3x3_reduce(pool2_3x3_s2_out)
268 | inception_3a_3x3_reduce_bn_out = self.inception_3a_3x3_reduce_bn(inception_3a_3x3_reduce_out)
269 | inception_3a_relu_3x3_reduce_out = self.inception_3a_relu_3x3_reduce(inception_3a_3x3_reduce_bn_out)
270 | inception_3a_3x3_out = self.inception_3a_3x3(inception_3a_relu_3x3_reduce_out)
271 | inception_3a_3x3_bn_out = self.inception_3a_3x3_bn(inception_3a_3x3_out)
272 | inception_3a_relu_3x3_out = self.inception_3a_relu_3x3(inception_3a_3x3_bn_out)
273 | inception_3a_double_3x3_reduce_out = self.inception_3a_double_3x3_reduce(pool2_3x3_s2_out)
274 | inception_3a_double_3x3_reduce_bn_out = self.inception_3a_double_3x3_reduce_bn(
275 | inception_3a_double_3x3_reduce_out)
276 | inception_3a_relu_double_3x3_reduce_out = self.inception_3a_relu_double_3x3_reduce(
277 | inception_3a_double_3x3_reduce_bn_out)
278 | inception_3a_double_3x3_1_out = self.inception_3a_double_3x3_1(inception_3a_relu_double_3x3_reduce_out)
279 | inception_3a_double_3x3_1_bn_out = self.inception_3a_double_3x3_1_bn(inception_3a_double_3x3_1_out)
280 | inception_3a_relu_double_3x3_1_out = self.inception_3a_relu_double_3x3_1(inception_3a_double_3x3_1_bn_out)
281 | inception_3a_double_3x3_2_out = self.inception_3a_double_3x3_2(inception_3a_relu_double_3x3_1_out)
282 | inception_3a_double_3x3_2_bn_out = self.inception_3a_double_3x3_2_bn(inception_3a_double_3x3_2_out)
283 | inception_3a_relu_double_3x3_2_out = self.inception_3a_relu_double_3x3_2(inception_3a_double_3x3_2_bn_out)
284 | inception_3a_pool_out = self.inception_3a_pool(pool2_3x3_s2_out)
285 | inception_3a_pool_proj_out = self.inception_3a_pool_proj(inception_3a_pool_out)
286 | inception_3a_pool_proj_bn_out = self.inception_3a_pool_proj_bn(inception_3a_pool_proj_out)
287 | inception_3a_relu_pool_proj_out = self.inception_3a_relu_pool_proj(inception_3a_pool_proj_bn_out)
288 | inception_3a_output_out = torch.cat(
289 | [inception_3a_relu_1x1_out, inception_3a_relu_3x3_out, inception_3a_relu_double_3x3_2_out,
290 | inception_3a_relu_pool_proj_out], 1)
291 |
292 | inception_3b_1x1_out = self.inception_3b_1x1(inception_3a_output_out)
293 | inception_3b_1x1_bn_out = self.inception_3b_1x1_bn(inception_3b_1x1_out)
294 | inception_3b_relu_1x1_out = self.inception_3b_relu_1x1(inception_3b_1x1_bn_out)
295 | inception_3b_3x3_reduce_out = self.inception_3b_3x3_reduce(inception_3a_output_out)
296 | inception_3b_3x3_reduce_bn_out = self.inception_3b_3x3_reduce_bn(inception_3b_3x3_reduce_out)
297 | inception_3b_relu_3x3_reduce_out = self.inception_3b_relu_3x3_reduce(inception_3b_3x3_reduce_bn_out)
298 | inception_3b_3x3_out = self.inception_3b_3x3(inception_3b_relu_3x3_reduce_out)
299 | inception_3b_3x3_bn_out = self.inception_3b_3x3_bn(inception_3b_3x3_out)
300 | inception_3b_relu_3x3_out = self.inception_3b_relu_3x3(inception_3b_3x3_bn_out)
301 | inception_3b_double_3x3_reduce_out = self.inception_3b_double_3x3_reduce(inception_3a_output_out)
302 | inception_3b_double_3x3_reduce_bn_out = self.inception_3b_double_3x3_reduce_bn(
303 | inception_3b_double_3x3_reduce_out)
304 | inception_3b_relu_double_3x3_reduce_out = self.inception_3b_relu_double_3x3_reduce(
305 | inception_3b_double_3x3_reduce_bn_out)
306 | inception_3b_double_3x3_1_out = self.inception_3b_double_3x3_1(inception_3b_relu_double_3x3_reduce_out)
307 | inception_3b_double_3x3_1_bn_out = self.inception_3b_double_3x3_1_bn(inception_3b_double_3x3_1_out)
308 | inception_3b_relu_double_3x3_1_out = self.inception_3b_relu_double_3x3_1(inception_3b_double_3x3_1_bn_out)
309 | inception_3b_double_3x3_2_out = self.inception_3b_double_3x3_2(inception_3b_relu_double_3x3_1_out)
310 | inception_3b_double_3x3_2_bn_out = self.inception_3b_double_3x3_2_bn(inception_3b_double_3x3_2_out)
311 | inception_3b_relu_double_3x3_2_out = self.inception_3b_relu_double_3x3_2(inception_3b_double_3x3_2_bn_out)
312 | inception_3b_pool_out = self.inception_3b_pool(inception_3a_output_out)
313 | inception_3b_pool_proj_out = self.inception_3b_pool_proj(inception_3b_pool_out)
314 | inception_3b_pool_proj_bn_out = self.inception_3b_pool_proj_bn(inception_3b_pool_proj_out)
315 | inception_3b_relu_pool_proj_out = self.inception_3b_relu_pool_proj(inception_3b_pool_proj_bn_out)
316 | inception_3b_output_out = torch.cat(
317 | [inception_3b_relu_1x1_out, inception_3b_relu_3x3_out, inception_3b_relu_double_3x3_2_out,
318 | inception_3b_relu_pool_proj_out], 1)
319 |
320 | inception_3c_3x3_reduce_out = self.inception_3c_3x3_reduce(inception_3b_output_out)
321 | inception_3c_3x3_reduce_bn_out = self.inception_3c_3x3_reduce_bn(inception_3c_3x3_reduce_out)
322 | inception_3c_relu_3x3_reduce_out = self.inception_3c_relu_3x3_reduce(inception_3c_3x3_reduce_bn_out)
323 | inception_3c_3x3_out = self.inception_3c_3x3(inception_3c_relu_3x3_reduce_out)
324 | inception_3c_3x3_bn_out = self.inception_3c_3x3_bn(inception_3c_3x3_out)
325 | inception_3c_relu_3x3_out = self.inception_3c_relu_3x3(inception_3c_3x3_bn_out)
326 | inception_3c_double_3x3_reduce_out = self.inception_3c_double_3x3_reduce(inception_3b_output_out)
327 | inception_3c_double_3x3_reduce_bn_out = self.inception_3c_double_3x3_reduce_bn(
328 | inception_3c_double_3x3_reduce_out)
329 | inception_3c_relu_double_3x3_reduce_out = self.inception_3c_relu_double_3x3_reduce(
330 | inception_3c_double_3x3_reduce_bn_out)
331 | inception_3c_double_3x3_1_out = self.inception_3c_double_3x3_1(inception_3c_relu_double_3x3_reduce_out)
332 | inception_3c_double_3x3_1_bn_out = self.inception_3c_double_3x3_1_bn(inception_3c_double_3x3_1_out)
333 | inception_3c_relu_double_3x3_1_out = self.inception_3c_relu_double_3x3_1(inception_3c_double_3x3_1_bn_out)
334 | inception_3c_double_3x3_2_out = self.inception_3c_double_3x3_2(inception_3c_relu_double_3x3_1_out)
335 | inception_3c_double_3x3_2_bn_out = self.inception_3c_double_3x3_2_bn(inception_3c_double_3x3_2_out)
336 | inception_3c_relu_double_3x3_2_out = self.inception_3c_relu_double_3x3_2(inception_3c_double_3x3_2_bn_out)
337 | inception_3c_pool_out = self.inception_3c_pool(inception_3b_output_out)
338 | inception_3c_output_out = torch.cat(
339 | [inception_3c_relu_3x3_out, inception_3c_relu_double_3x3_2_out, inception_3c_pool_out], 1)
340 |
341 | inception_4a_1x1_out = self.inception_4a_1x1(inception_3c_output_out)
342 | inception_4a_1x1_bn_out = self.inception_4a_1x1_bn(inception_4a_1x1_out)
343 | inception_4a_relu_1x1_out = self.inception_4a_relu_1x1(inception_4a_1x1_bn_out)
344 | inception_4a_3x3_reduce_out = self.inception_4a_3x3_reduce(inception_3c_output_out)
345 | inception_4a_3x3_reduce_bn_out = self.inception_4a_3x3_reduce_bn(inception_4a_3x3_reduce_out)
346 | inception_4a_relu_3x3_reduce_out = self.inception_4a_relu_3x3_reduce(inception_4a_3x3_reduce_bn_out)
347 | inception_4a_3x3_out = self.inception_4a_3x3(inception_4a_relu_3x3_reduce_out)
348 | inception_4a_3x3_bn_out = self.inception_4a_3x3_bn(inception_4a_3x3_out)
349 | inception_4a_relu_3x3_out = self.inception_4a_relu_3x3(inception_4a_3x3_bn_out)
350 | inception_4a_double_3x3_reduce_out = self.inception_4a_double_3x3_reduce(inception_3c_output_out)
351 | inception_4a_double_3x3_reduce_bn_out = self.inception_4a_double_3x3_reduce_bn(
352 | inception_4a_double_3x3_reduce_out)
353 | inception_4a_relu_double_3x3_reduce_out = self.inception_4a_relu_double_3x3_reduce(
354 | inception_4a_double_3x3_reduce_bn_out)
355 | inception_4a_double_3x3_1_out = self.inception_4a_double_3x3_1(inception_4a_relu_double_3x3_reduce_out)
356 | inception_4a_double_3x3_1_bn_out = self.inception_4a_double_3x3_1_bn(inception_4a_double_3x3_1_out)
357 | inception_4a_relu_double_3x3_1_out = self.inception_4a_relu_double_3x3_1(inception_4a_double_3x3_1_bn_out)
358 | inception_4a_double_3x3_2_out = self.inception_4a_double_3x3_2(inception_4a_relu_double_3x3_1_out)
359 | inception_4a_double_3x3_2_bn_out = self.inception_4a_double_3x3_2_bn(inception_4a_double_3x3_2_out)
360 | inception_4a_relu_double_3x3_2_out = self.inception_4a_relu_double_3x3_2(inception_4a_double_3x3_2_bn_out)
361 | inception_4a_pool_out = self.inception_4a_pool(inception_3c_output_out)
362 | inception_4a_pool_proj_out = self.inception_4a_pool_proj(inception_4a_pool_out)
363 | inception_4a_pool_proj_bn_out = self.inception_4a_pool_proj_bn(inception_4a_pool_proj_out)
364 | inception_4a_relu_pool_proj_out = self.inception_4a_relu_pool_proj(inception_4a_pool_proj_bn_out)
365 | inception_4a_output_out = torch.cat(
366 | [inception_4a_relu_1x1_out, inception_4a_relu_3x3_out, inception_4a_relu_double_3x3_2_out,
367 | inception_4a_relu_pool_proj_out], 1)
368 |
369 | inception_4b_1x1_out = self.inception_4b_1x1(inception_4a_output_out)
370 | inception_4b_1x1_bn_out = self.inception_4b_1x1_bn(inception_4b_1x1_out)
371 | inception_4b_relu_1x1_out = self.inception_4b_relu_1x1(inception_4b_1x1_bn_out)
372 | inception_4b_3x3_reduce_out = self.inception_4b_3x3_reduce(inception_4a_output_out)
373 | inception_4b_3x3_reduce_bn_out = self.inception_4b_3x3_reduce_bn(inception_4b_3x3_reduce_out)
374 | inception_4b_relu_3x3_reduce_out = self.inception_4b_relu_3x3_reduce(inception_4b_3x3_reduce_bn_out)
375 | inception_4b_3x3_out = self.inception_4b_3x3(inception_4b_relu_3x3_reduce_out)
376 | inception_4b_3x3_bn_out = self.inception_4b_3x3_bn(inception_4b_3x3_out)
377 | inception_4b_relu_3x3_out = self.inception_4b_relu_3x3(inception_4b_3x3_bn_out)
378 | inception_4b_double_3x3_reduce_out = self.inception_4b_double_3x3_reduce(inception_4a_output_out)
379 | inception_4b_double_3x3_reduce_bn_out = self.inception_4b_double_3x3_reduce_bn(
380 | inception_4b_double_3x3_reduce_out)
381 | inception_4b_relu_double_3x3_reduce_out = self.inception_4b_relu_double_3x3_reduce(
382 | inception_4b_double_3x3_reduce_bn_out)
383 | inception_4b_double_3x3_1_out = self.inception_4b_double_3x3_1(inception_4b_relu_double_3x3_reduce_out)
384 | inception_4b_double_3x3_1_bn_out = self.inception_4b_double_3x3_1_bn(inception_4b_double_3x3_1_out)
385 | inception_4b_relu_double_3x3_1_out = self.inception_4b_relu_double_3x3_1(inception_4b_double_3x3_1_bn_out)
386 | inception_4b_double_3x3_2_out = self.inception_4b_double_3x3_2(inception_4b_relu_double_3x3_1_out)
387 | inception_4b_double_3x3_2_bn_out = self.inception_4b_double_3x3_2_bn(inception_4b_double_3x3_2_out)
388 | inception_4b_relu_double_3x3_2_out = self.inception_4b_relu_double_3x3_2(inception_4b_double_3x3_2_bn_out)
389 | inception_4b_pool_out = self.inception_4b_pool(inception_4a_output_out)
390 | inception_4b_pool_proj_out = self.inception_4b_pool_proj(inception_4b_pool_out)
391 | inception_4b_pool_proj_bn_out = self.inception_4b_pool_proj_bn(inception_4b_pool_proj_out)
392 | inception_4b_relu_pool_proj_out = self.inception_4b_relu_pool_proj(inception_4b_pool_proj_bn_out)
393 | inception_4b_output_out = torch.cat(
394 | [inception_4b_relu_1x1_out, inception_4b_relu_3x3_out, inception_4b_relu_double_3x3_2_out,
395 | inception_4b_relu_pool_proj_out], 1)
396 |
397 | inception_4c_1x1_out = self.inception_4c_1x1(inception_4b_output_out)
398 | inception_4c_1x1_bn_out = self.inception_4c_1x1_bn(inception_4c_1x1_out)
399 | inception_4c_relu_1x1_out = self.inception_4c_relu_1x1(inception_4c_1x1_bn_out)
400 | inception_4c_3x3_reduce_out = self.inception_4c_3x3_reduce(inception_4b_output_out)
401 | inception_4c_3x3_reduce_bn_out = self.inception_4c_3x3_reduce_bn(inception_4c_3x3_reduce_out)
402 | inception_4c_relu_3x3_reduce_out = self.inception_4c_relu_3x3_reduce(inception_4c_3x3_reduce_bn_out)
403 | inception_4c_3x3_out = self.inception_4c_3x3(inception_4c_relu_3x3_reduce_out)
404 | inception_4c_3x3_bn_out = self.inception_4c_3x3_bn(inception_4c_3x3_out)
405 | inception_4c_relu_3x3_out = self.inception_4c_relu_3x3(inception_4c_3x3_bn_out)
406 | inception_4c_double_3x3_reduce_out = self.inception_4c_double_3x3_reduce(inception_4b_output_out)
407 | inception_4c_double_3x3_reduce_bn_out = self.inception_4c_double_3x3_reduce_bn(
408 | inception_4c_double_3x3_reduce_out)
409 | inception_4c_relu_double_3x3_reduce_out = self.inception_4c_relu_double_3x3_reduce(
410 | inception_4c_double_3x3_reduce_bn_out)
411 | inception_4c_double_3x3_1_out = self.inception_4c_double_3x3_1(inception_4c_relu_double_3x3_reduce_out)
412 | inception_4c_double_3x3_1_bn_out = self.inception_4c_double_3x3_1_bn(inception_4c_double_3x3_1_out)
413 | inception_4c_relu_double_3x3_1_out = self.inception_4c_relu_double_3x3_1(inception_4c_double_3x3_1_bn_out)
414 | inception_4c_double_3x3_2_out = self.inception_4c_double_3x3_2(inception_4c_relu_double_3x3_1_out)
415 | inception_4c_double_3x3_2_bn_out = self.inception_4c_double_3x3_2_bn(inception_4c_double_3x3_2_out)
416 | inception_4c_relu_double_3x3_2_out = self.inception_4c_relu_double_3x3_2(inception_4c_double_3x3_2_bn_out)
417 | inception_4c_pool_out = self.inception_4c_pool(inception_4b_output_out)
418 | inception_4c_pool_proj_out = self.inception_4c_pool_proj(inception_4c_pool_out)
419 | inception_4c_pool_proj_bn_out = self.inception_4c_pool_proj_bn(inception_4c_pool_proj_out)
420 | inception_4c_relu_pool_proj_out = self.inception_4c_relu_pool_proj(inception_4c_pool_proj_bn_out)
421 | inception_4c_output_out = torch.cat(
422 | [inception_4c_relu_1x1_out, inception_4c_relu_3x3_out, inception_4c_relu_double_3x3_2_out,
423 | inception_4c_relu_pool_proj_out], 1)
424 |
425 | inception_4d_1x1_out = self.inception_4d_1x1(inception_4c_output_out)
426 | inception_4d_1x1_bn_out = self.inception_4d_1x1_bn(inception_4d_1x1_out)
427 | inception_4d_relu_1x1_out = self.inception_4d_relu_1x1(inception_4d_1x1_bn_out)
428 | inception_4d_3x3_reduce_out = self.inception_4d_3x3_reduce(inception_4c_output_out)
429 | inception_4d_3x3_reduce_bn_out = self.inception_4d_3x3_reduce_bn(inception_4d_3x3_reduce_out)
430 | inception_4d_relu_3x3_reduce_out = self.inception_4d_relu_3x3_reduce(inception_4d_3x3_reduce_bn_out)
431 | inception_4d_3x3_out = self.inception_4d_3x3(inception_4d_relu_3x3_reduce_out)
432 | inception_4d_3x3_bn_out = self.inception_4d_3x3_bn(inception_4d_3x3_out)
433 | inception_4d_relu_3x3_out = self.inception_4d_relu_3x3(inception_4d_3x3_bn_out)
434 | inception_4d_double_3x3_reduce_out = self.inception_4d_double_3x3_reduce(inception_4c_output_out)
435 | inception_4d_double_3x3_reduce_bn_out = self.inception_4d_double_3x3_reduce_bn(
436 | inception_4d_double_3x3_reduce_out)
437 | inception_4d_relu_double_3x3_reduce_out = self.inception_4d_relu_double_3x3_reduce(
438 | inception_4d_double_3x3_reduce_bn_out)
439 | inception_4d_double_3x3_1_out = self.inception_4d_double_3x3_1(inception_4d_relu_double_3x3_reduce_out)
440 | inception_4d_double_3x3_1_bn_out = self.inception_4d_double_3x3_1_bn(inception_4d_double_3x3_1_out)
441 | inception_4d_relu_double_3x3_1_out = self.inception_4d_relu_double_3x3_1(inception_4d_double_3x3_1_bn_out)
442 | inception_4d_double_3x3_2_out = self.inception_4d_double_3x3_2(inception_4d_relu_double_3x3_1_out)
443 | inception_4d_double_3x3_2_bn_out = self.inception_4d_double_3x3_2_bn(inception_4d_double_3x3_2_out)
444 | inception_4d_relu_double_3x3_2_out = self.inception_4d_relu_double_3x3_2(inception_4d_double_3x3_2_bn_out)
445 | inception_4d_pool_out = self.inception_4d_pool(inception_4c_output_out)
446 | inception_4d_pool_proj_out = self.inception_4d_pool_proj(inception_4d_pool_out)
447 | inception_4d_pool_proj_bn_out = self.inception_4d_pool_proj_bn(inception_4d_pool_proj_out)
448 | inception_4d_relu_pool_proj_out = self.inception_4d_relu_pool_proj(inception_4d_pool_proj_bn_out)
449 | inception_4d_output_out = torch.cat(
450 | [inception_4d_relu_1x1_out, inception_4d_relu_3x3_out, inception_4d_relu_double_3x3_2_out,
451 | inception_4d_relu_pool_proj_out], 1)
452 |
453 | inception_4e_3x3_reduce_out = self.inception_4e_3x3_reduce(inception_4d_output_out)
454 | inception_4e_3x3_reduce_bn_out = self.inception_4e_3x3_reduce_bn(inception_4e_3x3_reduce_out)
455 | inception_4e_relu_3x3_reduce_out = self.inception_4e_relu_3x3_reduce(inception_4e_3x3_reduce_bn_out)
456 | inception_4e_3x3_out = self.inception_4e_3x3(inception_4e_relu_3x3_reduce_out)
457 | inception_4e_3x3_bn_out = self.inception_4e_3x3_bn(inception_4e_3x3_out)
458 | inception_4e_relu_3x3_out = self.inception_4e_relu_3x3(inception_4e_3x3_bn_out)
459 | inception_4e_double_3x3_reduce_out = self.inception_4e_double_3x3_reduce(inception_4d_output_out)
460 | inception_4e_double_3x3_reduce_bn_out = self.inception_4e_double_3x3_reduce_bn(
461 | inception_4e_double_3x3_reduce_out)
462 | inception_4e_relu_double_3x3_reduce_out = self.inception_4e_relu_double_3x3_reduce(
463 | inception_4e_double_3x3_reduce_bn_out)
464 | inception_4e_double_3x3_1_out = self.inception_4e_double_3x3_1(inception_4e_relu_double_3x3_reduce_out)
465 | inception_4e_double_3x3_1_bn_out = self.inception_4e_double_3x3_1_bn(inception_4e_double_3x3_1_out)
466 | inception_4e_relu_double_3x3_1_out = self.inception_4e_relu_double_3x3_1(inception_4e_double_3x3_1_bn_out)
467 | inception_4e_double_3x3_2_out = self.inception_4e_double_3x3_2(inception_4e_relu_double_3x3_1_out)
468 | inception_4e_double_3x3_2_bn_out = self.inception_4e_double_3x3_2_bn(inception_4e_double_3x3_2_out)
469 | inception_4e_relu_double_3x3_2_out = self.inception_4e_relu_double_3x3_2(inception_4e_double_3x3_2_bn_out)
470 | inception_4e_pool_out = self.inception_4e_pool(inception_4d_output_out)
471 | inception_4e_output_out = torch.cat(
472 | [inception_4e_relu_3x3_out, inception_4e_relu_double_3x3_2_out, inception_4e_pool_out], 1)
473 |
474 | inception_5a_1x1_out = self.inception_5a_1x1(inception_4e_output_out)
475 | inception_5a_1x1_bn_out = self.inception_5a_1x1_bn(inception_5a_1x1_out)
476 | inception_5a_relu_1x1_out = self.inception_5a_relu_1x1(inception_5a_1x1_bn_out)
477 | inception_5a_3x3_reduce_out = self.inception_5a_3x3_reduce(inception_4e_output_out)
478 | inception_5a_3x3_reduce_bn_out = self.inception_5a_3x3_reduce_bn(inception_5a_3x3_reduce_out)
479 | inception_5a_relu_3x3_reduce_out = self.inception_5a_relu_3x3_reduce(inception_5a_3x3_reduce_bn_out)
480 | inception_5a_3x3_out = self.inception_5a_3x3(inception_5a_relu_3x3_reduce_out)
481 | inception_5a_3x3_bn_out = self.inception_5a_3x3_bn(inception_5a_3x3_out)
482 | inception_5a_relu_3x3_out = self.inception_5a_relu_3x3(inception_5a_3x3_bn_out)
483 | inception_5a_double_3x3_reduce_out = self.inception_5a_double_3x3_reduce(inception_4e_output_out)
484 | inception_5a_double_3x3_reduce_bn_out = self.inception_5a_double_3x3_reduce_bn(
485 | inception_5a_double_3x3_reduce_out)
486 | inception_5a_relu_double_3x3_reduce_out = self.inception_5a_relu_double_3x3_reduce(
487 | inception_5a_double_3x3_reduce_bn_out)
488 | inception_5a_double_3x3_1_out = self.inception_5a_double_3x3_1(inception_5a_relu_double_3x3_reduce_out)
489 | inception_5a_double_3x3_1_bn_out = self.inception_5a_double_3x3_1_bn(inception_5a_double_3x3_1_out)
490 | inception_5a_relu_double_3x3_1_out = self.inception_5a_relu_double_3x3_1(inception_5a_double_3x3_1_bn_out)
491 | inception_5a_double_3x3_2_out = self.inception_5a_double_3x3_2(inception_5a_relu_double_3x3_1_out)
492 | inception_5a_double_3x3_2_bn_out = self.inception_5a_double_3x3_2_bn(inception_5a_double_3x3_2_out)
493 | inception_5a_relu_double_3x3_2_out = self.inception_5a_relu_double_3x3_2(inception_5a_double_3x3_2_bn_out)
494 | inception_5a_pool_out = self.inception_5a_pool(inception_4e_output_out)
495 | inception_5a_pool_proj_out = self.inception_5a_pool_proj(inception_5a_pool_out)
496 | inception_5a_pool_proj_bn_out = self.inception_5a_pool_proj_bn(inception_5a_pool_proj_out)
497 | inception_5a_relu_pool_proj_out = self.inception_5a_relu_pool_proj(inception_5a_pool_proj_bn_out)
498 | inception_5a_output_out = torch.cat(
499 | [inception_5a_relu_1x1_out, inception_5a_relu_3x3_out, inception_5a_relu_double_3x3_2_out,
500 | inception_5a_relu_pool_proj_out], 1)
501 |
502 | inception_5b_1x1_out = self.inception_5b_1x1(inception_5a_output_out)
503 | inception_5b_1x1_bn_out = self.inception_5b_1x1_bn(inception_5b_1x1_out)
504 | inception_5b_relu_1x1_out = self.inception_5b_relu_1x1(inception_5b_1x1_bn_out)
505 | inception_5b_3x3_reduce_out = self.inception_5b_3x3_reduce(inception_5a_output_out)
506 | inception_5b_3x3_reduce_bn_out = self.inception_5b_3x3_reduce_bn(inception_5b_3x3_reduce_out)
507 | inception_5b_relu_3x3_reduce_out = self.inception_5b_relu_3x3_reduce(inception_5b_3x3_reduce_bn_out)
508 | inception_5b_3x3_out = self.inception_5b_3x3(inception_5b_relu_3x3_reduce_out)
509 | inception_5b_3x3_bn_out = self.inception_5b_3x3_bn(inception_5b_3x3_out)
510 | inception_5b_relu_3x3_out = self.inception_5b_relu_3x3(inception_5b_3x3_bn_out)
511 | inception_5b_double_3x3_reduce_out = self.inception_5b_double_3x3_reduce(inception_5a_output_out)
512 | inception_5b_double_3x3_reduce_bn_out = self.inception_5b_double_3x3_reduce_bn(
513 | inception_5b_double_3x3_reduce_out)
514 | inception_5b_relu_double_3x3_reduce_out = self.inception_5b_relu_double_3x3_reduce(
515 | inception_5b_double_3x3_reduce_bn_out)
516 | inception_5b_double_3x3_1_out = self.inception_5b_double_3x3_1(inception_5b_relu_double_3x3_reduce_out)
517 | inception_5b_double_3x3_1_bn_out = self.inception_5b_double_3x3_1_bn(inception_5b_double_3x3_1_out)
518 | inception_5b_relu_double_3x3_1_out = self.inception_5b_relu_double_3x3_1(inception_5b_double_3x3_1_bn_out)
519 | inception_5b_double_3x3_2_out = self.inception_5b_double_3x3_2(inception_5b_relu_double_3x3_1_out)
520 | inception_5b_double_3x3_2_bn_out = self.inception_5b_double_3x3_2_bn(inception_5b_double_3x3_2_out)
521 | inception_5b_relu_double_3x3_2_out = self.inception_5b_relu_double_3x3_2(inception_5b_double_3x3_2_bn_out)
522 | inception_5b_pool_out = self.inception_5b_pool(inception_5a_output_out)
523 | inception_5b_pool_proj_out = self.inception_5b_pool_proj(inception_5b_pool_out)
524 | inception_5b_pool_proj_bn_out = self.inception_5b_pool_proj_bn(inception_5b_pool_proj_out)
525 | inception_5b_relu_pool_proj_out = self.inception_5b_relu_pool_proj(inception_5b_pool_proj_bn_out)
526 | inception_5b_output_out = torch.cat(
527 | [inception_5b_relu_1x1_out, inception_5b_relu_3x3_out, inception_5b_relu_double_3x3_2_out,
528 | inception_5b_relu_pool_proj_out], 1)
529 |
530 | return inception_5b_output_out
531 |
532 | def forward(self, input, normalized=False):
533 | bbout = self.features(input)
534 | x = self.head(bbout)
535 |
536 | if normalized:
537 | norm = x.norm(dim=1, p=2, keepdim=True)
538 | x = x.div(norm.expand_as(x))
539 |
540 | return x
541 |
--------------------------------------------------------------------------------
/mytrain.py:
--------------------------------------------------------------------------------
1 | import myutils
2 | from config import get_config
3 | from learner import metric_learner
4 | import argparse
5 | from pathlib import Path
6 | import numpy as np
7 | import torch
8 |
9 | if __name__ == '__main__':
10 |
11 | conf = get_config()
12 |
13 | learner = metric_learner(conf)
14 |
15 | learner.load_bninception_pretrained(conf)
16 |
17 | learner.train(conf)
18 |
19 |
--------------------------------------------------------------------------------
/mytrain.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=0
2 | export PYTHONPATH=.
3 |
4 | python mytrain.py \
5 | --use_dataset CUB \
6 | --instances 3 \
7 | --use_loss triplet \
8 | --lr 0.5e-5 \
9 | --lr_p 0.25e-5 \
10 | --lr_gamma 0.1 \
11 | --sec_wei 1.0
12 |
--------------------------------------------------------------------------------
/myutils.py:
--------------------------------------------------------------------------------
1 | import os, sys, \
2 | subprocess, glob, re, \
3 | numpy as np, \
4 | logging, \
5 | collections, copy, \
6 | datetime
7 | from os import path as osp
8 | import time
9 |
10 | root_path = osp.normpath(osp.join(osp.abspath(osp.dirname(__file__)), )) + '/'
11 | sys.path.insert(0, root_path)
12 |
13 | def set_stream_logger(log_level=logging.DEBUG):
14 | import colorlog
15 | sh = colorlog.StreamHandler()
16 | sh.setLevel(log_level)
17 | sh.setFormatter(
18 | colorlog.ColoredFormatter(
19 | ' %(asctime)s %(filename)s [line:%(lineno)d] %(log_color)s%(levelname)s%(reset)s %(message)s'))
20 | logging.root.addHandler(sh)
21 |
22 | def set_file_logger(work_dir=None, log_level=logging.DEBUG):
23 | work_dir = work_dir or root_path
24 | fh = logging.FileHandler(os.path.join(work_dir, 'log-ing'))
25 | fh.setLevel(log_level)
26 | fh.setFormatter(
27 | logging.Formatter('%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s'))
28 | logging.root.addHandler(fh)
29 |
30 | logging.root.setLevel(logging.INFO)
31 | set_stream_logger(logging.DEBUG)
32 |
33 | def shell(cmd, block=True, return_msg=True, verbose=True, timeout=None):
34 | import os
35 | my_env = os.environ.copy()
36 | home = os.path.expanduser('~')
37 | my_env['http_proxy'] = ''
38 | my_env['https_proxy'] = ''
39 | if verbose:
40 | logging.info('cmd is ' + cmd)
41 | if block:
42 |
43 | task = subprocess.Popen(cmd,
44 | shell=True,
45 | stdout=subprocess.PIPE,
46 | stderr=subprocess.PIPE,
47 | env=my_env,
48 | preexec_fn=os.setsid
49 | )
50 | if return_msg:
51 | msg = task.communicate(timeout)
52 | msg = [msg_.decode('utf-8') for msg_ in msg]
53 | if msg[0] != '' and verbose:
54 | logging.info('stdout {}'.format(msg[0]))
55 | if msg[1] != '' and verbose:
56 | logging.error('stderr {}'.format(msg[1]))
57 | return msg
58 | else:
59 | return task
60 | else:
61 | logging.debug('Non-block!')
62 | task = subprocess.Popen(cmd,
63 | shell=True,
64 | stdout=subprocess.PIPE,
65 | stderr=subprocess.PIPE,
66 | env=my_env,
67 | preexec_fn=os.setsid
68 | )
69 | return task
70 |
71 | def rm(path, block=True):
72 | path = osp.abspath(path)
73 | if not osp.exists(path):
74 | logging.info(f'no need rm {path}')
75 | stdout, _ = shell('which trash', verbose=False)
76 | if 'trash' not in stdout:
77 | dst = glob.glob('{}.bak*'.format(path))
78 | parsr = re.compile(r'{}.bak(\d+?)'.format(path))
79 | used = [0, ]
80 | for d in dst:
81 | m = re.match(parsr, d)
82 | if not m:
83 | used.append(0)
84 | elif m.groups()[0] == '':
85 | used.append(0)
86 | else:
87 | used.append(int(m.groups()[0]))
88 | dst_path = '{}.bak{}'.format(path, max(used) + 1)
89 | cmd = 'mv {} {} '.format(path, dst_path)
90 | return shell(cmd, block=block)
91 | else:
92 | return shell(f'trash -r {path}', block=block)
93 |
94 | def mkdir_p(path, delete=True):
95 | path = str(path)
96 | if path == '':
97 | return
98 | if delete and osp.exists(path):
99 | rm(path)
100 | if not osp.exists(path):
101 | shell('mkdir -p ' + path)
102 |
103 |
104 | class Logger(object):
105 | def __init__(self, fpath=None, console=sys.stdout):
106 | self.console = console
107 | self.file = None
108 | if fpath is not None:
109 | mkdir_p(os.path.dirname(fpath), delete=False)
110 |
111 | self.file = open(fpath, 'a')
112 |
113 | def __del__(self):
114 | self.close()
115 |
116 | def __enter__(self):
117 | pass
118 |
119 | def __exit__(self, *args):
120 | self.close()
121 |
122 | def write(self, msg):
123 | self.console.write(msg)
124 | if self.file is not None:
125 | self.file.write(msg)
126 |
127 | def flush(self):
128 | self.console.flush()
129 | if self.file is not None:
130 | self.file.flush()
131 | os.fsync(self.file.fileno())
132 |
133 | def close(self):
134 | self.console.close()
135 | if self.file is not None:
136 | self.file.close()
137 |
138 |
139 | class Timer(object):
140 | """A flexible Timer class.
141 |
142 | :Example:
143 |
144 | >>> import time
145 | >>> import cvbase as cvb
146 | >>> with cvb.Timer():
147 | >>> # simulate a code block that will run for 1s
148 | >>> time.sleep(1)
149 | 1.000
150 | >>> with cvb.Timer(print_tmpl='hey it taks {:.1f} seconds'):
151 | >>> # simulate a code block that will run for 1s
152 | >>> time.sleep(1)
153 | hey it taks 1.0 seconds
154 | >>> timer = cvb.Timer()
155 | >>> time.sleep(0.5)
156 | >>> print(timer.since_start())
157 | 0.500
158 | >>> time.sleep(0.5)
159 | >>> print(timer.since_last_check())
160 | 0.500
161 | >>> print(timer.since_start())
162 | 1.000
163 |
164 | """
165 |
166 | def __init__(self, start=True, print_tmpl=None):
167 | self._is_running = False
168 | self.print_tmpl = print_tmpl if print_tmpl else '{:.3f}'
169 | if start:
170 | self.start()
171 |
172 | @property
173 | def is_running(self):
174 | """bool: indicate whether the timer is running"""
175 | return self._is_running
176 |
177 | def __enter__(self):
178 | self.start()
179 | return self
180 |
181 | def __exit__(self, type, value, traceback):
182 | print(self.print_tmpl.format(self.since_last_check()))
183 | self._is_running = False
184 |
185 | def start(self):
186 | """Start the timer."""
187 | if not self._is_running:
188 | self._t_start = time.time()
189 | self._is_running = True
190 | self._t_last = time.time()
191 |
192 | def since_start(self, aux=''):
193 | """Total time since the timer is started.
194 |
195 | Returns(float): the time in seconds
196 | """
197 | if not self._is_running:
198 | raise ValueError('timer is not running')
199 | self._t_last = time.time()
200 | logging.info(f'{aux} time {self.print_tmpl.format(self._t_last - self._t_start)}')
201 | return self._t_last - self._t_start
202 |
203 | def since_last_check(self, aux='', verbose=True):
204 | """Time since the last checking.
205 |
206 | Either :func:`since_start` or :func:`since_last_check` is a checking operation.
207 |
208 | Returns(float): the time in seconds
209 | """
210 | if not self._is_running:
211 | raise ValueError('timer is not running')
212 | dur = time.time() - self._t_last
213 | self._t_last = time.time()
214 | if verbose:
215 | logging.info(f'{aux} time {self.print_tmpl.format(dur)}')
216 | return dur
217 |
218 |
219 | class AverageMeter(object):
220 | """Computes and stores the average and current value"""
221 |
222 | def __init__(self, maxlen=100):
223 |
224 | self.val = 0
225 | self.avg = 0
226 | self.sum = 0
227 | self.count = 0
228 | self.mem = collections.deque(maxlen=maxlen)
229 |
230 | def reset(self):
231 | self.val = 0
232 | self.avg = 0
233 | self.sum = 0
234 | self.count = 0
235 |
236 | def update(self, val, n=1):
237 | val = float(val)
238 | self.mem.append(val)
239 | self.avg = np.mean(list(self.mem))
240 |
241 |
242 | timer = Timer()
243 | logging.info('import myutils')
244 |
--------------------------------------------------------------------------------
/test_sop.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from config import get_config
3 | from learner import metric_learner
4 |
5 | if __name__ == '__main__':
6 | conf = get_config()
7 |
8 | test_path = 'work_space/' + conf.test_sop_model + '/models'
9 | logging.info(test_path)
10 | test_name = 'SOP'
11 |
12 | conf.use_dataset = test_name
13 | learner = metric_learner(conf, inference=True)
14 | learner.load_state(conf, resume_path=test_path)
15 |
16 | nmi, f1, recall_ks = learner.test_sop_complete(conf)
17 |
18 | ks_dict = dict()
19 | ks_dict['CUB'] = [1, 2, 4, 8, 16, 32]
20 | ks_dict['Cars'] = [1, 2, 4, 8, 16, 32]
21 | ks_dict['SOP'] = [1, 10, 100, 1000, 10000]
22 | ks_dict['Inshop'] = [1, 10, 20, 30, 40, 50]
23 | k_s = ks_dict[test_name]
24 |
25 | logging.info(f'nmi: {nmi}')
26 | logging.info(f'f1: {f1}')
27 | for i in range(len(recall_ks)):
28 | logging.info(f'R{k_s[i]} {recall_ks[i]}')
29 |
--------------------------------------------------------------------------------
/test_sop.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=0
2 | export PYTHONPATH=.
3 |
4 | python test_sop.py \
5 | --use_dataset SOP \
6 | --test_sop_model SOP_0000_0000
7 |
8 |
9 |
10 |
11 |
12 |
13 |
--------------------------------------------------------------------------------