├── .gitignore ├── DataSet ├── CUB200.py ├── Car196.py ├── In_shop_clothes.py ├── Products.py ├── __init__.py └── transforms.py ├── LICENSE ├── Model2Feature.py ├── README.md ├── __init__.py ├── data_list_example.txt ├── evaluations ├── NMI.py ├── __init__.py ├── cnn.py ├── extract_featrure.py ├── recall_at_k.py └── top_k.py ├── losses ├── Binomial.py ├── Contrastive.py ├── HardMining.py ├── LiftedStructure.py ├── NCA.py ├── SemiHard.py ├── __init__.py └── triplet.py ├── models ├── BN_Inception.py └── __init__.py ├── run_train_00.sh ├── test.py ├── train.py ├── trainer.py └── utils ├── Batch_generator.py ├── HyperparamterDisplay.py ├── __init__.py ├── cluster.py ├── logging.py ├── map.py ├── meters.py ├── numpy_tozero.py ├── orthogonal_regularizaton.py ├── osutils.py ├── sampler.py ├── serialization.py └── str2nums.py /.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | 3 | # temporary files which can be created if a process still has a handle open of a deleted file 4 | .fuse_hidden* 5 | 6 | # KDE directory preferences 7 | .directory 8 | 9 | # Linux trash folder which might appear on any partition or disk 10 | .Trash-* 11 | 12 | # .nfs files are created when an open file is removed but is still being accessed 13 | .nfs* 14 | *.npy 15 | 16 | *.DS_Store 17 | .AppleDouble 18 | .LSOverride 19 | 20 | # Icon must end with two \r 21 | Icon 22 | 23 | 24 | # Thumbnails 25 | ._* 26 | 27 | # Files that might appear in the root of a volume 28 | .DocumentRevisions-V100 29 | .fseventsd 30 | .Spotlight-V100 31 | .TemporaryItems 32 | .Trashes 33 | .VolumeIcon.icns 34 | .com.apple.timemachine.donotpresent 35 | 36 | # Directories potentially created on remote AFP share 37 | .AppleDB 38 | .AppleDesktop 39 | Network Trash Folder 40 | Temporary Items 41 | .apdisk 42 | 43 | 44 | # swap 45 | [._]*.s[a-v][a-z] 46 | [._]*.sw[a-p] 47 | [._]s[a-v][a-z] 48 | [._]sw[a-p] 49 | # session 50 | Session.vim 51 | # temporary 52 | .netrwhist 53 | *~ 54 | # auto-generated tag files 55 | tags 56 | 57 | 58 | # cache files for sublime text 59 | *.tmlanguage.cache 60 | *.tmPreferences.cache 61 | *.stTheme.cache 62 | 63 | # workspace files are user-specific 64 | *.sublime-workspace 65 | 66 | # project files should be checked into the repository, unless a significant 67 | # proportion of contributors will probably not be using SublimeText 68 | # *.sublime-project 69 | 70 | # sftp configuration file 71 | sftp-config.json 72 | 73 | # Package control specific files 74 | Package Control.last-run 75 | Package Control.ca-list 76 | Package Control.ca-bundle 77 | Package Control.system-ca-bundle 78 | Package Control.cache/ 79 | Package Control.ca-certs/ 80 | Package Control.merged-ca-bundle 81 | Package Control.user-ca-bundle 82 | oscrypto-ca-bundle.crt 83 | bh_unicode_properties.cache 84 | 85 | # Sublime-github package stores a github token in this file 86 | # https://packagecontrol.io/packages/sublime-github 87 | GitHub.sublime-settings 88 | 89 | 90 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm 91 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 92 | 93 | # User-specific stuff: 94 | .idea 95 | .idea/**/workspace.xml 96 | .idea/**/tasks.xml 97 | 98 | # Sensitive or high-churn files: 99 | .idea/**/dataSources/ 100 | .idea/**/dataSources.ids 101 | .idea/**/dataSources.xml 102 | .idea/**/dataSources.local.xml 103 | .idea/**/sqlDataSources.xml 104 | .idea/**/dynamic.xml 105 | .idea/**/uiDesigner.xml 106 | 107 | # Gradle: 108 | .idea/**/gradle.xml 109 | .idea/**/libraries 110 | 111 | # Mongo Explorer plugin: 112 | .idea/**/mongoSettings.xml 113 | 114 | ## File-based project format: 115 | *.iws 116 | 117 | ## Plugin-specific files: 118 | 119 | # IntelliJ 120 | /out/ 121 | 122 | # mpeltonen/sbt-idea plugin 123 | .idea_modules/ 124 | 125 | # JIRA plugin 126 | atlassian-ide-plugin.xml 127 | 128 | # Crashlytics plugin (for Android Studio and IntelliJ) 129 | com_crashlytics_export_strings.xml 130 | crashlytics.properties 131 | crashlytics-build.properties 132 | fabric.properties 133 | 134 | 135 | # Byte-compiled / optimized / DLL files 136 | __pycache__/ 137 | *.py[cod] 138 | *$py.class 139 | 140 | # C extensions 141 | *.so 142 | 143 | # Distribution / packaging 144 | .Python 145 | env/ 146 | build/ 147 | develop-eggs/ 148 | dist/ 149 | downloads/ 150 | eggs/ 151 | .eggs/ 152 | lib/ 153 | lib64/ 154 | parts/ 155 | sdist/ 156 | var/ 157 | *.egg-info/ 158 | .installed.cfg 159 | *.egg 160 | 161 | # PyInstaller 162 | # Usually these files are written by a python script from a template 163 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 164 | *.manifest 165 | *.spec 166 | 167 | # Installer logs 168 | pip-log.txt 169 | pip-delete-this-directory.txt 170 | 171 | # Unit test / coverage reports 172 | htmlcov/ 173 | .tox/ 174 | .coverage 175 | .coverage.* 176 | .cache 177 | nosetests.xml 178 | coverage.xml 179 | *,cover 180 | .hypothesis/ 181 | 182 | # Translations 183 | *.mo 184 | *.pot 185 | 186 | # Django stuff: 187 | *.log 188 | local_settings.py 189 | 190 | # Flask stuff: 191 | instance/ 192 | .webassets-cache 193 | 194 | # Scrapy stuff: 195 | .scrapy 196 | 197 | # Sphinx documentation 198 | docs/_build/ 199 | 200 | # PyBuilder 201 | target/ 202 | 203 | # IPython Notebook 204 | .ipynb_checkpoints 205 | 206 | # pyenv 207 | .python-version 208 | .npy 209 | 210 | # celery beat schedule file 211 | celerybeat-schedule 212 | 213 | # dotenv 214 | .env 215 | 216 | # virtualenv 217 | venv/ 218 | ENV/ 219 | 220 | # Spyder project settings 221 | .spyderproject 222 | 223 | # Rope project settings 224 | .ropeproject 225 | 226 | 227 | # Project specific 228 | DataSet/Car196/ 229 | DataSet/CUB_200_2011/ 230 | DataSet/Products/ 231 | pretrained* 232 | checkpoint* 233 | *.npy 234 | 235 | # No loss will update before my paper 236 | #losses/* 237 | -------------------------------------------------------------------------------- /DataSet/CUB200.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, print_function 2 | """ 3 | CUB-200-2011 data-set for Pytorch 4 | """ 5 | import torch 6 | import torch.utils.data as data 7 | from PIL import Image 8 | 9 | import os 10 | import sys 11 | from DataSet import transforms 12 | from collections import defaultdict 13 | 14 | 15 | def default_loader(path): 16 | return Image.open(path).convert('RGB') 17 | 18 | def Generate_transform_Dict(origin_width=256, width=227, ratio=0.16): 19 | 20 | std_value = 1.0 / 255.0 21 | normalize = transforms.Normalize(mean=[104 / 255.0, 117 / 255.0, 128 / 255.0], 22 | std= [1.0/255, 1.0/255, 1.0/255]) 23 | 24 | transform_dict = {} 25 | 26 | transform_dict['rand-crop'] = \ 27 | transforms.Compose([ 28 | transforms.CovertBGR(), 29 | transforms.Resize((origin_width)), 30 | transforms.RandomResizedCrop(scale=(ratio, 1), size=width), 31 | transforms.RandomHorizontalFlip(), 32 | transforms.ToTensor(), 33 | normalize, 34 | ]) 35 | 36 | transform_dict['center-crop'] = \ 37 | transforms.Compose([ 38 | transforms.CovertBGR(), 39 | transforms.Resize((origin_width)), 40 | transforms.CenterCrop(width), 41 | transforms.ToTensor(), 42 | normalize, 43 | ]) 44 | 45 | transform_dict['resize'] = \ 46 | transforms.Compose([ 47 | transforms.CovertBGR(), 48 | transforms.Resize((width)), 49 | transforms.ToTensor(), 50 | normalize, 51 | ]) 52 | return transform_dict 53 | 54 | 55 | class MyData(data.Dataset): 56 | def __init__(self, root=None, label_txt=None, 57 | transform=None, loader=default_loader): 58 | 59 | # Initialization data path and train(gallery or query) txt path 60 | if root is None: 61 | self.root = "data/cub/" 62 | self.root = root 63 | 64 | if label_txt is None: 65 | label_txt = os.path.join(root, 'train.txt') 66 | 67 | if transform is None: 68 | transform_dict = Generate_transform_Dict()['rand-crop'] 69 | 70 | file = open(label_txt) 71 | images_anon = file.readlines() 72 | 73 | images = [] 74 | labels = [] 75 | 76 | for img_anon in images_anon: 77 | 78 | [img, label] = img_anon.split(' ') 79 | images.append(img) 80 | labels.append(int(label)) 81 | 82 | classes = list(set(labels)) 83 | 84 | # Generate Index Dictionary for every class 85 | Index = defaultdict(list) 86 | for i, label in enumerate(labels): 87 | Index[label].append(i) 88 | 89 | # Initialization Done 90 | self.root = root 91 | self.images = images 92 | self.labels = labels 93 | self.classes = classes 94 | self.transform = transform 95 | self.Index = Index 96 | self.loader = loader 97 | 98 | def __getitem__(self, index): 99 | fn, label = self.images[index], self.labels[index] 100 | fn = os.path.join(self.root, fn) 101 | img = self.loader(fn) 102 | if self.transform is not None: 103 | img = self.transform(img) 104 | return img, label 105 | 106 | def __len__(self): 107 | return len(self.images) 108 | 109 | 110 | class CUB_200_2011: 111 | def __init__(self, width=227, origin_width=256, ratio=0.16, root=None, transform=None): 112 | print('width: \t {}'.format(width)) 113 | transform_Dict = Generate_transform_Dict(origin_width=origin_width, width=width, ratio=ratio) 114 | if root is None: 115 | root = "data/CUB_200_2011/" 116 | 117 | train_txt = os.path.join(root, 'train.txt') 118 | test_txt = os.path.join(root, 'test.txt') 119 | 120 | self.train = MyData(root, label_txt=train_txt, transform=transform_Dict['rand-crop']) 121 | self.gallery = MyData(root, label_txt=test_txt, transform=transform_Dict['center-crop']) 122 | 123 | 124 | def testCUB_200_2011(): 125 | print(CUB_200_2011.__name__) 126 | data = CUB_200_2011() 127 | print(len(data.gallery)) 128 | print(len(data.train)) 129 | print(data.train[1]) 130 | 131 | 132 | if __name__ == "__main__": 133 | testCUB_200_2011() 134 | 135 | 136 | -------------------------------------------------------------------------------- /DataSet/Car196.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, print_function 2 | """ 3 | CUB-200-2011 data-set for Pytorch 4 | """ 5 | import torch 6 | import torch.utils.data as data 7 | from PIL import Image 8 | 9 | import os 10 | from torchvision import transforms 11 | from collections import defaultdict 12 | 13 | from DataSet.CUB200 import MyData, default_loader, Generate_transform_Dict 14 | 15 | 16 | class Cars196: 17 | def __init__(self, root=None, origin_width=256, width=227, ratio=0.16, transform=None): 18 | if transform is None: 19 | transform_Dict = Generate_transform_Dict(origin_width=origin_width, width=width, ratio=ratio) 20 | if root is None: 21 | root = 'data/Cars196/' 22 | 23 | train_txt = os.path.join(root, 'train.txt') 24 | test_txt = os.path.join(root, 'test.txt') 25 | self.train = MyData(root, label_txt=train_txt, transform=transform_Dict['rand-crop']) 26 | self.gallery = MyData(root, label_txt=test_txt, transform=transform_Dict['center-crop']) 27 | 28 | 29 | def testCar196(): 30 | data = Cars196() 31 | print(len(data.gallery)) 32 | print(len(data.train)) 33 | print(data.train[1]) 34 | 35 | 36 | if __name__ == "__main__": 37 | testCar196() 38 | 39 | 40 | -------------------------------------------------------------------------------- /DataSet/In_shop_clothes.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, print_function 2 | """ 3 | In-shop-clothes data-set for Pytorch 4 | """ 5 | import torch 6 | import torch.utils.data as data 7 | from PIL import Image 8 | 9 | import os 10 | from torchvision import transforms 11 | from collections import defaultdict 12 | 13 | from DataSet.CUB200 import default_loader, Generate_transform_Dict 14 | 15 | 16 | class MyData(data.Dataset): 17 | def __init__(self, root=None, label_txt=None, 18 | transform=None, loader=default_loader): 19 | 20 | # Initialization data path and train(gallery or query) txt path 21 | 22 | if root is None: 23 | root = "/home/xunwang" 24 | label_txt = os.path.join(root, 'train.txt') 25 | 26 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 27 | std=[0.229, 0.224, 0.225]) 28 | if transform is None: 29 | transform = transforms.Compose([ 30 | # transforms.CovertBGR(), 31 | transforms.Resize(256), 32 | transforms.RandomResizedCrop(scale=(0.16, 1), size=224), 33 | transforms.RandomHorizontalFlip(), 34 | transforms.ToTensor(), 35 | normalize, 36 | ]) 37 | 38 | # read txt get image path and labels 39 | file = open(label_txt) 40 | images_anon = file.readlines() 41 | 42 | images = [] 43 | labels = [] 44 | 45 | for img_anon in images_anon: 46 | img_anon = img_anon.replace(' ', '\t') 47 | 48 | [img, label] = (img_anon.split('\t'))[:2] 49 | images.append(img) 50 | labels.append(int(label)) 51 | 52 | classes = list(set(labels)) 53 | 54 | # Generate Index Dictionary for every class 55 | Index = defaultdict(list) 56 | for i, label in enumerate(labels): 57 | Index[label].append(i) 58 | 59 | # Initialization Done 60 | self.root = root 61 | self.images = images 62 | self.labels = labels 63 | self.classes = classes 64 | self.transform = transform 65 | self.Index = Index 66 | self.loader = loader 67 | 68 | def __getitem__(self, index): 69 | fn, label = self.images[index], self.labels[index] 70 | # print(os.path.join(self.root, fn)) 71 | img = self.loader(os.path.join(self.root, fn)) 72 | if self.transform is not None: 73 | img = self.transform(img) 74 | return img, label 75 | 76 | def __len__(self): 77 | return len(self.images) 78 | 79 | 80 | class InShopClothes: 81 | def __init__(self, root=None, crop=False, origin_width=256, width=224, ratio=0.16): 82 | # Data loading code 83 | transform_Dict = Generate_transform_Dict(origin_width=origin_width, width=width, ratio=ratio) 84 | 85 | if root is None: 86 | root = 'data/In_shop_clothes' 87 | 88 | train_txt = os.path.join(root, 'train.txt') 89 | gallery_txt = os.path.join(root, 'gallery.txt') 90 | query_txt = os.path.join(root, 'query.txt') 91 | 92 | self.train = MyData(root, label_txt=train_txt, transform=transform_Dict['rand-crop']) 93 | self.gallery = MyData(root, label_txt=gallery_txt, transform=transform_Dict['center-crop']) 94 | self.query = MyData(root, label_txt=query_txt, transform=transform_Dict['center-crop']) 95 | 96 | 97 | 98 | def testIn_Shop_Clothes(): 99 | data = InShopClothes() 100 | print(len(data.gallery), len(data.train)) 101 | print(len(data.query)) 102 | print(data.train[1][0][0][0][1]) 103 | 104 | 105 | if __name__ == "__main__": 106 | testIn_Shop_Clothes() -------------------------------------------------------------------------------- /DataSet/Products.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import os 4 | import os.path as osp 5 | import torchvision.transforms as transforms 6 | import torchvision.datasets as datasets 7 | from DataSet.CUB200 import default_loader, Generate_transform_Dict, MyData 8 | 9 | 10 | class Products: 11 | def __init__(self, width=224, origin_width=256, ratio=0.16, root=None, transform=None): 12 | transform_Dict = Generate_transform_Dict(origin_width=origin_width, width=width, ratio=ratio) 13 | if root is None: 14 | root = '../data/Products' 15 | 16 | train_txt = osp.join(root, 'train.txt') 17 | test_txt = osp.join(root, 'test.txt') 18 | 19 | self.train = MyData(root, label_txt=train_txt, transform=transform_Dict['rand-crop']) 20 | self.gallery = MyData(root, label_txt=test_txt, transform=transform_Dict['center-crop']) 21 | 22 | def test(): 23 | data = Products() 24 | print(data.train[1][0][0][0]) 25 | print(len(data.gallery), len(data.train)) 26 | 27 | 28 | 29 | if __name__=='__main__': 30 | test() 31 | 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /DataSet/__init__.py: -------------------------------------------------------------------------------- 1 | from .CUB200 import CUB_200_2011 2 | from .Car196 import Cars196 3 | from .Products import Products 4 | from .In_shop_clothes import InShopClothes 5 | # from .transforms import * 6 | import os 7 | 8 | __factory = { 9 | 'cub': CUB_200_2011, 10 | 'car': Cars196, 11 | 'product': Products, 12 | 'shop': InShopClothes, 13 | } 14 | 15 | 16 | def names(): 17 | return sorted(__factory.keys()) 18 | 19 | def get_full_name(name): 20 | if name not in __factory: 21 | raise KeyError("Unknown dataset:", name) 22 | return __factory[name].__name__ 23 | 24 | def create(name, root=None, *args, **kwargs): 25 | """ 26 | Create a dataset instance. 27 | """ 28 | if root is not None: 29 | root = os.path.join(root, get_full_name(name)) 30 | 31 | if name not in __factory: 32 | raise KeyError("Unknown dataset:", name) 33 | return __factory[name](root=root, *args, **kwargs) 34 | -------------------------------------------------------------------------------- /DataSet/transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torchvision.transforms import * 4 | from PIL import Image 5 | 6 | 7 | class CovertBGR(object): 8 | def __init__(self): 9 | pass 10 | 11 | def __call__(self, img): 12 | r, g, b = img.split() 13 | img = Image.merge("RGB", (b, g, r)) 14 | return img 15 | 16 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /Model2Feature.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import absolute_import, print_function 3 | 4 | import torch 5 | from torch.backends import cudnn 6 | from evaluations import extract_features 7 | import models 8 | import DataSet 9 | from utils.serialization import load_checkpoint 10 | cudnn.benchmark = True 11 | 12 | 13 | def Model2Feature(data, net, checkpoint, dim=512, width=224, root=None, nThreads=16, batch_size=100, pool_feature=False, **kargs): 14 | dataset_name = data 15 | model = models.create(net, dim=dim, pretrained=False) 16 | # resume = load_checkpoint(ckp_path) 17 | resume = checkpoint 18 | model.load_state_dict(resume['state_dict']) 19 | model = torch.nn.DataParallel(model).cuda() 20 | data = DataSet.create(data, width=width, root=root) 21 | 22 | if dataset_name in ['shop', 'jd_test']: 23 | gallery_loader = torch.utils.data.DataLoader( 24 | data.gallery, batch_size=batch_size, shuffle=False, 25 | drop_last=False, pin_memory=True, num_workers=nThreads) 26 | 27 | query_loader = torch.utils.data.DataLoader( 28 | data.query, batch_size=batch_size, 29 | shuffle=False, drop_last=False, 30 | pin_memory=True, num_workers=nThreads) 31 | 32 | gallery_feature, gallery_labels = extract_features(model, gallery_loader, print_freq=1e5, metric=None, pool_feature=pool_feature) 33 | query_feature, query_labels = extract_features(model, query_loader, print_freq=1e5, metric=None, pool_feature=pool_feature) 34 | 35 | else: 36 | data_loader = torch.utils.data.DataLoader( 37 | data.gallery, batch_size=batch_size, 38 | shuffle=False, drop_last=False, pin_memory=True, 39 | num_workers=nThreads) 40 | features, labels = extract_features(model, data_loader, print_freq=1e5, metric=None, pool_feature=pool_feature) 41 | gallery_feature, gallery_labels = query_feature, query_labels = features, labels 42 | return gallery_feature, gallery_labels, query_feature, query_labels 43 | 44 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Deep Metric Learning in PyTorch 2 | 3 | Learn deep metric for image retrieval or other information retrieval. 4 | 5 | 6 | #### Our XBM is nominated as best paper in CVPR 2020. 7 | 8 | 9 | #### One Blog on XBM in Zhihu 10 | 11 | 我写了一个知乎文章,通俗快速解读了XBM想法动机: 12 | 13 | [跨越时空的难样本挖掘](https://zhuanlan.zhihu.com/p/136522363) 14 | 15 | 欢迎大家阅读指点! 16 | 17 | 18 | 19 | Recommend one recently released excellent papers in DML not written by me: 20 | 21 | #### [A Metric Learning Reality Check](https://arxiv.org/pdf/2003.08505.pdf) 22 | 23 | from Cornell Tech and Facebook AI 24 | 25 | Abstract: Deep metric learning papers from the past four years have consistently claimed great advances in accuracy, often more than doubling the performance of decade-old methods. In this paper, we take a closer look at the field to see if this is actually true. We find flaws in the experimental setup of these papers, and propose a new way to evaluate metric learning algorithms. Finally, we present experimental results that show that the improvements over time have been marginal at best. 26 | 27 |    28 | 29 | ### XBM: A New Sota method for DML, accepted by CVPR-2020 as Oral and nominated as best paper: 30 | 31 | ### [Cross-Batch Memory for Embedding Learning](https://arxiv.org/pdf/1912.06798.pdf) 32 | 33 | - #### Great Improvement: XBM can improve the R@1 by 12~25% on three large-scale datasets 34 | 35 | - #### Easy to implement: with only several lines of codes 36 | 37 | - #### Memory efficient: with less than 1GB for large-scale datasets 38 | 39 | - #### Code has already been released: [xbm](https://github.com/MalongTech/research-ms-loss/blob/master/ret_benchmark/modeling/xbm.py) 40 | 41 | #### Other implementations: 42 | [pytorch-metric-learning](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#crossbatchmemory)(a great work by Kevin Musgrave) 43 | 44 |    45 |    46 | 47 | ### MS Loss based on GPW: Accepted by CVPR 2019 as Poster 48 | ### [Multi-Similarity Loss with General Pair Weighting for Deep Metric Learning](https://arxiv.org/pdf/1904.06627.pdf) 49 | - #### [code released link](https://github.com/MalongTech/research-ms-loss) 50 | - #### [New Version of paper ](https://arxiv.org/pdf/1904.06627.pdf), To make my idea to be understand easily, I have rewritten the major part of my paper recently to make it clear. (at 2020-03-24) 51 | 52 | 53 | 54 |   55 | ### Deep metric methods implemented in this repositories: 56 | 57 | - Contrasstive Loss [1] 58 | 59 | - Semi-Hard Mining Strategy [2] 60 | 61 | - Lifted Structure Loss* [3] (Modified version because of its original weak performance) 62 | 63 | - Binomial BinDeviance Loss [4] 64 | 65 | - NCA Loss [6] 66 | 67 | - Multi-Similarity Loss [7] 68 | 69 | ### Dataset 70 | - [Car-196](http://ai.stanford.edu/~jkrause/cars/car_devkit.tgz) 71 | 72 | first 98 classes as train set and last 98 classes as test set 73 | 74 | - [CUB-200-2011](http://www.vision.caltech.edu/visipedia-data/CUB-200/images.tgz) 75 | 76 | first 100 classes as train set and last 100 classes as test set 77 | 78 | - [Stanford-Online-Products](ftp://cs.stanford.edu/cs/cvgl/Stanford_Online_Products.zip) 79 | 80 | for the experiments, we split 59,551 images of 11,318 classes for training and 60,502 images of 11,316 classes for testing 81 | 82 | - [In-Shop-clothes-Retrieval](ftp://cs.stanford.edu/cs/cvgl/Stanford_Online_Products.zip) 83 | 84 | For the In-Shop Clothes Retrieval dataset, 3,997 classes with 25,882 images for training. 85 | And the test set are partitioned to query set with 3,985 classes(14,218 images) and gallery set with 3,985 classes (12,612 images). 86 | 87 | - [Processed CUB and Cars196](https://pan.baidu.com/s/1LPHi72JPupkvUy_1OIn6yA) 88 | 89 | Extract code: inmj 90 | 91 | To easily reimplement the performance, I provide the processed datasets: CUB and Cars-196. 92 | 93 | 94 | ### Requirements 95 | * Python >= 3.5 96 | * PyTorch = 1.0 97 | 98 | ### Comparasion with state-of-the-art on CUB-200 and Cars-196 99 | 100 | |Recall@K | 1 | 2 | 4 | 8 | 16 | 32 | 1 | 2 | 4 | 8 | 16 | 32| 101 | |:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:| 102 | |HDC | 53.6 | 65.7 | 77.0 | 85.6 | 91.5 | 95.5 | 73.7 | 83.2 | 89.5 | 93.8 | 96.7 | 98.4| 103 | |Clustering | 48.2 | 61.4 | 71.8 | 81.9 | - | - | 58.1 | 70.6 | 80.3 | 87.8 | - | -| 104 | |ProxyNCA | 49.2 | 61.9 | 67.9 | 72.4 | - | - | 73.2 | 82.4 | 86.4 | 87.8 | - | -| 105 | |Smart Mining | 49.8 | 62.3 | 74.1 | 83.3 | - | - | 64.7 | 76.2 | 84.2 | 90.2 | - | -| 106 | |Margin [5] | 63.6| 74.4| 83.1| 90.0| 94.2 | - | 79.6| 86.5| 91.9| 95.1| 97.3 | - | 107 | |HTL | 57.1| 68.8| 78.7| 86.5| 92.5| 95.5 | 81.4| 88.0| 92.7| 95.7| 97.4| 99.0 | 108 | |ABIER |57.5 |68.7 |78.3 |86.2 |91.9 |95.5 |82.0 |89.0 |93.2 |96.1 |97.8 |98.7| 109 | 110 | 111 | ### Comparasion with state-of-the-art on SOP and In-shop 112 | 113 | |Recall@K | 1 | 10 | 100 | 1000 | 1 | 10 | 20 | 30 | 40 | 50| 114 | |:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:| 115 | |Clustering | 67.0 | 83.7 | 93.2 | - | -| -| -| -| - | -| 116 | |HDC | 69.5 | 84.4 | 92.8 | 97.7 | 62.1 | 84.9 | 89.0 | 91.2 | 92.3 | 93.1| 117 | |Margin [5] | 72.7 | 86.2 | 93.8 | 98.0 | -| -| - | -| -| -| 118 | |Proxy-NCA | 73.7 | - | - | - | -| -| - | - | -| -| 119 | |ABIER | 74.2 | 86.9 | 94.0 | 97.8 | 83.1 | 95.1 | 96.9 | 97.5 | 97.8 | 98.0| 120 | |HTL | 74.8| 88.3| 94.8| 98.4 | 80.9| 94.3| 95.8| 97.2| 97.4| 97.8 || 121 | 122 | #### see more detail in our CVPR-2019 paper [Multi-Similarity Loss](https://arxiv.org/pdf/1904.06627.pdf) 123 | 124 | ##### Reproducing Car-196 (or CUB-200-2011) experiments 125 | *** weight :*** 126 | 127 | ```bash 128 | sh run_train_00.sh 129 | ``` 130 | ### Other implementations: 131 |

[Tensorflow] (by geonm) 132 | 133 | ### References 134 | 135 | [1] [R. Hadsell, S. Chopra, and Y. LeCun. Dimensionality reduction 136 | by learning an invariant mapping] 137 | 138 | [2] [F. Schroff, D. Kalenichenko, and J. Philbin. Facenet: A unified 139 | embedding for face recognition and clustering. In CVPR, 140 | 2015.] 141 | 142 | [3][H. Oh Song, Y. Xiang, S. Jegelka, and S. Savarese. Deep 143 | metric learning via lifted structured feature embedding. In 144 | CVPR, 2016.] 145 | 146 | [4][D. Yi, Z. Lei, and S. Z. Li. Deep metric learning for practical 147 | person re-identification.] 148 | 149 | [5][C. Wu, R. Manmatha, A. J. Smola, and P. Kr¨ahenb¨uhl. Sampling 150 | matters in deep embedding learning. ICCV, 2017.] 151 | 152 | [6][R. Salakhutdinov and G. Hinton. Learning a nonlinear embedding 153 | by preserving class neighbourhood structure. In 154 | AISTATS, 2007.] 155 | 156 | 157 | ### Citation 158 | 159 | If you use this method or this code in your research, please cite as: 160 | 161 | @inproceedings{wang2019multi, 162 | title={Multi-Similarity Loss with General Pair Weighting for Deep Metric Learning}, 163 | author={Wang, Xun and Han, Xintong and Huang, Weilin and Dong, Dengke and Scott, Matthew R}, 164 | booktitle={CVPR}, 165 | year={2019} 166 | } 167 | 168 | @inproceedings{wang2020xbm, 169 | title={Cross-Batch Memory for Embedding Learning}, 170 | author={Wang, Xun and Zhang, haozhi and Huang, Weilin and Scott, Matthew R}, 171 | booktitle={CVPR}, 172 | year={2020} 173 | } 174 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from . import losses 4 | from . import models 5 | from . import utils 6 | from . import evaluations 7 | from . import DataSet 8 | -------------------------------------------------------------------------------- /data_list_example.txt: -------------------------------------------------------------------------------- 1 | train/3136/161757894477_0.JPG 1 2 | train/3136/161757894477_3.JPG 1 3 | train/3136/161757894477_5.JPG 1 4 | train/3136/161757894477_1.JPG 1 5 | train/3136/161757894477_2.JPG 1 6 | train/3136/161757894477_4.JPG 1 7 | train/4140/111476369192_1.JPG 2 8 | train/4140/111476369192_0.JPG 2 9 | train/4140/111476369192_2.JPG 2 10 | train/3393/191615104477_1.JPG 3 11 | train/3393/191615104477_5.JPG 3 12 | train/3393/191615104477_3.JPG 3 13 | train/3393/191615104477_2.JPG 3 14 | train/3393/191615104477_4.JPG 3 15 | train/3393/191615104477_0.JPG 3 16 | train/8316/111163187109_0.JPG 4 17 | train/8316/111163187109_2.JPG 4 18 | train/8316/111163187109_1.JPG 4 19 | train/11002/171886261942_4.JPG 5 20 | train/11002/171886261942_1.JPG 5 21 | train/11002/171886261942_2.JPG 5 22 | train/11002/171886261942_6.JPG 5 23 | train/11002/171886261942_3.JPG 5 24 | train/11002/171886261942_0.JPG 5 25 | train/11002/171886261942_5.JPG 5 26 | -------------------------------------------------------------------------------- /evaluations/NMI.py: -------------------------------------------------------------------------------- 1 | from sklearn.cluster import KMeans 2 | from sklearn.metrics.cluster import normalized_mutual_info_score 3 | import numpy as np 4 | from utils import to_numpy 5 | 6 | 7 | def NMI(X, ground_truth, n_cluster=3): 8 | X = [to_numpy(x) for x in X] 9 | # list to numpy 10 | X = np.array(X) 11 | ground_truth = np.array(ground_truth) 12 | # print('x_type:', type(X)) 13 | # print('label_type:', type(ground_truth)) 14 | kmeans = KMeans(n_clusters=n_cluster, n_jobs=-1, random_state=0).fit(X) 15 | 16 | print('K-means done') 17 | nmi = normalized_mutual_info_score(ground_truth, kmeans.labels_) 18 | return nmi 19 | 20 | 21 | def main(): 22 | label = [1, 2, 3]*2 23 | 24 | X = np.array([[1, 2], [1, 4], [1, 0], 25 | [4, 2], [4, 4], [4, 0]]) 26 | 27 | print(NMI(X, label)) 28 | 29 | if __name__ == '__main__': 30 | main() 31 | -------------------------------------------------------------------------------- /evaluations/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import utils 3 | 4 | from .cnn import extract_cnn_feature 5 | from .extract_featrure import extract_features, pairwise_distance, pairwise_similarity 6 | from .recall_at_k import Recall_at_ks 7 | from .NMI import NMI 8 | from .top_k import Compute_top_k 9 | # from utils import to_torch 10 | -------------------------------------------------------------------------------- /evaluations/cnn.py: -------------------------------------------------------------------------------- 1 | # from collections import OrderedDict 2 | 3 | # from torch.autograd import Variable 4 | from utils import to_torch 5 | import torch 6 | from torch.autograd import Variable 7 | 8 | # def extract_cnn_feature(model, inputs, modules=None): 9 | # model.eval() 10 | # inputs = to_torch(inputs) 11 | # with torch.no_grad(): 12 | # inputs = inputs.cuda() 13 | # if modules is None: 14 | # outputs = model(inputs) 15 | # outputs = outputs.data 16 | # return outputs 17 | 18 | # # Register forward hook for each module 19 | # outputs = OrderedDict() 20 | # handles = [] 21 | # for m in modules: 22 | # outputs[id(m)] = None 23 | # def func(m, i, o): outputs[id(m)] = o.data 24 | # handles.append(m.register_forward_hook(func)) 25 | # model(inputs) 26 | # for h in handles: 27 | # h.remove() 28 | # return list(outputs.values()) 29 | 30 | 31 | def extract_cnn_feature(model, inputs, pool_feature=False): 32 | model.eval() 33 | with torch.no_grad(): 34 | inputs = to_torch(inputs) 35 | inputs = Variable(inputs).cuda() 36 | if pool_feature is False: 37 | outputs = model(inputs) 38 | return outputs 39 | else: 40 | # Register forward hook for each module 41 | outputs = {} 42 | 43 | 44 | def func(m, i, o): outputs['pool_feature'] = o.data.view(n, -1) 45 | hook = model.module._modules.get('features').register_forward_hook(func) 46 | model(inputs) 47 | hook.remove() 48 | # print(outputs['pool_feature'].shape) 49 | return outputs['pool_feature'] 50 | 51 | 52 | -------------------------------------------------------------------------------- /evaluations/extract_featrure.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import time 3 | from collections import OrderedDict 4 | 5 | import torch 6 | from utils import to_numpy 7 | import numpy as np 8 | 9 | from utils.meters import AverageMeter 10 | from evaluations.cnn import extract_cnn_feature 11 | 12 | 13 | def normalize(x): 14 | norm = x.norm(dim=1, p=2, keepdim=True) 15 | x = x.div(norm.expand_as(x)) 16 | return x 17 | 18 | 19 | def extract_features(model, data_loader, print_freq=1, metric=None, pool_feature=False): 20 | # model.eval() 21 | batch_time = AverageMeter() 22 | data_time = AverageMeter() 23 | 24 | feature_cpu = torch.FloatTensor() 25 | feature_gpu = torch.FloatTensor().cuda() 26 | 27 | trans_inter = 1e4 28 | labels = list() 29 | end = time.time() 30 | 31 | for i, (imgs, pids) in enumerate(data_loader): 32 | imgs = imgs 33 | outputs = extract_cnn_feature(model, imgs, pool_feature=pool_feature) 34 | feature_gpu = torch.cat((feature_gpu, outputs.data), 0) 35 | labels.extend(pids) 36 | count = feature_gpu.size(0) 37 | if count > trans_inter or i == len(data_loader)-1: 38 | # print(feature_gpu.size()) 39 | data_time.update(time.time() - end) 40 | end = time.time() 41 | # print('transfer to cpu {} / {}'.format(i+1, len(data_loader))) 42 | feature_cpu = torch.cat((feature_cpu, feature_gpu.cpu()), 0) 43 | feature_gpu = torch.FloatTensor().cuda() 44 | batch_time.update(time.time() - end) 45 | print('Extract Features: [{}/{}]\t' 46 | 'Time {:.3f} ({:.3f})\t' 47 | 'Data {:.3f} ({:.3f})\t' 48 | .format(i + 1, len(data_loader), 49 | batch_time.val, batch_time.avg, 50 | data_time.val, data_time.avg)) 51 | 52 | end = time.time() 53 | del outputs 54 | 55 | return feature_cpu, labels 56 | 57 | 58 | def pairwise_distance(features, metric=None): 59 | n = features.size(0) 60 | # normalize feature before test 61 | x = normalize(features) 62 | # print(4*'\n', x.size()) 63 | if metric is not None: 64 | x = metric.transform(x) 65 | dist = torch.pow(x, 2).sum(dim=1, keepdim=True) 66 | # print(dist.size()) 67 | dist = dist.expand(n, n) 68 | dist = dist + dist.t() 69 | dist = dist - 2 * torch.mm(x, x.t()) 70 | dist = torch.sqrt(dist) 71 | return dist 72 | 73 | 74 | def pairwise_similarity(x, y=None): 75 | if y is None: 76 | y = x 77 | # normalization 78 | y = normalize(y) 79 | x = normalize(x) 80 | # similarity 81 | similarity = torch.mm(x, y.t()) 82 | return similarity 83 | 84 | 85 | 86 | -------------------------------------------------------------------------------- /evaluations/recall_at_k.py: -------------------------------------------------------------------------------- 1 | # coding : utf-8 2 | from __future__ import absolute_import 3 | import numpy as np 4 | import torch 5 | from utils import to_numpy 6 | import time 7 | import random 8 | 9 | 10 | def Recall_at_ks(sim_mat, data='cub', query_ids=None, gallery_ids=None): 11 | # start_time = time.time() 12 | # print(start_time) 13 | """ 14 | :param sim_mat: 15 | :param query_ids 16 | :param gallery_ids 17 | :param data 18 | 19 | Compute [R@1, R@2, R@4, R@8] 20 | """ 21 | 22 | ks_dict = dict() 23 | ks_dict['cub'] = [1, 2, 4, 8, 16, 32] 24 | ks_dict['car'] = [1, 2, 4, 8, 16, 32] 25 | ks_dict['jd'] = [1, 2, 4, 8] 26 | ks_dict['product'] = [1, 10, 100, 1000] 27 | ks_dict['shop'] = [1, 10, 20, 30, 40, 50] 28 | 29 | if data is None: 30 | data = 'cub' 31 | k_s = ks_dict[data] 32 | 33 | sim_mat = to_numpy(sim_mat) 34 | m, n = sim_mat.shape 35 | gallery_ids = np.asarray(gallery_ids) 36 | if query_ids is None: 37 | query_ids = gallery_ids 38 | else: 39 | query_ids = np.asarray(query_ids) 40 | 41 | num_max = int(1e6) 42 | 43 | if m > num_max: 44 | samples = list(range(m)) 45 | random.shuffle(samples) 46 | samples = samples[:num_max] 47 | sim_mat = sim_mat[samples, :] 48 | query_ids = [query_ids[k] for k in samples] 49 | m = num_max 50 | 51 | # Hope to be much faster yes!! 52 | num_valid = np.zeros(len(k_s)) 53 | neg_nums = np.zeros(m) 54 | for i in range(m): 55 | x = sim_mat[i] 56 | 57 | pos_max = np.max(x[gallery_ids == query_ids[i]]) 58 | neg_num = np.sum(x > pos_max) 59 | neg_nums[i] = neg_num 60 | 61 | for i, k in enumerate(k_s): 62 | if i == 0: 63 | temp = np.sum(neg_nums < k) 64 | num_valid[i:] += temp 65 | else: 66 | temp = np.sum(neg_nums < k) 67 | num_valid[i:] += temp - num_valid[i-1] 68 | # t = time.time() - start_time 69 | # print(t) 70 | return num_valid / float(m) 71 | 72 | 73 | def test(): 74 | sim_mat = torch.rand(int(7e2), int(14e2)) 75 | sim_mat = to_numpy(sim_mat) 76 | query_ids = int(1e2)*list(range(7)) 77 | gallery_ids = int(2e2)*list(range(7)) 78 | gallery_ids = np.asarray(gallery_ids) 79 | query_ids = np.asarray(query_ids) 80 | print(Recall_at_ks(sim_mat, query_ids=query_ids, gallery_ids=gallery_ids, data='shop')) 81 | 82 | if __name__ == '__main__': 83 | test() 84 | -------------------------------------------------------------------------------- /evaluations/top_k.py: -------------------------------------------------------------------------------- 1 | # coding : utf-8 2 | from __future__ import absolute_import 3 | import numpy as np 4 | # from utils import to_numpy 5 | import time 6 | # import bottleneck 7 | import random 8 | import torch 9 | import heapq 10 | 11 | 12 | def to_numpy(tensor): 13 | if torch.is_tensor(tensor): 14 | return tensor.cpu().numpy() 15 | elif type(tensor).__module__ != 'numpy': 16 | raise ValueError("Cannot convert {} to numpy array" 17 | .format(type(tensor))) 18 | return tensor 19 | 20 | 21 | def Compute_top_k(sim_mat, k=10): 22 | # start_time = time.time() 23 | # print(start_time) 24 | """ 25 | :param sim_mat: 26 | 27 | Compute 28 | top-k in gallery for each query 29 | """ 30 | 31 | sim_mat = to_numpy(sim_mat) 32 | m, n = sim_mat.shape 33 | print('query number is %d' % m) 34 | print('gallery number is %d' % n) 35 | 36 | top_k = np.zeros([m, k]) 37 | 38 | for i in range(m): 39 | sim_i = sim_mat[i] 40 | idx = heapq.nlargest(k, range(len(sim_i)), sim_i.take) 41 | top_k[i] = idx 42 | return top_k 43 | 44 | 45 | def test(): 46 | import torch 47 | sim_mat = torch.rand(int(1e1), int(1e2)) 48 | sim_mat = to_numpy(sim_mat) 49 | print(Compute_top_k(sim_mat, k=3)) 50 | 51 | if __name__ == '__main__': 52 | test() 53 | -------------------------------------------------------------------------------- /losses/Binomial.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Variable 6 | import numpy as np 7 | 8 | 9 | class BinomialLoss(nn.Module): 10 | def __init__(self, alpha=40, beta=0, margin=0.5, hard_mining=None, **kwargs): 11 | super(BinomialLoss, self).__init__() 12 | self.margin = margin 13 | self.alpha = alpha 14 | self.beta = beta 15 | self.hard_mining = hard_mining 16 | 17 | def forward(self, inputs, targets): 18 | n = inputs.size(0) 19 | sim_mat = torch.matmul(inputs, inputs.t()) 20 | targets = targets 21 | 22 | base = 0.5 23 | loss = list() 24 | c = 0 25 | 26 | for i in range(n): 27 | pos_pair_ = torch.masked_select(sim_mat[i], targets==targets[i]) 28 | 29 | # move itself 30 | pos_pair_ = torch.masked_select(pos_pair_, pos_pair_ < 1) 31 | neg_pair_ = torch.masked_select(sim_mat[i], targets!=targets[i]) 32 | 33 | pos_pair_ = torch.sort(pos_pair_)[0] 34 | neg_pair_ = torch.sort(neg_pair_)[0] 35 | 36 | if self.hard_mining is not None: 37 | 38 | neg_pair = torch.masked_select(neg_pair_, neg_pair_ + 0.1 > pos_pair_[0]) 39 | pos_pair = torch.masked_select(pos_pair_, pos_pair_ - 0.1 < neg_pair_[-1]) 40 | 41 | if len(neg_pair) < 1 or len(pos_pair) < 1: 42 | c += 1 43 | continue 44 | 45 | pos_loss = 2.0/self.beta * torch.mean(torch.log(1 + torch.exp(-self.beta*(pos_pair - 0.5)))) 46 | neg_loss = 2.0/self.alpha * torch.mean(torch.log(1 + torch.exp(self.alpha*(neg_pair - 0.5)))) 47 | 48 | else: 49 | pos_pair = pos_pair_ 50 | neg_pair = neg_pair_ 51 | 52 | pos_loss = torch.mean(torch.log(1 + torch.exp(-2*(pos_pair - self.margin)))) 53 | neg_loss = torch.mean(torch.log(1 + torch.exp(self.alpha*(neg_pair - self.margin)))) 54 | 55 | if len(neg_pair) == 0: 56 | c += 1 57 | continue 58 | 59 | loss.append(pos_loss + neg_loss) 60 | 61 | loss = sum(loss)/n 62 | prec = float(c)/n 63 | mean_neg_sim = torch.mean(neg_pair_).item() 64 | mean_pos_sim = torch.mean(pos_pair_).item() 65 | return loss, prec, mean_pos_sim, mean_neg_sim 66 | 67 | def main(): 68 | data_size = 32 69 | input_dim = 3 70 | output_dim = 2 71 | num_class = 4 72 | # margin = 0.5 73 | x = Variable(torch.rand(data_size, input_dim), requires_grad=False) 74 | # print(x) 75 | w = Variable(torch.rand(input_dim, output_dim), requires_grad=True) 76 | inputs = x.mm(w) 77 | y_ = 8*list(range(num_class)) 78 | targets = Variable(torch.IntTensor(y_)) 79 | 80 | print(BinomialLoss()(inputs, targets)) 81 | 82 | 83 | if __name__ == '__main__': 84 | main() 85 | print('Congratulations to you!') 86 | 87 | 88 | -------------------------------------------------------------------------------- /losses/Contrastive.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Variable 6 | import numpy as np 7 | 8 | 9 | class ContrastiveLoss(nn.Module): 10 | def __init__(self, margin=0.5, **kwargs): 11 | super(ContrastiveLoss, self).__init__() 12 | self.margin = margin 13 | 14 | def forward(self, inputs, targets): 15 | n = inputs.size(0) 16 | # Compute similarity matrix 17 | sim_mat = torch.matmul(inputs, inputs.t()) 18 | targets = targets 19 | loss = list() 20 | c = 0 21 | 22 | for i in range(n): 23 | pos_pair_ = torch.masked_select(sim_mat[i], targets==targets[i]) 24 | 25 | # move itself 26 | pos_pair_ = torch.masked_select(pos_pair_, pos_pair_ < 1) 27 | neg_pair_ = torch.masked_select(sim_mat[i], targets!=targets[i]) 28 | 29 | pos_pair_ = torch.sort(pos_pair_)[0] 30 | neg_pair_ = torch.sort(neg_pair_)[0] 31 | 32 | neg_pair = torch.masked_select(neg_pair_, neg_pair_ > self.margin) 33 | 34 | neg_loss = 0 35 | 36 | pos_loss = torch.sum(-pos_pair_+1) 37 | if len(neg_pair) > 0: 38 | neg_loss = torch.sum(neg_pair) 39 | loss.append(pos_loss + neg_loss) 40 | 41 | loss = sum(loss)/n 42 | prec = float(c)/n 43 | mean_neg_sim = torch.mean(neg_pair_).item() 44 | mean_pos_sim = torch.mean(pos_pair_).item() 45 | return loss, prec, mean_pos_sim, mean_neg_sim 46 | 47 | 48 | 49 | def main(): 50 | data_size = 32 51 | input_dim = 3 52 | output_dim = 2 53 | num_class = 4 54 | # margin = 0.5 55 | x = Variable(torch.rand(data_size, input_dim), requires_grad=False) 56 | # print(x) 57 | w = Variable(torch.rand(input_dim, output_dim), requires_grad=True) 58 | inputs = x.mm(w) 59 | y_ = 8*list(range(num_class)) 60 | targets = Variable(torch.IntTensor(y_)) 61 | 62 | print(ContrastiveLoss()(inputs, targets)) 63 | 64 | 65 | if __name__ == '__main__': 66 | main() 67 | print('Congratulations to you!') 68 | 69 | 70 | -------------------------------------------------------------------------------- /losses/HardMining.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Variable 6 | import numpy as np 7 | 8 | 9 | def similarity(inputs_): 10 | # Compute similarity mat of deep feature 11 | # n = inputs_.size(0) 12 | sim = torch.matmul(inputs_, inputs_.t()) 13 | return sim 14 | 15 | 16 | class HardMiningLoss(nn.Module): 17 | def __init__(self, beta=None, margin=0, **kwargs): 18 | super(HardMiningLoss, self).__init__() 19 | self.beta = beta 20 | self.margin = 0.1 21 | 22 | def forward(self, inputs, targets): 23 | n = inputs.size(0) 24 | sim_mat = torch.matmul(inputs, inputs.t()) 25 | targets = targets 26 | 27 | base = 0.5 28 | loss = list() 29 | c = 0 30 | 31 | for i in range(n): 32 | pos_pair_ = torch.masked_select(sim_mat[i], targets==targets[i]) 33 | 34 | # move itself 35 | pos_pair_ = torch.masked_select(pos_pair_, pos_pair_ < 1) 36 | neg_pair_ = torch.masked_select(sim_mat[i], targets!=targets[i]) 37 | 38 | pos_pair_ = torch.sort(pos_pair_)[0] 39 | neg_pair_ = torch.sort(neg_pair_)[0] 40 | 41 | neg_pair = torch.masked_select(neg_pair_, neg_pair_ > pos_pair_[0] - self.margin) 42 | pos_pair = torch.masked_select(pos_pair_, pos_pair_ < neg_pair_[-1] + self.margin) 43 | # pos_pair = pos_pair[1:] 44 | if len(neg_pair) < 1: 45 | c += 1 46 | continue 47 | 48 | pos_loss = torch.mean(1 - pos_pair) 49 | neg_loss = torch.mean(neg_pair) 50 | # pos_loss = torch.mean(torch.log(1 + torch.exp(-2*(pos_pair - self.margin)))) 51 | # neg_loss = 0.04*torch.mean(torch.log(1 + torch.exp(50*(neg_pair - self.margin)))) 52 | loss.append(pos_loss + neg_loss) 53 | 54 | loss = sum(loss)/n 55 | prec = float(c)/n 56 | mean_neg_sim = torch.mean(neg_pair_).item() 57 | mean_pos_sim = torch.mean(pos_pair_).item() 58 | return loss, prec, mean_pos_sim, mean_neg_sim 59 | 60 | 61 | def main(): 62 | data_size = 32 63 | input_dim = 3 64 | output_dim = 2 65 | num_class = 4 66 | # margin = 0.5 67 | x = Variable(torch.rand(data_size, input_dim), requires_grad=False) 68 | # print(x) 69 | w = Variable(torch.rand(input_dim, output_dim), requires_grad=True) 70 | inputs = x.mm(w) 71 | y_ = 8*list(range(num_class)) 72 | targets = Variable(torch.IntTensor(y_)) 73 | 74 | print(HardMiningLoss()(inputs, targets)) 75 | 76 | 77 | if __name__ == '__main__': 78 | main() 79 | print('Congratulations to you!') 80 | 81 | 82 | -------------------------------------------------------------------------------- /losses/LiftedStructure.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Variable 6 | import numpy as np 7 | 8 | 9 | class LiftedStructureLoss(nn.Module): 10 | def __init__(self, alpha=40, beta=2, margin=0.5, hard_mining=None, **kwargs): 11 | super(LiftedStructureLoss, self).__init__() 12 | self.margin = margin 13 | self.alpha = alpha 14 | self.beta = beta 15 | self.hard_mining = hard_mining 16 | 17 | def forward(self, inputs, targets): 18 | n = inputs.size(0) 19 | sim_mat = torch.matmul(inputs, inputs.t()) 20 | targets = targets 21 | loss = list() 22 | c = 0 23 | 24 | for i in range(n): 25 | pos_pair_ = torch.masked_select(sim_mat[i], targets==targets[i]) 26 | 27 | # move itself 28 | pos_pair_ = torch.masked_select(pos_pair_, pos_pair_ < 1) 29 | neg_pair_ = torch.masked_select(sim_mat[i], targets!=targets[i]) 30 | 31 | pos_pair_ = torch.sort(pos_pair_)[0] 32 | neg_pair_ = torch.sort(neg_pair_)[0] 33 | 34 | if self.hard_mining is not None: 35 | 36 | neg_pair = torch.masked_select(neg_pair_, neg_pair_ + 0.1 > pos_pair_[0]) 37 | pos_pair = torch.masked_select(pos_pair_, pos_pair_ - 0.1 < neg_pair_[-1]) 38 | 39 | if len(neg_pair) < 1 or len(pos_pair) < 1: 40 | c += 1 41 | continue 42 | 43 | pos_loss = 2.0/self.beta * torch.log(torch.sum(torch.exp(-self.beta*pos_pair))) 44 | neg_loss = 2.0/self.alpha * torch.log(torch.sum(torch.exp(self.alpha*neg_pair))) 45 | 46 | else: 47 | pos_pair = pos_pair_ 48 | neg_pair = neg_pair_ 49 | 50 | pos_loss = 2.0/self.beta * torch.log(torch.sum(torch.exp(-self.beta*pos_pair))) 51 | neg_loss = 2.0/self.alpha * torch.log(torch.sum(torch.exp(self.alpha*neg_pair))) 52 | 53 | if len(neg_pair) == 0: 54 | c += 1 55 | continue 56 | 57 | loss.append(pos_loss + neg_loss) 58 | loss = sum(loss)/n 59 | prec = float(c)/n 60 | mean_neg_sim = torch.mean(neg_pair_).item() 61 | mean_pos_sim = torch.mean(pos_pair_).item() 62 | return loss, prec, mean_pos_sim, mean_neg_sim 63 | 64 | 65 | def main(): 66 | data_size = 32 67 | input_dim = 3 68 | output_dim = 2 69 | num_class = 4 70 | x = Variable(torch.rand(data_size, input_dim), requires_grad=False) 71 | w = Variable(torch.rand(input_dim, output_dim), requires_grad=True) 72 | inputs = x.mm(w) 73 | y_ = 8*list(range(num_class)) 74 | targets = Variable(torch.IntTensor(y_)) 75 | 76 | print(LiftedStructureLoss()(inputs, targets)) 77 | 78 | 79 | if __name__ == '__main__': 80 | main() 81 | print('Congratulations to you!') 82 | 83 | -------------------------------------------------------------------------------- /losses/NCA.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import absolute_import 3 | 4 | import torch 5 | from torch import nn 6 | from torch.autograd import Variable 7 | import numpy as np 8 | 9 | 10 | class NCALoss(nn.Module): 11 | def __init__(self, alpha=16, k=32, **kwargs): 12 | super(NCALoss, self).__init__() 13 | self.alpha = alpha 14 | self.K = k 15 | 16 | def forward(self, inputs, targets): 17 | n = inputs.size(0) 18 | sim_mat = torch.matmul(inputs, inputs.t()) 19 | targets = targets 20 | 21 | base = 0.5 22 | loss = list() 23 | c = 0 24 | 25 | for i in range(n): 26 | pos_pair_ = torch.masked_select(sim_mat[i], targets==targets[i]) 27 | 28 | # move itself 29 | pos_pair_ = torch.masked_select(pos_pair_, pos_pair_ < 1) 30 | neg_pair_ = torch.masked_select(sim_mat[i], targets!=targets[i]) 31 | 32 | pos_pair = torch.sort(pos_pair_)[0] 33 | neg_pair = torch.sort(neg_pair_)[0] 34 | 35 | # 第K+1个近邻点到Anchor的距离值 36 | pair = torch.cat([pos_pair, neg_pair]) 37 | threshold = torch.sort(pair)[0][self.K] 38 | 39 | # 取出K近邻中的正样本对和负样本对 40 | pos_neig = torch.masked_select(pos_pair, pos_pair < threshold) 41 | neg_neig = torch.masked_select(neg_pair, neg_pair < threshold) 42 | 43 | # 若前K个近邻中没有正样本,则仅取最近正样本 44 | if len(pos_neig) == 0: 45 | pos_neig = pos_pair[0] 46 | 47 | base = torch.mean(sim_mat[i]).item() 48 | # 计算logit, base的作用是防止超过计算机浮点数 49 | pos_logit = torch.sum(torch.exp(self.alpha*(base - pos_neig))) 50 | neg_logit = torch.sum(torch.exp(self.alpha*(base - neg_neig))) 51 | loss_ = -torch.log(pos_logit/(pos_logit + neg_logit)) 52 | 53 | if loss_.data[0] < 0.6: 54 | acc_num += 1 55 | loss.append(loss_) 56 | loss = sum(loss)/n 57 | prec = float(c)/n 58 | mean_neg_sim = torch.mean(neg_pair_).item() 59 | mean_pos_sim = torch.mean(pos_pair_).item() 60 | return loss, prec, mean_pos_sim, mean_neg_sim 61 | 62 | 63 | 64 | def main(): 65 | data_size = 32 66 | input_dim = 3 67 | output_dim = 2 68 | num_class = 4 69 | # margin = 0.5 70 | x = Variable(torch.rand(data_size, input_dim), requires_grad=False) 71 | w = Variable(torch.rand(input_dim, output_dim), requires_grad=True) 72 | inputs = x.mm(w) 73 | y_ = 8*list(range(num_class)) 74 | targets = Variable(torch.IntTensor(y_)) 75 | 76 | print(NCALoss(alpha=30)(inputs, targets)) 77 | 78 | 79 | if __name__ == '__main__': 80 | main() 81 | print('Congratulations to you!') 82 | -------------------------------------------------------------------------------- /losses/SemiHard.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Variable 6 | import numpy as np 7 | 8 | 9 | def similarity(inputs_): 10 | # Compute similarity mat of deep feature 11 | # n = inputs_.size(0) 12 | sim = torch.matmul(inputs_, inputs_.t()) 13 | return sim 14 | 15 | 16 | class SemiHardLoss(nn.Module): 17 | def __init__(self, alpha=0, beta=None, margin=0, **kwargs): 18 | super(SemiHardLoss, self).__init__() 19 | self.beta = beta 20 | self.margin = margin 21 | self.alpha = alpha 22 | 23 | def forward(self, inputs, targets): 24 | n = inputs.size(0) 25 | # Compute similarity matrixr® 26 | sim_mat = similarity(inputs) 27 | # print(sim_mat) 28 | targets = targets.cuda() 29 | # split the positive and negative pairs 30 | eyes_ = Variable(torch.eye(n, n)).cuda() 31 | # eyes_ = Variable(torch.eye(n, n)) 32 | pos_mask = targets.expand(n, n).eq(targets.expand(n, n).t()) 33 | neg_mask = eyes_.eq(eyes_) - pos_mask 34 | pos_mask = pos_mask - eyes_.eq(1) 35 | 36 | pos_sim = torch.masked_select(sim_mat, pos_mask) 37 | neg_sim = torch.masked_select(sim_mat, neg_mask) 38 | 39 | num_instances = len(pos_sim)//n + 1 40 | num_neg_instances = n - num_instances 41 | 42 | pos_sim = pos_sim.resize(len(pos_sim)//(num_instances-1), num_instances-1) 43 | neg_sim = neg_sim.resize( 44 | len(neg_sim) // num_neg_instances, num_neg_instances) 45 | 46 | # clear way to compute the loss first 47 | loss = list() 48 | c = 0 49 | base = 0.5 50 | for i, pos_pair_ in enumerate(pos_sim): 51 | # print(i) 52 | pos_pair_ = torch.sort(pos_pair_)[0] 53 | neg_pair_ = torch.sort(neg_sim[i])[0] 54 | 55 | neg_pair = torch.masked_select(neg_pair_, neg_pair_ < pos_pair_[-1]) 56 | pos_pair = torch.masked_select(pos_pair_, pos_pair_ > neg_pair_[0]) 57 | 58 | 59 | pos_loss = 2.0/self.beta * torch.log(1 + torch.sum(torch.exp(-self.beta * (pos_pair - base)))) 60 | else: 61 | pos_loss = 0*torch.mean(1 - pos_pair_) 62 | 63 | if len(neg_pair)>0: 64 | # neg_loss = torch.mean(neg_pair) 65 | neg_loss = 2.0/self.alpha * torch.log(1 + torch.sum(torch.exp(self.alpha * (neg_pair - base)))) 66 | else: 67 | neg_loss = 0*torch.mean(neg_pair_) 68 | loss.append(pos_loss + neg_loss) 69 | 70 | loss = sum(loss)/n 71 | prec = float(c)/n 72 | mean_neg_sim = torch.mean(neg_pair_).item() 73 | mean_pos_sim = torch.mean(pos_pair_).item() 74 | return loss, prec, mean_pos_sim, mean_neg_sim 75 | 76 | def main(): 77 | data_size = 32 78 | input_dim = 3 79 | output_dim = 2 80 | num_class = 4 81 | # margin = 0.5 82 | x = Variable(torch.rand(data_size, input_dim), requires_grad=False) 83 | # print(x) 84 | w = Variable(torch.rand(input_dim, output_dim), requires_grad=True) 85 | inputs = x.mm(w) 86 | y_ = 8*list(range(num_class)) 87 | targets = Variable(torch.IntTensor(y_)) 88 | 89 | print(SemiHardLoss()(inputs, targets)) 90 | 91 | 92 | if __name__ == '__main__': 93 | main() 94 | print('Congratulations to you!') 95 | 96 | 97 | -------------------------------------------------------------------------------- /losses/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | 3 | from .NCA import NCALoss 4 | from .Contrastive import ContrastiveLoss 5 | from .Binomial import BinomialLoss 6 | from .LiftedStructure import LiftedStructureLoss 7 | # from .Weight import WeightLoss 8 | from .HardMining import HardMiningLoss 9 | 10 | __factory = { 11 | 'NCA': NCALoss, 12 | 'Contrastive': ContrastiveLoss, 13 | 'Binomial': BinomialLoss, 14 | 'LiftedStructure': LiftedStructureLoss, 15 | # 'Weight': WeightLoss, 16 | 'HardMining': HardMiningLoss, 17 | } 18 | 19 | 20 | def names(): 21 | return sorted(__factory.keys()) 22 | 23 | 24 | def create(name, *args, **kwargs): 25 | """ 26 | Create a loss instance. 27 | 28 | Parameters 29 | ---------- 30 | name : str 31 | the name of loss function 32 | """ 33 | if name not in __factory: 34 | raise KeyError("Unknown loss:", name) 35 | return __factory[name](*args, **kwargs) 36 | 37 | 38 | 39 | def names(): 40 | return sorted(__factory.keys()) 41 | 42 | 43 | def create(name, *args, **kwargs): 44 | """ 45 | Create a loss instance. 46 | 47 | Parameters 48 | ---------- 49 | name : str 50 | the name of loss function 51 | """ 52 | if name not in __factory: 53 | raise KeyError("Unknown loss:", name) 54 | return __factory[name]( *args, **kwargs) 55 | -------------------------------------------------------------------------------- /losses/triplet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Variable 6 | import numpy as np 7 | 8 | 9 | def similarity(inputs_): 10 | # Compute similarity mat of deep feature 11 | # n = inputs_.size(0) 12 | sim = torch.matmul(inputs_, inputs_.t()) 13 | return sim 14 | 15 | 16 | class HardMiningLoss(nn.Module): 17 | def __init__(self, beta=None, margin=0, **kwargs): 18 | super(HardMiningLoss, self).__init__() 19 | self.beta = beta 20 | self.margin = margin 21 | 22 | def forward(self, inputs, targets): 23 | n = inputs.size(0) 24 | # Compute similarity matrix 25 | sim_mat = similarity(inputs) 26 | # print(sim_mat) 27 | targets = targets.cuda() 28 | # split the positive and negative pairs 29 | eyes_ = Variable(torch.eye(n, n)).cuda() 30 | # eyes_ = Variable(torch.eye(n, n)) 31 | pos_mask = targets.expand(n, n).eq(targets.expand(n, n).t()) 32 | neg_mask = eyes_.eq(eyes_) - pos_mask 33 | pos_mask = pos_mask - eyes_.eq(1) 34 | 35 | pos_sim = torch.masked_select(sim_mat, pos_mask) 36 | neg_sim = torch.masked_select(sim_mat, neg_mask) 37 | 38 | num_instances = len(pos_sim)//n + 1 39 | num_neg_instances = n - num_instances 40 | 41 | pos_sim = pos_sim.resize(len(pos_sim)//(num_instances-1), num_instances-1) 42 | neg_sim = neg_sim.resize( 43 | len(neg_sim) // num_neg_instances, num_neg_instances) 44 | 45 | # clear way to compute the loss first 46 | loss = list() 47 | c = 0 48 | 49 | for i, pos_pair_ in enumerate(pos_sim): 50 | # print(i) 51 | pos_pair_ = torch.sort(pos_pair_)[0] 52 | neg_pair_ = torch.sort(neg_sim[i])[0] 53 | 54 | neg_pair = torch.masked_select(neg_pair_, neg_pair_ > pos_pair_[0] - self.margin) 55 | pos_pair = torch.masked_select(pos_pair_, pos_pair_ < neg_pair_[-1] + self.margin) 56 | # pos_pair = pos_pair[1:] 57 | if len(neg_pair) < 1: 58 | c += 1 59 | continue 60 | 61 | pos_loss = torch.mean(1 - pos_pair) 62 | neg_loss = torch.mean(neg_pair) 63 | loss.append(pos_loss + neg_loss) 64 | 65 | loss = sum(loss)/n 66 | prec = float(c)/n 67 | mean_neg_sim = torch.mean(neg_pair_).item() 68 | mean_pos_sim = torch.mean(pos_pair_).item() 69 | return loss, prec, mean_pos_sim, mean_neg_sim 70 | 71 | 72 | def main(): 73 | data_size = 32 74 | input_dim = 3 75 | output_dim = 2 76 | num_class = 4 77 | # margin = 0.5 78 | x = Variable(torch.rand(data_size, input_dim), requires_grad=False) 79 | # print(x) 80 | w = Variable(torch.rand(input_dim, output_dim), requires_grad=True) 81 | inputs = x.mm(w) 82 | y_ = 8*list(range(num_class)) 83 | targets = Variable(torch.IntTensor(y_)) 84 | 85 | print(HardMiningLoss()(inputs, targets)) 86 | 87 | 88 | if __name__ == '__main__': 89 | main() 90 | print('Congratulations to you!') 91 | 92 | 93 | -------------------------------------------------------------------------------- /models/BN_Inception.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | class Embedding(nn.Module): 7 | def __init__(self, in_dim, out_dim, dropout=None, normalized=True): 8 | super(Embedding, self).__init__() 9 | self.bn = nn.BatchNorm2d(in_dim, eps=1e-5) 10 | self.linear = nn.Linear(in_features=in_dim, out_features=out_dim) 11 | self.dropout = dropout 12 | self.normalized = normalized 13 | 14 | def forward(self, x): 15 | # x = self.bn(x) 16 | # x = F.relu(x, inplace=True) 17 | if self.dropout is not None: 18 | x = nn.Dropout(p=self.dropout)(x, inplace=True) 19 | x = self.linear(x) 20 | if self.normalized: 21 | norm = x.norm(dim=1, p=2, keepdim=True) 22 | x = x.div(norm.expand_as(x)) 23 | return x 24 | 25 | class BNInception(nn.Module): 26 | def __init__(self, dim=512): 27 | super(BNInception, self).__init__() 28 | self.dim = dim 29 | 30 | inplace = True 31 | self.conv1_7x7_s2 = nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3)) 32 | self.conv1_7x7_s2_bn = nn.BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True) 33 | self.conv1_relu_7x7 = nn.ReLU (inplace) 34 | self.pool1_3x3_s2 = nn.MaxPool2d ((3, 3), stride=(2, 2), dilation=(1, 1), ceil_mode=True) 35 | self.conv2_3x3_reduce = nn.Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1)) 36 | self.conv2_3x3_reduce_bn = nn.BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True) 37 | self.conv2_relu_3x3_reduce = nn.ReLU (inplace) 38 | self.conv2_3x3 = nn.Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 39 | self.conv2_3x3_bn = nn.BatchNorm2d(192, eps=1e-05, momentum=0.9, affine=True) 40 | self.conv2_relu_3x3 = nn.ReLU (inplace) 41 | self.pool2_3x3_s2 = nn.MaxPool2d ((3, 3), stride=(2, 2), dilation=(1, 1), ceil_mode=True) 42 | self.inception_3a_1x1 = nn.Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1)) 43 | self.inception_3a_1x1_bn = nn.BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True) 44 | self.inception_3a_relu_1x1 = nn.ReLU (inplace) 45 | self.inception_3a_3x3_reduce = nn.Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1)) 46 | self.inception_3a_3x3_reduce_bn = nn.BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True) 47 | self.inception_3a_relu_3x3_reduce = nn.ReLU (inplace) 48 | self.inception_3a_3x3 = nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 49 | self.inception_3a_3x3_bn = nn.BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True) 50 | self.inception_3a_relu_3x3 = nn.ReLU (inplace) 51 | self.inception_3a_double_3x3_reduce = nn.Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1)) 52 | self.inception_3a_double_3x3_reduce_bn = nn.BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True) 53 | self.inception_3a_relu_double_3x3_reduce = nn.ReLU (inplace) 54 | self.inception_3a_double_3x3_1 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 55 | self.inception_3a_double_3x3_1_bn = nn.BatchNorm2d(96, eps=1e-05, momentum=0.9, affine=True) 56 | self.inception_3a_relu_double_3x3_1 = nn.ReLU (inplace) 57 | self.inception_3a_double_3x3_2 = nn.Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 58 | self.inception_3a_double_3x3_2_bn = nn.BatchNorm2d(96, eps=1e-05, momentum=0.9, affine=True) 59 | self.inception_3a_relu_double_3x3_2 = nn.ReLU (inplace) 60 | self.inception_3a_pool = nn.AvgPool2d (3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) 61 | self.inception_3a_pool_proj = nn.Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1)) 62 | self.inception_3a_pool_proj_bn = nn.BatchNorm2d(32, eps=1e-05, momentum=0.9, affine=True) 63 | self.inception_3a_relu_pool_proj = nn.ReLU (inplace) 64 | self.inception_3b_1x1 = nn.Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1)) 65 | self.inception_3b_1x1_bn = nn.BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True) 66 | self.inception_3b_relu_1x1 = nn.ReLU (inplace) 67 | self.inception_3b_3x3_reduce = nn.Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1)) 68 | self.inception_3b_3x3_reduce_bn = nn.BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True) 69 | self.inception_3b_relu_3x3_reduce = nn.ReLU (inplace) 70 | self.inception_3b_3x3 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 71 | self.inception_3b_3x3_bn = nn.BatchNorm2d(96, eps=1e-05, momentum=0.9, affine=True) 72 | self.inception_3b_relu_3x3 = nn.ReLU (inplace) 73 | self.inception_3b_double_3x3_reduce = nn.Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1)) 74 | self.inception_3b_double_3x3_reduce_bn = nn.BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True) 75 | self.inception_3b_relu_double_3x3_reduce = nn.ReLU (inplace) 76 | self.inception_3b_double_3x3_1 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 77 | self.inception_3b_double_3x3_1_bn = nn.BatchNorm2d(96, eps=1e-05, momentum=0.9, affine=True) 78 | self.inception_3b_relu_double_3x3_1 = nn.ReLU (inplace) 79 | self.inception_3b_double_3x3_2 = nn.Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 80 | self.inception_3b_double_3x3_2_bn = nn.BatchNorm2d(96, eps=1e-05, momentum=0.9, affine=True) 81 | self.inception_3b_relu_double_3x3_2 = nn.ReLU (inplace) 82 | self.inception_3b_pool = nn.AvgPool2d (3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) 83 | self.inception_3b_pool_proj = nn.Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1)) 84 | self.inception_3b_pool_proj_bn = nn.BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True) 85 | self.inception_3b_relu_pool_proj = nn.ReLU (inplace) 86 | self.inception_3c_3x3_reduce = nn.Conv2d(320, 128, kernel_size=(1, 1), stride=(1, 1)) 87 | self.inception_3c_3x3_reduce_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) 88 | self.inception_3c_relu_3x3_reduce = nn.ReLU (inplace) 89 | self.inception_3c_3x3 = nn.Conv2d(128, 160, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) 90 | self.inception_3c_3x3_bn = nn.BatchNorm2d(160, eps=1e-05, momentum=0.9, affine=True) 91 | self.inception_3c_relu_3x3 = nn.ReLU (inplace) 92 | self.inception_3c_double_3x3_reduce = nn.Conv2d(320, 64, kernel_size=(1, 1), stride=(1, 1)) 93 | self.inception_3c_double_3x3_reduce_bn = nn.BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True) 94 | self.inception_3c_relu_double_3x3_reduce = nn.ReLU (inplace) 95 | self.inception_3c_double_3x3_1 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 96 | self.inception_3c_double_3x3_1_bn = nn.BatchNorm2d(96, eps=1e-05, momentum=0.9, affine=True) 97 | self.inception_3c_relu_double_3x3_1 = nn.ReLU (inplace) 98 | self.inception_3c_double_3x3_2 = nn.Conv2d(96, 96, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) 99 | self.inception_3c_double_3x3_2_bn = nn.BatchNorm2d(96, eps=1e-05, momentum=0.9, affine=True) 100 | self.inception_3c_relu_double_3x3_2 = nn.ReLU (inplace) 101 | self.inception_3c_pool = nn.MaxPool2d ((3, 3), stride=(2, 2), dilation=(1, 1), ceil_mode=True) 102 | self.inception_4a_1x1 = nn.Conv2d(576, 224, kernel_size=(1, 1), stride=(1, 1)) 103 | self.inception_4a_1x1_bn = nn.BatchNorm2d(224, eps=1e-05, momentum=0.9, affine=True) 104 | self.inception_4a_relu_1x1 = nn.ReLU (inplace) 105 | self.inception_4a_3x3_reduce = nn.Conv2d(576, 64, kernel_size=(1, 1), stride=(1, 1)) 106 | self.inception_4a_3x3_reduce_bn = nn.BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True) 107 | self.inception_4a_relu_3x3_reduce = nn.ReLU (inplace) 108 | self.inception_4a_3x3 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 109 | self.inception_4a_3x3_bn = nn.BatchNorm2d(96, eps=1e-05, momentum=0.9, affine=True) 110 | self.inception_4a_relu_3x3 = nn.ReLU (inplace) 111 | self.inception_4a_double_3x3_reduce = nn.Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1)) 112 | self.inception_4a_double_3x3_reduce_bn = nn.BatchNorm2d(96, eps=1e-05, momentum=0.9, affine=True) 113 | self.inception_4a_relu_double_3x3_reduce = nn.ReLU (inplace) 114 | self.inception_4a_double_3x3_1 = nn.Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 115 | self.inception_4a_double_3x3_1_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) 116 | self.inception_4a_relu_double_3x3_1 = nn.ReLU (inplace) 117 | self.inception_4a_double_3x3_2 = nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 118 | self.inception_4a_double_3x3_2_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) 119 | self.inception_4a_relu_double_3x3_2 = nn.ReLU (inplace) 120 | self.inception_4a_pool = nn.AvgPool2d (3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) 121 | self.inception_4a_pool_proj = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1)) 122 | self.inception_4a_pool_proj_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) 123 | self.inception_4a_relu_pool_proj = nn.ReLU (inplace) 124 | self.inception_4b_1x1 = nn.Conv2d(576, 192, kernel_size=(1, 1), stride=(1, 1)) 125 | self.inception_4b_1x1_bn = nn.BatchNorm2d(192, eps=1e-05, momentum=0.9, affine=True) 126 | self.inception_4b_relu_1x1 = nn.ReLU (inplace) 127 | self.inception_4b_3x3_reduce = nn.Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1)) 128 | self.inception_4b_3x3_reduce_bn = nn.BatchNorm2d(96, eps=1e-05, momentum=0.9, affine=True) 129 | self.inception_4b_relu_3x3_reduce = nn.ReLU (inplace) 130 | self.inception_4b_3x3 = nn.Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 131 | self.inception_4b_3x3_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) 132 | self.inception_4b_relu_3x3 = nn.ReLU (inplace) 133 | self.inception_4b_double_3x3_reduce = nn.Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1)) 134 | self.inception_4b_double_3x3_reduce_bn = nn.BatchNorm2d(96, eps=1e-05, momentum=0.9, affine=True) 135 | self.inception_4b_relu_double_3x3_reduce = nn.ReLU (inplace) 136 | self.inception_4b_double_3x3_1 = nn.Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 137 | self.inception_4b_double_3x3_1_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) 138 | self.inception_4b_relu_double_3x3_1 = nn.ReLU (inplace) 139 | self.inception_4b_double_3x3_2 = nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 140 | self.inception_4b_double_3x3_2_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) 141 | self.inception_4b_relu_double_3x3_2 = nn.ReLU (inplace) 142 | self.inception_4b_pool = nn.AvgPool2d (3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) 143 | self.inception_4b_pool_proj = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1)) 144 | self.inception_4b_pool_proj_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) 145 | self.inception_4b_relu_pool_proj = nn.ReLU (inplace) 146 | self.inception_4c_1x1 = nn.Conv2d(576, 160, kernel_size=(1, 1), stride=(1, 1)) 147 | self.inception_4c_1x1_bn = nn.BatchNorm2d(160, eps=1e-05, momentum=0.9, affine=True) 148 | self.inception_4c_relu_1x1 = nn.ReLU (inplace) 149 | self.inception_4c_3x3_reduce = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1)) 150 | self.inception_4c_3x3_reduce_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) 151 | self.inception_4c_relu_3x3_reduce = nn.ReLU (inplace) 152 | self.inception_4c_3x3 = nn.Conv2d(128, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 153 | self.inception_4c_3x3_bn = nn.BatchNorm2d(160, eps=1e-05, momentum=0.9, affine=True) 154 | self.inception_4c_relu_3x3 = nn.ReLU (inplace) 155 | self.inception_4c_double_3x3_reduce = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1)) 156 | self.inception_4c_double_3x3_reduce_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) 157 | self.inception_4c_relu_double_3x3_reduce = nn.ReLU (inplace) 158 | self.inception_4c_double_3x3_1 = nn.Conv2d(128, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 159 | self.inception_4c_double_3x3_1_bn = nn.BatchNorm2d(160, eps=1e-05, momentum=0.9, affine=True) 160 | self.inception_4c_relu_double_3x3_1 = nn.ReLU (inplace) 161 | self.inception_4c_double_3x3_2 = nn.Conv2d(160, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 162 | self.inception_4c_double_3x3_2_bn = nn.BatchNorm2d(160, eps=1e-05, momentum=0.9, affine=True) 163 | self.inception_4c_relu_double_3x3_2 = nn.ReLU (inplace) 164 | self.inception_4c_pool = nn.AvgPool2d (3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) 165 | self.inception_4c_pool_proj = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1)) 166 | self.inception_4c_pool_proj_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) 167 | self.inception_4c_relu_pool_proj = nn.ReLU (inplace) 168 | self.inception_4d_1x1 = nn.Conv2d(608, 96, kernel_size=(1, 1), stride=(1, 1)) 169 | self.inception_4d_1x1_bn = nn.BatchNorm2d(96, eps=1e-05, momentum=0.9, affine=True) 170 | self.inception_4d_relu_1x1 = nn.ReLU (inplace) 171 | self.inception_4d_3x3_reduce = nn.Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1)) 172 | self.inception_4d_3x3_reduce_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) 173 | self.inception_4d_relu_3x3_reduce = nn.ReLU (inplace) 174 | self.inception_4d_3x3 = nn.Conv2d(128, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 175 | self.inception_4d_3x3_bn = nn.BatchNorm2d(192, eps=1e-05, momentum=0.9, affine=True) 176 | self.inception_4d_relu_3x3 = nn.ReLU (inplace) 177 | self.inception_4d_double_3x3_reduce = nn.Conv2d(608, 160, kernel_size=(1, 1), stride=(1, 1)) 178 | self.inception_4d_double_3x3_reduce_bn = nn.BatchNorm2d(160, eps=1e-05, momentum=0.9, affine=True) 179 | self.inception_4d_relu_double_3x3_reduce = nn.ReLU (inplace) 180 | self.inception_4d_double_3x3_1 = nn.Conv2d(160, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 181 | self.inception_4d_double_3x3_1_bn = nn.BatchNorm2d(192, eps=1e-05, momentum=0.9, affine=True) 182 | self.inception_4d_relu_double_3x3_1 = nn.ReLU (inplace) 183 | self.inception_4d_double_3x3_2 = nn.Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 184 | self.inception_4d_double_3x3_2_bn = nn.BatchNorm2d(192, eps=1e-05, momentum=0.9, affine=True) 185 | self.inception_4d_relu_double_3x3_2 = nn.ReLU (inplace) 186 | self.inception_4d_pool = nn.AvgPool2d (3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) 187 | self.inception_4d_pool_proj = nn.Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1)) 188 | self.inception_4d_pool_proj_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) 189 | self.inception_4d_relu_pool_proj = nn.ReLU (inplace) 190 | self.inception_4e_3x3_reduce = nn.Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1)) 191 | self.inception_4e_3x3_reduce_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) 192 | self.inception_4e_relu_3x3_reduce = nn.ReLU (inplace) 193 | self.inception_4e_3x3 = nn.Conv2d(128, 192, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) 194 | self.inception_4e_3x3_bn = nn.BatchNorm2d(192, eps=1e-05, momentum=0.9, affine=True) 195 | self.inception_4e_relu_3x3 = nn.ReLU (inplace) 196 | self.inception_4e_double_3x3_reduce = nn.Conv2d(608, 192, kernel_size=(1, 1), stride=(1, 1)) 197 | self.inception_4e_double_3x3_reduce_bn = nn.BatchNorm2d(192, eps=1e-05, momentum=0.9, affine=True) 198 | self.inception_4e_relu_double_3x3_reduce = nn.ReLU (inplace) 199 | self.inception_4e_double_3x3_1 = nn.Conv2d(192, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 200 | self.inception_4e_double_3x3_1_bn = nn.BatchNorm2d(256, eps=1e-05, momentum=0.9, affine=True) 201 | self.inception_4e_relu_double_3x3_1 = nn.ReLU (inplace) 202 | self.inception_4e_double_3x3_2 = nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) 203 | self.inception_4e_double_3x3_2_bn = nn.BatchNorm2d(256, eps=1e-05, momentum=0.9, affine=True) 204 | self.inception_4e_relu_double_3x3_2 = nn.ReLU (inplace) 205 | self.inception_4e_pool = nn.MaxPool2d ((3, 3), stride=(2, 2), dilation=(1, 1), ceil_mode=True) 206 | self.inception_5a_1x1 = nn.Conv2d(1056, 352, kernel_size=(1, 1), stride=(1, 1)) 207 | self.inception_5a_1x1_bn = nn.BatchNorm2d(352, eps=1e-05, momentum=0.9, affine=True) 208 | self.inception_5a_relu_1x1 = nn.ReLU (inplace) 209 | self.inception_5a_3x3_reduce = nn.Conv2d(1056, 192, kernel_size=(1, 1), stride=(1, 1)) 210 | self.inception_5a_3x3_reduce_bn = nn.BatchNorm2d(192, eps=1e-05, momentum=0.9, affine=True) 211 | self.inception_5a_relu_3x3_reduce = nn.ReLU (inplace) 212 | self.inception_5a_3x3 = nn.Conv2d(192, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 213 | self.inception_5a_3x3_bn = nn.BatchNorm2d(320, eps=1e-05, momentum=0.9, affine=True) 214 | self.inception_5a_relu_3x3 = nn.ReLU (inplace) 215 | self.inception_5a_double_3x3_reduce = nn.Conv2d(1056, 160, kernel_size=(1, 1), stride=(1, 1)) 216 | self.inception_5a_double_3x3_reduce_bn = nn.BatchNorm2d(160, eps=1e-05, momentum=0.9, affine=True) 217 | self.inception_5a_relu_double_3x3_reduce = nn.ReLU (inplace) 218 | self.inception_5a_double_3x3_1 = nn.Conv2d(160, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 219 | self.inception_5a_double_3x3_1_bn = nn.BatchNorm2d(224, eps=1e-05, momentum=0.9, affine=True) 220 | self.inception_5a_relu_double_3x3_1 = nn.ReLU (inplace) 221 | self.inception_5a_double_3x3_2 = nn.Conv2d(224, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 222 | self.inception_5a_double_3x3_2_bn = nn.BatchNorm2d(224, eps=1e-05, momentum=0.9, affine=True) 223 | self.inception_5a_relu_double_3x3_2 = nn.ReLU (inplace) 224 | self.inception_5a_pool = nn.AvgPool2d (3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) 225 | self.inception_5a_pool_proj = nn.Conv2d(1056, 128, kernel_size=(1, 1), stride=(1, 1)) 226 | self.inception_5a_pool_proj_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) 227 | self.inception_5a_relu_pool_proj = nn.ReLU(inplace) 228 | self.inception_5b_1x1 = nn.Conv2d(1024, 352, kernel_size=(1, 1), stride=(1, 1)) 229 | self.inception_5b_1x1_bn = nn.BatchNorm2d(352, eps=1e-05, momentum=0.9, affine=True) 230 | self.inception_5b_relu_1x1 = nn.ReLU(inplace) 231 | self.inception_5b_3x3_reduce = nn.Conv2d(1024, 192, kernel_size=(1, 1), stride=(1, 1)) 232 | self.inception_5b_3x3_reduce_bn = nn.BatchNorm2d(192, eps=1e-05, momentum=0.9, affine=True) 233 | self.inception_5b_relu_3x3_reduce = nn.ReLU (inplace) 234 | self.inception_5b_3x3 = nn.Conv2d(192, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 235 | self.inception_5b_3x3_bn = nn.BatchNorm2d(320, eps=1e-05, momentum=0.9, affine=True) 236 | self.inception_5b_relu_3x3 = nn.ReLU (inplace) 237 | self.inception_5b_double_3x3_reduce = nn.Conv2d(1024, 192, kernel_size=(1, 1), stride=(1, 1)) 238 | self.inception_5b_double_3x3_reduce_bn = nn.BatchNorm2d(192, eps=1e-05, momentum=0.9, affine=True) 239 | self.inception_5b_relu_double_3x3_reduce = nn.ReLU(inplace) 240 | self.inception_5b_double_3x3_1 = nn.Conv2d(192, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 241 | self.inception_5b_double_3x3_1_bn = nn.BatchNorm2d(224, eps=1e-05, momentum=0.9, affine=True) 242 | self.inception_5b_relu_double_3x3_1 = nn.ReLU(inplace) 243 | self.inception_5b_double_3x3_2 = nn.Conv2d(224, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 244 | self.inception_5b_double_3x3_2_bn = nn.BatchNorm2d(224, eps=1e-05, momentum=0.9, affine=True) 245 | self.inception_5b_relu_double_3x3_2 = nn.ReLU (inplace) 246 | self.inception_5b_pool = nn.MaxPool2d ((3, 3), stride=(1, 1), padding=(1, 1), dilation=(1, 1), ceil_mode=True) 247 | self.inception_5b_pool_proj = nn.Conv2d(1024, 128, kernel_size=(1, 1), stride=(1, 1)) 248 | self.inception_5b_pool_proj_bn = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True) 249 | self.inception_5b_relu_pool_proj = nn.ReLU(inplace) 250 | if self.dim == 0: 251 | pass 252 | else: 253 | self.classifier = Embedding(1024, self.dim, normalized=True) 254 | 255 | # Official init from torch repo. 256 | for m in self.modules(): 257 | if isinstance(m, nn.Conv2d): 258 | nn.init.kaiming_normal_(m.weight) 259 | elif isinstance(m, nn.BatchNorm2d): 260 | nn.init.constant_(m.weight, 1) 261 | nn.init.constant_(m.bias, 0) 262 | elif isinstance(m, nn.Linear): 263 | nn.init.constant_(m.bias, 0) 264 | 265 | def features(self, input): 266 | conv1_7x7_s2_out = self.conv1_7x7_s2(input) 267 | conv1_7x7_s2_bn_out = self.conv1_7x7_s2_bn(conv1_7x7_s2_out) 268 | conv1_relu_7x7_out = self.conv1_relu_7x7(conv1_7x7_s2_bn_out) 269 | pool1_3x3_s2_out = self.pool1_3x3_s2(conv1_7x7_s2_bn_out) 270 | conv2_3x3_reduce_out = self.conv2_3x3_reduce(pool1_3x3_s2_out) 271 | conv2_3x3_reduce_bn_out = self.conv2_3x3_reduce_bn(conv2_3x3_reduce_out) 272 | conv2_relu_3x3_reduce_out = self.conv2_relu_3x3_reduce(conv2_3x3_reduce_bn_out) 273 | conv2_3x3_out = self.conv2_3x3(conv2_3x3_reduce_bn_out) 274 | conv2_3x3_bn_out = self.conv2_3x3_bn(conv2_3x3_out) 275 | conv2_relu_3x3_out = self.conv2_relu_3x3(conv2_3x3_bn_out) 276 | pool2_3x3_s2_out = self.pool2_3x3_s2(conv2_3x3_bn_out) 277 | inception_3a_1x1_out = self.inception_3a_1x1(pool2_3x3_s2_out) 278 | inception_3a_1x1_bn_out = self.inception_3a_1x1_bn(inception_3a_1x1_out) 279 | inception_3a_relu_1x1_out = self.inception_3a_relu_1x1(inception_3a_1x1_bn_out) 280 | inception_3a_3x3_reduce_out = self.inception_3a_3x3_reduce(pool2_3x3_s2_out) 281 | inception_3a_3x3_reduce_bn_out = self.inception_3a_3x3_reduce_bn(inception_3a_3x3_reduce_out) 282 | inception_3a_relu_3x3_reduce_out = self.inception_3a_relu_3x3_reduce(inception_3a_3x3_reduce_bn_out) 283 | inception_3a_3x3_out = self.inception_3a_3x3(inception_3a_3x3_reduce_bn_out) 284 | inception_3a_3x3_bn_out = self.inception_3a_3x3_bn(inception_3a_3x3_out) 285 | inception_3a_relu_3x3_out = self.inception_3a_relu_3x3(inception_3a_3x3_bn_out) 286 | inception_3a_double_3x3_reduce_out = self.inception_3a_double_3x3_reduce(pool2_3x3_s2_out) 287 | inception_3a_double_3x3_reduce_bn_out = self.inception_3a_double_3x3_reduce_bn(inception_3a_double_3x3_reduce_out) 288 | inception_3a_relu_double_3x3_reduce_out = self.inception_3a_relu_double_3x3_reduce(inception_3a_double_3x3_reduce_bn_out) 289 | inception_3a_double_3x3_1_out = self.inception_3a_double_3x3_1(inception_3a_double_3x3_reduce_bn_out) 290 | inception_3a_double_3x3_1_bn_out = self.inception_3a_double_3x3_1_bn(inception_3a_double_3x3_1_out) 291 | inception_3a_relu_double_3x3_1_out = self.inception_3a_relu_double_3x3_1(inception_3a_double_3x3_1_bn_out) 292 | inception_3a_double_3x3_2_out = self.inception_3a_double_3x3_2(inception_3a_double_3x3_1_bn_out) 293 | inception_3a_double_3x3_2_bn_out = self.inception_3a_double_3x3_2_bn(inception_3a_double_3x3_2_out) 294 | inception_3a_relu_double_3x3_2_out = self.inception_3a_relu_double_3x3_2(inception_3a_double_3x3_2_bn_out) 295 | inception_3a_pool_out = self.inception_3a_pool(pool2_3x3_s2_out) 296 | inception_3a_pool_proj_out = self.inception_3a_pool_proj(inception_3a_pool_out) 297 | inception_3a_pool_proj_bn_out = self.inception_3a_pool_proj_bn(inception_3a_pool_proj_out) 298 | inception_3a_relu_pool_proj_out = self.inception_3a_relu_pool_proj(inception_3a_pool_proj_bn_out) 299 | inception_3a_output_out = torch.cat([inception_3a_1x1_bn_out,inception_3a_3x3_bn_out,inception_3a_double_3x3_2_bn_out,inception_3a_pool_proj_bn_out], 1) 300 | inception_3b_1x1_out = self.inception_3b_1x1(inception_3a_output_out) 301 | inception_3b_1x1_bn_out = self.inception_3b_1x1_bn(inception_3b_1x1_out) 302 | inception_3b_relu_1x1_out = self.inception_3b_relu_1x1(inception_3b_1x1_bn_out) 303 | inception_3b_3x3_reduce_out = self.inception_3b_3x3_reduce(inception_3a_output_out) 304 | inception_3b_3x3_reduce_bn_out = self.inception_3b_3x3_reduce_bn(inception_3b_3x3_reduce_out) 305 | inception_3b_relu_3x3_reduce_out = self.inception_3b_relu_3x3_reduce(inception_3b_3x3_reduce_bn_out) 306 | inception_3b_3x3_out = self.inception_3b_3x3(inception_3b_3x3_reduce_bn_out) 307 | inception_3b_3x3_bn_out = self.inception_3b_3x3_bn(inception_3b_3x3_out) 308 | inception_3b_relu_3x3_out = self.inception_3b_relu_3x3(inception_3b_3x3_bn_out) 309 | inception_3b_double_3x3_reduce_out = self.inception_3b_double_3x3_reduce(inception_3a_output_out) 310 | inception_3b_double_3x3_reduce_bn_out = self.inception_3b_double_3x3_reduce_bn(inception_3b_double_3x3_reduce_out) 311 | inception_3b_relu_double_3x3_reduce_out = self.inception_3b_relu_double_3x3_reduce(inception_3b_double_3x3_reduce_bn_out) 312 | inception_3b_double_3x3_1_out = self.inception_3b_double_3x3_1(inception_3b_double_3x3_reduce_bn_out) 313 | inception_3b_double_3x3_1_bn_out = self.inception_3b_double_3x3_1_bn(inception_3b_double_3x3_1_out) 314 | inception_3b_relu_double_3x3_1_out = self.inception_3b_relu_double_3x3_1(inception_3b_double_3x3_1_bn_out) 315 | inception_3b_double_3x3_2_out = self.inception_3b_double_3x3_2(inception_3b_double_3x3_1_bn_out) 316 | inception_3b_double_3x3_2_bn_out = self.inception_3b_double_3x3_2_bn(inception_3b_double_3x3_2_out) 317 | inception_3b_relu_double_3x3_2_out = self.inception_3b_relu_double_3x3_2(inception_3b_double_3x3_2_bn_out) 318 | inception_3b_pool_out = self.inception_3b_pool(inception_3a_output_out) 319 | inception_3b_pool_proj_out = self.inception_3b_pool_proj(inception_3b_pool_out) 320 | inception_3b_pool_proj_bn_out = self.inception_3b_pool_proj_bn(inception_3b_pool_proj_out) 321 | inception_3b_relu_pool_proj_out = self.inception_3b_relu_pool_proj(inception_3b_pool_proj_bn_out) 322 | inception_3b_output_out = torch.cat([inception_3b_1x1_bn_out,inception_3b_3x3_bn_out,inception_3b_double_3x3_2_bn_out,inception_3b_pool_proj_bn_out], 1) 323 | inception_3c_3x3_reduce_out = self.inception_3c_3x3_reduce(inception_3b_output_out) 324 | inception_3c_3x3_reduce_bn_out = self.inception_3c_3x3_reduce_bn(inception_3c_3x3_reduce_out) 325 | inception_3c_relu_3x3_reduce_out = self.inception_3c_relu_3x3_reduce(inception_3c_3x3_reduce_bn_out) 326 | inception_3c_3x3_out = self.inception_3c_3x3(inception_3c_3x3_reduce_bn_out) 327 | inception_3c_3x3_bn_out = self.inception_3c_3x3_bn(inception_3c_3x3_out) 328 | inception_3c_relu_3x3_out = self.inception_3c_relu_3x3(inception_3c_3x3_bn_out) 329 | inception_3c_double_3x3_reduce_out = self.inception_3c_double_3x3_reduce(inception_3b_output_out) 330 | inception_3c_double_3x3_reduce_bn_out = self.inception_3c_double_3x3_reduce_bn(inception_3c_double_3x3_reduce_out) 331 | inception_3c_relu_double_3x3_reduce_out = self.inception_3c_relu_double_3x3_reduce(inception_3c_double_3x3_reduce_bn_out) 332 | inception_3c_double_3x3_1_out = self.inception_3c_double_3x3_1(inception_3c_double_3x3_reduce_bn_out) 333 | inception_3c_double_3x3_1_bn_out = self.inception_3c_double_3x3_1_bn(inception_3c_double_3x3_1_out) 334 | inception_3c_relu_double_3x3_1_out = self.inception_3c_relu_double_3x3_1(inception_3c_double_3x3_1_bn_out) 335 | inception_3c_double_3x3_2_out = self.inception_3c_double_3x3_2(inception_3c_double_3x3_1_bn_out) 336 | inception_3c_double_3x3_2_bn_out = self.inception_3c_double_3x3_2_bn(inception_3c_double_3x3_2_out) 337 | inception_3c_relu_double_3x3_2_out = self.inception_3c_relu_double_3x3_2(inception_3c_double_3x3_2_bn_out) 338 | inception_3c_pool_out = self.inception_3c_pool(inception_3b_output_out) 339 | inception_3c_output_out = torch.cat([inception_3c_3x3_bn_out,inception_3c_double_3x3_2_bn_out,inception_3c_pool_out], 1) 340 | inception_4a_1x1_out = self.inception_4a_1x1(inception_3c_output_out) 341 | inception_4a_1x1_bn_out = self.inception_4a_1x1_bn(inception_4a_1x1_out) 342 | inception_4a_relu_1x1_out = self.inception_4a_relu_1x1(inception_4a_1x1_bn_out) 343 | inception_4a_3x3_reduce_out = self.inception_4a_3x3_reduce(inception_3c_output_out) 344 | inception_4a_3x3_reduce_bn_out = self.inception_4a_3x3_reduce_bn(inception_4a_3x3_reduce_out) 345 | inception_4a_relu_3x3_reduce_out = self.inception_4a_relu_3x3_reduce(inception_4a_3x3_reduce_bn_out) 346 | inception_4a_3x3_out = self.inception_4a_3x3(inception_4a_3x3_reduce_bn_out) 347 | inception_4a_3x3_bn_out = self.inception_4a_3x3_bn(inception_4a_3x3_out) 348 | inception_4a_relu_3x3_out = self.inception_4a_relu_3x3(inception_4a_3x3_bn_out) 349 | inception_4a_double_3x3_reduce_out = self.inception_4a_double_3x3_reduce(inception_3c_output_out) 350 | inception_4a_double_3x3_reduce_bn_out = self.inception_4a_double_3x3_reduce_bn(inception_4a_double_3x3_reduce_out) 351 | inception_4a_relu_double_3x3_reduce_out = self.inception_4a_relu_double_3x3_reduce(inception_4a_double_3x3_reduce_bn_out) 352 | inception_4a_double_3x3_1_out = self.inception_4a_double_3x3_1(inception_4a_double_3x3_reduce_bn_out) 353 | inception_4a_double_3x3_1_bn_out = self.inception_4a_double_3x3_1_bn(inception_4a_double_3x3_1_out) 354 | inception_4a_relu_double_3x3_1_out = self.inception_4a_relu_double_3x3_1(inception_4a_double_3x3_1_bn_out) 355 | inception_4a_double_3x3_2_out = self.inception_4a_double_3x3_2(inception_4a_double_3x3_1_bn_out) 356 | inception_4a_double_3x3_2_bn_out = self.inception_4a_double_3x3_2_bn(inception_4a_double_3x3_2_out) 357 | inception_4a_relu_double_3x3_2_out = self.inception_4a_relu_double_3x3_2(inception_4a_double_3x3_2_bn_out) 358 | inception_4a_pool_out = self.inception_4a_pool(inception_3c_output_out) 359 | inception_4a_pool_proj_out = self.inception_4a_pool_proj(inception_4a_pool_out) 360 | inception_4a_pool_proj_bn_out = self.inception_4a_pool_proj_bn(inception_4a_pool_proj_out) 361 | inception_4a_relu_pool_proj_out = self.inception_4a_relu_pool_proj(inception_4a_pool_proj_bn_out) 362 | inception_4a_output_out = torch.cat([inception_4a_1x1_bn_out,inception_4a_3x3_bn_out,inception_4a_double_3x3_2_bn_out,inception_4a_pool_proj_bn_out], 1) 363 | inception_4b_1x1_out = self.inception_4b_1x1(inception_4a_output_out) 364 | inception_4b_1x1_bn_out = self.inception_4b_1x1_bn(inception_4b_1x1_out) 365 | inception_4b_relu_1x1_out = self.inception_4b_relu_1x1(inception_4b_1x1_bn_out) 366 | inception_4b_3x3_reduce_out = self.inception_4b_3x3_reduce(inception_4a_output_out) 367 | inception_4b_3x3_reduce_bn_out = self.inception_4b_3x3_reduce_bn(inception_4b_3x3_reduce_out) 368 | inception_4b_relu_3x3_reduce_out = self.inception_4b_relu_3x3_reduce(inception_4b_3x3_reduce_bn_out) 369 | inception_4b_3x3_out = self.inception_4b_3x3(inception_4b_3x3_reduce_bn_out) 370 | inception_4b_3x3_bn_out = self.inception_4b_3x3_bn(inception_4b_3x3_out) 371 | inception_4b_relu_3x3_out = self.inception_4b_relu_3x3(inception_4b_3x3_bn_out) 372 | inception_4b_double_3x3_reduce_out = self.inception_4b_double_3x3_reduce(inception_4a_output_out) 373 | inception_4b_double_3x3_reduce_bn_out = self.inception_4b_double_3x3_reduce_bn(inception_4b_double_3x3_reduce_out) 374 | inception_4b_relu_double_3x3_reduce_out = self.inception_4b_relu_double_3x3_reduce(inception_4b_double_3x3_reduce_bn_out) 375 | inception_4b_double_3x3_1_out = self.inception_4b_double_3x3_1(inception_4b_double_3x3_reduce_bn_out) 376 | inception_4b_double_3x3_1_bn_out = self.inception_4b_double_3x3_1_bn(inception_4b_double_3x3_1_out) 377 | inception_4b_relu_double_3x3_1_out = self.inception_4b_relu_double_3x3_1(inception_4b_double_3x3_1_bn_out) 378 | inception_4b_double_3x3_2_out = self.inception_4b_double_3x3_2(inception_4b_double_3x3_1_bn_out) 379 | inception_4b_double_3x3_2_bn_out = self.inception_4b_double_3x3_2_bn(inception_4b_double_3x3_2_out) 380 | inception_4b_relu_double_3x3_2_out = self.inception_4b_relu_double_3x3_2(inception_4b_double_3x3_2_bn_out) 381 | inception_4b_pool_out = self.inception_4b_pool(inception_4a_output_out) 382 | inception_4b_pool_proj_out = self.inception_4b_pool_proj(inception_4b_pool_out) 383 | inception_4b_pool_proj_bn_out = self.inception_4b_pool_proj_bn(inception_4b_pool_proj_out) 384 | inception_4b_relu_pool_proj_out = self.inception_4b_relu_pool_proj(inception_4b_pool_proj_bn_out) 385 | inception_4b_output_out = torch.cat([inception_4b_1x1_bn_out,inception_4b_3x3_bn_out,inception_4b_double_3x3_2_bn_out,inception_4b_pool_proj_bn_out], 1) 386 | inception_4c_1x1_out = self.inception_4c_1x1(inception_4b_output_out) 387 | inception_4c_1x1_bn_out = self.inception_4c_1x1_bn(inception_4c_1x1_out) 388 | inception_4c_relu_1x1_out = self.inception_4c_relu_1x1(inception_4c_1x1_bn_out) 389 | inception_4c_3x3_reduce_out = self.inception_4c_3x3_reduce(inception_4b_output_out) 390 | inception_4c_3x3_reduce_bn_out = self.inception_4c_3x3_reduce_bn(inception_4c_3x3_reduce_out) 391 | inception_4c_relu_3x3_reduce_out = self.inception_4c_relu_3x3_reduce(inception_4c_3x3_reduce_bn_out) 392 | inception_4c_3x3_out = self.inception_4c_3x3(inception_4c_3x3_reduce_bn_out) 393 | inception_4c_3x3_bn_out = self.inception_4c_3x3_bn(inception_4c_3x3_out) 394 | inception_4c_relu_3x3_out = self.inception_4c_relu_3x3(inception_4c_3x3_bn_out) 395 | inception_4c_double_3x3_reduce_out = self.inception_4c_double_3x3_reduce(inception_4b_output_out) 396 | inception_4c_double_3x3_reduce_bn_out = self.inception_4c_double_3x3_reduce_bn(inception_4c_double_3x3_reduce_out) 397 | inception_4c_relu_double_3x3_reduce_out = self.inception_4c_relu_double_3x3_reduce(inception_4c_double_3x3_reduce_bn_out) 398 | inception_4c_double_3x3_1_out = self.inception_4c_double_3x3_1(inception_4c_double_3x3_reduce_bn_out) 399 | inception_4c_double_3x3_1_bn_out = self.inception_4c_double_3x3_1_bn(inception_4c_double_3x3_1_out) 400 | inception_4c_relu_double_3x3_1_out = self.inception_4c_relu_double_3x3_1(inception_4c_double_3x3_1_bn_out) 401 | inception_4c_double_3x3_2_out = self.inception_4c_double_3x3_2(inception_4c_double_3x3_1_bn_out) 402 | inception_4c_double_3x3_2_bn_out = self.inception_4c_double_3x3_2_bn(inception_4c_double_3x3_2_out) 403 | inception_4c_relu_double_3x3_2_out = self.inception_4c_relu_double_3x3_2(inception_4c_double_3x3_2_bn_out) 404 | inception_4c_pool_out = self.inception_4c_pool(inception_4b_output_out) 405 | inception_4c_pool_proj_out = self.inception_4c_pool_proj(inception_4c_pool_out) 406 | inception_4c_pool_proj_bn_out = self.inception_4c_pool_proj_bn(inception_4c_pool_proj_out) 407 | inception_4c_relu_pool_proj_out = self.inception_4c_relu_pool_proj(inception_4c_pool_proj_bn_out) 408 | inception_4c_output_out = torch.cat([inception_4c_1x1_bn_out,inception_4c_3x3_bn_out,inception_4c_double_3x3_2_bn_out,inception_4c_pool_proj_bn_out], 1) 409 | inception_4d_1x1_out = self.inception_4d_1x1(inception_4c_output_out) 410 | inception_4d_1x1_bn_out = self.inception_4d_1x1_bn(inception_4d_1x1_out) 411 | inception_4d_relu_1x1_out = self.inception_4d_relu_1x1(inception_4d_1x1_bn_out) 412 | inception_4d_3x3_reduce_out = self.inception_4d_3x3_reduce(inception_4c_output_out) 413 | inception_4d_3x3_reduce_bn_out = self.inception_4d_3x3_reduce_bn(inception_4d_3x3_reduce_out) 414 | inception_4d_relu_3x3_reduce_out = self.inception_4d_relu_3x3_reduce(inception_4d_3x3_reduce_bn_out) 415 | inception_4d_3x3_out = self.inception_4d_3x3(inception_4d_3x3_reduce_bn_out) 416 | inception_4d_3x3_bn_out = self.inception_4d_3x3_bn(inception_4d_3x3_out) 417 | inception_4d_relu_3x3_out = self.inception_4d_relu_3x3(inception_4d_3x3_bn_out) 418 | inception_4d_double_3x3_reduce_out = self.inception_4d_double_3x3_reduce(inception_4c_output_out) 419 | inception_4d_double_3x3_reduce_bn_out = self.inception_4d_double_3x3_reduce_bn(inception_4d_double_3x3_reduce_out) 420 | inception_4d_relu_double_3x3_reduce_out = self.inception_4d_relu_double_3x3_reduce(inception_4d_double_3x3_reduce_bn_out) 421 | inception_4d_double_3x3_1_out = self.inception_4d_double_3x3_1(inception_4d_double_3x3_reduce_bn_out) 422 | inception_4d_double_3x3_1_bn_out = self.inception_4d_double_3x3_1_bn(inception_4d_double_3x3_1_out) 423 | inception_4d_relu_double_3x3_1_out = self.inception_4d_relu_double_3x3_1(inception_4d_double_3x3_1_bn_out) 424 | inception_4d_double_3x3_2_out = self.inception_4d_double_3x3_2(inception_4d_double_3x3_1_bn_out) 425 | inception_4d_double_3x3_2_bn_out = self.inception_4d_double_3x3_2_bn(inception_4d_double_3x3_2_out) 426 | inception_4d_relu_double_3x3_2_out = self.inception_4d_relu_double_3x3_2(inception_4d_double_3x3_2_bn_out) 427 | inception_4d_pool_out = self.inception_4d_pool(inception_4c_output_out) 428 | inception_4d_pool_proj_out = self.inception_4d_pool_proj(inception_4d_pool_out) 429 | inception_4d_pool_proj_bn_out = self.inception_4d_pool_proj_bn(inception_4d_pool_proj_out) 430 | inception_4d_relu_pool_proj_out = self.inception_4d_relu_pool_proj(inception_4d_pool_proj_bn_out) 431 | inception_4d_output_out = torch.cat([inception_4d_1x1_bn_out,inception_4d_3x3_bn_out,inception_4d_double_3x3_2_bn_out,inception_4d_pool_proj_bn_out], 1) 432 | inception_4e_3x3_reduce_out = self.inception_4e_3x3_reduce(inception_4d_output_out) 433 | inception_4e_3x3_reduce_bn_out = self.inception_4e_3x3_reduce_bn(inception_4e_3x3_reduce_out) 434 | inception_4e_relu_3x3_reduce_out = self.inception_4e_relu_3x3_reduce(inception_4e_3x3_reduce_bn_out) 435 | inception_4e_3x3_out = self.inception_4e_3x3(inception_4e_3x3_reduce_bn_out) 436 | inception_4e_3x3_bn_out = self.inception_4e_3x3_bn(inception_4e_3x3_out) 437 | inception_4e_relu_3x3_out = self.inception_4e_relu_3x3(inception_4e_3x3_bn_out) 438 | inception_4e_double_3x3_reduce_out = self.inception_4e_double_3x3_reduce(inception_4d_output_out) 439 | inception_4e_double_3x3_reduce_bn_out = self.inception_4e_double_3x3_reduce_bn(inception_4e_double_3x3_reduce_out) 440 | inception_4e_relu_double_3x3_reduce_out = self.inception_4e_relu_double_3x3_reduce(inception_4e_double_3x3_reduce_bn_out) 441 | inception_4e_double_3x3_1_out = self.inception_4e_double_3x3_1(inception_4e_double_3x3_reduce_bn_out) 442 | inception_4e_double_3x3_1_bn_out = self.inception_4e_double_3x3_1_bn(inception_4e_double_3x3_1_out) 443 | inception_4e_relu_double_3x3_1_out = self.inception_4e_relu_double_3x3_1(inception_4e_double_3x3_1_bn_out) 444 | inception_4e_double_3x3_2_out = self.inception_4e_double_3x3_2(inception_4e_double_3x3_1_bn_out) 445 | inception_4e_double_3x3_2_bn_out = self.inception_4e_double_3x3_2_bn(inception_4e_double_3x3_2_out) 446 | inception_4e_relu_double_3x3_2_out = self.inception_4e_relu_double_3x3_2(inception_4e_double_3x3_2_bn_out) 447 | inception_4e_pool_out = self.inception_4e_pool(inception_4d_output_out) 448 | inception_4e_output_out = torch.cat([inception_4e_3x3_bn_out,inception_4e_double_3x3_2_bn_out,inception_4e_pool_out], 1) 449 | inception_5a_1x1_out = self.inception_5a_1x1(inception_4e_output_out) 450 | inception_5a_1x1_bn_out = self.inception_5a_1x1_bn(inception_5a_1x1_out) 451 | inception_5a_relu_1x1_out = self.inception_5a_relu_1x1(inception_5a_1x1_bn_out) 452 | inception_5a_3x3_reduce_out = self.inception_5a_3x3_reduce(inception_4e_output_out) 453 | inception_5a_3x3_reduce_bn_out = self.inception_5a_3x3_reduce_bn(inception_5a_3x3_reduce_out) 454 | inception_5a_relu_3x3_reduce_out = self.inception_5a_relu_3x3_reduce(inception_5a_3x3_reduce_bn_out) 455 | inception_5a_3x3_out = self.inception_5a_3x3(inception_5a_3x3_reduce_bn_out) 456 | inception_5a_3x3_bn_out = self.inception_5a_3x3_bn(inception_5a_3x3_out) 457 | inception_5a_relu_3x3_out = self.inception_5a_relu_3x3(inception_5a_3x3_bn_out) 458 | inception_5a_double_3x3_reduce_out = self.inception_5a_double_3x3_reduce(inception_4e_output_out) 459 | inception_5a_double_3x3_reduce_bn_out = self.inception_5a_double_3x3_reduce_bn(inception_5a_double_3x3_reduce_out) 460 | inception_5a_relu_double_3x3_reduce_out = self.inception_5a_relu_double_3x3_reduce(inception_5a_double_3x3_reduce_bn_out) 461 | inception_5a_double_3x3_1_out = self.inception_5a_double_3x3_1(inception_5a_double_3x3_reduce_bn_out) 462 | inception_5a_double_3x3_1_bn_out = self.inception_5a_double_3x3_1_bn(inception_5a_double_3x3_1_out) 463 | inception_5a_relu_double_3x3_1_out = self.inception_5a_relu_double_3x3_1(inception_5a_double_3x3_1_bn_out) 464 | inception_5a_double_3x3_2_out = self.inception_5a_double_3x3_2(inception_5a_double_3x3_1_bn_out) 465 | inception_5a_double_3x3_2_bn_out = self.inception_5a_double_3x3_2_bn(inception_5a_double_3x3_2_out) 466 | inception_5a_relu_double_3x3_2_out = self.inception_5a_relu_double_3x3_2(inception_5a_double_3x3_2_bn_out) 467 | inception_5a_pool_out = self.inception_5a_pool(inception_4e_output_out) 468 | inception_5a_pool_proj_out = self.inception_5a_pool_proj(inception_5a_pool_out) 469 | inception_5a_pool_proj_bn_out = self.inception_5a_pool_proj_bn(inception_5a_pool_proj_out) 470 | inception_5a_relu_pool_proj_out = self.inception_5a_relu_pool_proj(inception_5a_pool_proj_bn_out) 471 | inception_5a_output_out = torch.cat([inception_5a_1x1_bn_out,inception_5a_3x3_bn_out,inception_5a_double_3x3_2_bn_out,inception_5a_pool_proj_bn_out], 1) 472 | inception_5b_1x1_out = self.inception_5b_1x1(inception_5a_output_out) 473 | inception_5b_1x1_bn_out = self.inception_5b_1x1_bn(inception_5b_1x1_out) 474 | inception_5b_relu_1x1_out = self.inception_5b_relu_1x1(inception_5b_1x1_bn_out) 475 | inception_5b_3x3_reduce_out = self.inception_5b_3x3_reduce(inception_5a_output_out) 476 | inception_5b_3x3_reduce_bn_out = self.inception_5b_3x3_reduce_bn(inception_5b_3x3_reduce_out) 477 | inception_5b_relu_3x3_reduce_out = self.inception_5b_relu_3x3_reduce(inception_5b_3x3_reduce_bn_out) 478 | inception_5b_3x3_out = self.inception_5b_3x3(inception_5b_3x3_reduce_bn_out) 479 | inception_5b_3x3_bn_out = self.inception_5b_3x3_bn(inception_5b_3x3_out) 480 | inception_5b_relu_3x3_out = self.inception_5b_relu_3x3(inception_5b_3x3_bn_out) 481 | inception_5b_double_3x3_reduce_out = self.inception_5b_double_3x3_reduce(inception_5a_output_out) 482 | inception_5b_double_3x3_reduce_bn_out = self.inception_5b_double_3x3_reduce_bn(inception_5b_double_3x3_reduce_out) 483 | inception_5b_relu_double_3x3_reduce_out = self.inception_5b_relu_double_3x3_reduce(inception_5b_double_3x3_reduce_bn_out) 484 | inception_5b_double_3x3_1_out = self.inception_5b_double_3x3_1(inception_5b_double_3x3_reduce_bn_out) 485 | inception_5b_double_3x3_1_bn_out = self.inception_5b_double_3x3_1_bn(inception_5b_double_3x3_1_out) 486 | inception_5b_relu_double_3x3_1_out = self.inception_5b_relu_double_3x3_1(inception_5b_double_3x3_1_bn_out) 487 | inception_5b_double_3x3_2_out = self.inception_5b_double_3x3_2(inception_5b_double_3x3_1_bn_out) 488 | inception_5b_double_3x3_2_bn_out = self.inception_5b_double_3x3_2_bn(inception_5b_double_3x3_2_out) 489 | 490 | inception_5b_relu_double_3x3_2_out = self.inception_5b_relu_double_3x3_2(inception_5b_double_3x3_2_bn_out) 491 | inception_5b_pool_out = self.inception_5b_pool(inception_5a_output_out) 492 | inception_5b_pool_proj_out = self.inception_5b_pool_proj(inception_5b_pool_out) 493 | inception_5b_pool_proj_bn_out = self.inception_5b_pool_proj_bn(inception_5b_pool_proj_out) 494 | inception_5b_relu_pool_proj_out = self.inception_5b_relu_pool_proj(inception_5b_pool_proj_bn_out) 495 | inception_5b_output_out = torch.cat([inception_5b_1x1_bn_out, 496 | inception_5b_3x3_bn_out, 497 | inception_5b_double_3x3_2_bn_out, 498 | inception_5b_pool_proj_bn_out], 1) 499 | return inception_5b_output_out 500 | 501 | def forward(self, x): 502 | x = self.features(x) 503 | x = F.adaptive_max_pool2d(x, output_size=1) 504 | x = x.view(x.size(0), -1) 505 | if self.dim == 0: 506 | return x 507 | x = self.classifier(x) 508 | return x 509 | 510 | 511 | 512 | def BN_Inception(dim=512, pretrained=True, model_path=None): 513 | model = BNInception(dim=512) 514 | if model_path is None: 515 | model_path = '/home/xunwang/.torch/models/bn_inception-52deb4733.pth' 516 | if pretrained is True: 517 | model_dict = model.state_dict() 518 | pretrained_dict = torch.load(model_path) 519 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 520 | model_dict.update(pretrained_dict) 521 | model.load_state_dict(model_dict) 522 | return model 523 | 524 | def main(): 525 | model = BN_Inception(dim=512, pretrained=True) 526 | # print(model) 527 | images = Variable(torch.ones(8, 3, 227, 227)) 528 | out_ = model(images) 529 | print(out_.data.shape) 530 | 531 | if __name__ == '__main__': 532 | main() 533 | 534 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .BN_Inception import BN_Inception 2 | 3 | __factory = { 4 | 'BN-Inception': BN_Inception, 5 | } 6 | 7 | def names(): 8 | return sorted(__factory.keys()) 9 | 10 | 11 | def create(name, *args, **kwargs): 12 | """ 13 | Create a loss instance. 14 | 15 | Parameters 16 | ---------- 17 | name : str 18 | the name of loss function 19 | """ 20 | if name not in __factory: 21 | raise KeyError("Unknown network:", name) 22 | return __factory[name](*args, **kwargs) 23 | -------------------------------------------------------------------------------- /run_train_00.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | DATA=cub 3 | DATA_ROOT=data 4 | Gallery_eq_Query=True 5 | LOSS=Weight 6 | CHECKPOINTS=ckps 7 | R=.pth.tar 8 | 9 | if_exist_mkdir () 10 | { 11 | dirname=$1 12 | if [ ! -d "$dirname" ]; then 13 | mkdir $dirname 14 | fi 15 | } 16 | 17 | if_exist_mkdir ${CHECKPOINTS} 18 | if_exist_mkdir ${CHECKPOINTS}/${LOSS} 19 | if_exist_mkdir ${CHECKPOINTS}/${LOSS}/${DATA} 20 | 21 | if_exist_mkdir result 22 | if_exist_mkdir result/${LOSS} 23 | if_exist_mkdir result/${LOSS}/${DATA} 24 | 25 | NET=BN-Inception 26 | DIM=512 27 | ALPHA=40 28 | LR=1e-5 29 | BatchSize=80 30 | RATIO=0.16 31 | 32 | SAVE_DIR=${CHECKPOINTS}/${LOSS}/${DATA}/${NET}-DIM-${DIM}-lr${LR}-ratio-${RATIO}-BatchSize-${BatchSize} 33 | if_exist_mkdir ${SAVE_DIR} 34 | 35 | 36 | # if [ ! -n "$1" ] ;then 37 | echo "Begin Training!" 38 | CUDA_VISIBLE_DEVICES=0 python train.py --net ${NET} \ 39 | --data $DATA \ 40 | --data_root ${DATA_ROOT} \ 41 | --init random \ 42 | --lr $LR \ 43 | --dim $DIM \ 44 | --alpha $ALPHA \ 45 | --num_instances 5 \ 46 | --batch_size ${BatchSize} \ 47 | --epoch 600 \ 48 | --loss $LOSS \ 49 | --width 227 \ 50 | --save_dir ${SAVE_DIR} \ 51 | --save_step 50 \ 52 | --ratio ${RATIO} 53 | 54 | echo "Begin Testing!" 55 | 56 | Model_LIST=`seq 50 50 600` 57 | for i in $Model_LIST; do 58 | CUDA_VISIBLE_DEVICES=0 python test.py --net ${NET} \ 59 | --data $DATA \ 60 | --data_root ${DATA_ROOT} \ 61 | --batch_size 100 \ 62 | -g_eq_q ${Gallery_eq_Query} \ 63 | --width 227 \ 64 | -r ${SAVE_DIR}/ckp_ep$i$R \ 65 | --pool_feature ${POOL_FEATURE:-'False'} \ 66 | | tee -a result/$LOSS/$DATA/${NET}-DIM-$DIM-Batchsize-${BatchSize}-ratio-${RATIO}-lr-$LR${POOL_FEATURE:+'-pool_feature'}.txt 67 | 68 | done 69 | 70 | 71 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import absolute_import, print_function 3 | import argparse 4 | from Model2Feature import Model2Feature 5 | from evaluations import Recall_at_ks, pairwise_similarity 6 | from utils.serialization import load_checkpoint 7 | import torch 8 | import ast 9 | 10 | parser = argparse.ArgumentParser(description='PyTorch Testing') 11 | 12 | parser.add_argument('--data', type=str, default='cub') 13 | parser.add_argument('--data_root', type=str, default=None) 14 | parser.add_argument('--gallery_eq_query', '-g_eq_q', type=ast.literal_eval, default=False, 15 | help='Is gallery identical with query') 16 | parser.add_argument('--net', type=str, default='VGG16-BN') 17 | parser.add_argument('--resume', '-r', type=str, default='model.pkl', metavar='PATH') 18 | 19 | parser.add_argument('--dim', '-d', type=int, default=512, 20 | help='Dimension of Embedding Feather') 21 | parser.add_argument('--width', type=int, default=224, 22 | help='width of input image') 23 | 24 | parser.add_argument('--batch_size', type=int, default=100) 25 | parser.add_argument('--nThreads', '-j', default=16, type=int, metavar='N', 26 | help='number of data loading threads (default: 2)') 27 | parser.add_argument('--pool_feature', type=ast.literal_eval, default=False, required=False, 28 | help='if True extract feature from the last pool layer') 29 | 30 | args = parser.parse_args() 31 | 32 | checkpoint = load_checkpoint(args.resume) 33 | print(args.pool_feature) 34 | epoch = checkpoint['epoch'] 35 | 36 | gallery_feature, gallery_labels, query_feature, query_labels = \ 37 | Model2Feature(data=args.data, root=args.data_root, width=args.width, net=args.net, checkpoint=checkpoint, 38 | dim=args.dim, batch_size=args.batch_size, nThreads=args.nThreads, pool_feature=args.pool_feature) 39 | 40 | sim_mat = pairwise_similarity(query_feature, gallery_feature) 41 | if args.gallery_eq_query is True: 42 | sim_mat = sim_mat - torch.eye(sim_mat.size(0)) 43 | 44 | recall_ks = Recall_at_ks(sim_mat, query_ids=query_labels, gallery_ids=gallery_labels, data=args.data) 45 | 46 | result = ' '.join(['%.4f' % k for k in recall_ks]) 47 | print('Epoch-%d' % epoch, result) 48 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import absolute_import, print_function 3 | import time 4 | import argparse 5 | import os 6 | import sys 7 | import torch.utils.data 8 | from torch.backends import cudnn 9 | from torch.autograd import Variable 10 | import models 11 | import losses 12 | from utils import FastRandomIdentitySampler, mkdir_if_missing, logging, display 13 | from utils.serialization import save_checkpoint, load_checkpoint 14 | from trainer import train 15 | from utils import orth_reg 16 | 17 | import DataSet 18 | import numpy as np 19 | import os.path as osp 20 | cudnn.benchmark = True 21 | 22 | use_gpu = True 23 | 24 | # Batch Norm Freezer : bring 2% improvement on CUB 25 | def set_bn_eval(m): 26 | classname = m.__class__.__name__ 27 | if classname.find('BatchNorm') != -1: 28 | m.eval() 29 | 30 | 31 | def main(args): 32 | # s_ = time.time() 33 | 34 | save_dir = args.save_dir 35 | mkdir_if_missing(save_dir) 36 | 37 | sys.stdout = logging.Logger(os.path.join(save_dir, 'log.txt')) 38 | display(args) 39 | start = 0 40 | 41 | model = models.create(args.net, pretrained=True, dim=args.dim) 42 | 43 | # for vgg and densenet 44 | if args.resume is None: 45 | model_dict = model.state_dict() 46 | 47 | else: 48 | # resume model 49 | print('load model from {}'.format(args.resume)) 50 | chk_pt = load_checkpoint(args.resume) 51 | weight = chk_pt['state_dict'] 52 | start = chk_pt['epoch'] 53 | model.load_state_dict(weight) 54 | 55 | model = torch.nn.DataParallel(model) 56 | model = model.cuda() 57 | 58 | # freeze BN 59 | if args.freeze_BN is True: 60 | print(40 * '#', '\n BatchNorm frozen') 61 | model.apply(set_bn_eval) 62 | else: 63 | print(40*'#', 'BatchNorm NOT frozen') 64 | 65 | # Fine-tune the model: the learning rate for pre-trained parameter is 1/10 66 | new_param_ids = set(map(id, model.module.classifier.parameters())) 67 | 68 | new_params = [p for p in model.module.parameters() if 69 | id(p) in new_param_ids] 70 | 71 | base_params = [p for p in model.module.parameters() if 72 | id(p) not in new_param_ids] 73 | 74 | param_groups = [ 75 | {'params': base_params, 'lr_mult': 0.0}, 76 | {'params': new_params, 'lr_mult': 1.0}] 77 | 78 | print('initial model is save at %s' % save_dir) 79 | 80 | optimizer = torch.optim.Adam(param_groups, lr=args.lr, 81 | weight_decay=args.weight_decay) 82 | 83 | criterion = losses.create(args.loss, margin=args.margin, alpha=args.alpha, base=args.loss_base).cuda() 84 | 85 | # Decor_loss = losses.create('decor').cuda() 86 | data = DataSet.create(args.data, ratio=args.ratio, width=args.width, origin_width=args.origin_width, root=args.data_root) 87 | 88 | train_loader = torch.utils.data.DataLoader( 89 | data.train, batch_size=args.batch_size, 90 | sampler=FastRandomIdentitySampler(data.train, num_instances=args.num_instances), 91 | drop_last=True, pin_memory=True, num_workers=args.nThreads) 92 | 93 | # save the train information 94 | 95 | for epoch in range(start, args.epochs): 96 | 97 | train(epoch=epoch, model=model, criterion=criterion, 98 | optimizer=optimizer, train_loader=train_loader, args=args) 99 | 100 | if epoch == 1: 101 | optimizer.param_groups[0]['lr_mul'] = 0.1 102 | 103 | if (epoch+1) % args.save_step == 0 or epoch==0: 104 | if use_gpu: 105 | state_dict = model.module.state_dict() 106 | else: 107 | state_dict = model.state_dict() 108 | 109 | save_checkpoint({ 110 | 'state_dict': state_dict, 111 | 'epoch': (epoch+1), 112 | }, is_best=False, fpath=osp.join(args.save_dir, 'ckp_ep' + str(epoch + 1) + '.pth.tar')) 113 | 114 | if __name__ == '__main__': 115 | parser = argparse.ArgumentParser(description='Deep Metric Learning') 116 | 117 | # hype-parameters 118 | parser.add_argument('--lr', type=float, default=1e-5, help="learning rate of new parameters") 119 | parser.add_argument('--batch_size', '-b', default=128, type=int, metavar='N', 120 | help='mini-batch size (1 = pure stochastic) Default: 256') 121 | parser.add_argument('--num_instances', default=8, type=int, metavar='n', 122 | help=' number of samples from one class in mini-batch') 123 | parser.add_argument('--dim', default=512, type=int, metavar='n', 124 | help='dimension of embedding space') 125 | parser.add_argument('--width', default=224, type=int, 126 | help='width of input image') 127 | parser.add_argument('--origin_width', default=256, type=int, 128 | help='size of origin image') 129 | parser.add_argument('--ratio', default=0.16, type=float, 130 | help='random crop ratio for train data') 131 | 132 | parser.add_argument('--alpha', default=30, type=int, metavar='n', 133 | help='hyper parameter in NCA and its variants') 134 | parser.add_argument('--beta', default=0.1, type=float, metavar='n', 135 | help='hyper parameter in some deep metric loss functions') 136 | parser.add_argument('--orth_reg', default=0, type=float, 137 | help='hyper parameter coefficient for orth-reg loss') 138 | parser.add_argument('-k', default=16, type=int, metavar='n', 139 | help='number of neighbour points in KNN') 140 | parser.add_argument('--margin', default=0.5, type=float, 141 | help='margin in loss function') 142 | parser.add_argument('--init', default='random', 143 | help='the initialization way of FC layer') 144 | 145 | # network 146 | parser.add_argument('--freeze_BN', default=True, type=bool, required=False, metavar='N', 147 | help='Freeze BN if True') 148 | parser.add_argument('--data', default='cub', required=True, 149 | help='name of Data Set') 150 | parser.add_argument('--data_root', type=str, default=None, 151 | help='path to Data Set') 152 | 153 | parser.add_argument('--net', default='VGG16-BN') 154 | parser.add_argument('--loss', default='branch', required=True, 155 | help='loss for training network') 156 | parser.add_argument('--epochs', default=600, type=int, metavar='N', 157 | help='epochs for training process') 158 | parser.add_argument('--save_step', default=50, type=int, metavar='N', 159 | help='number of epochs to save model') 160 | 161 | # Resume from checkpoint 162 | parser.add_argument('--resume', '-r', default=None, 163 | help='the path of the pre-trained model') 164 | 165 | # train 166 | parser.add_argument('--print_freq', default=20, type=int, 167 | help='display frequency of training') 168 | 169 | # basic parameter 170 | # parser.add_argument('--checkpoints', default='/opt/intern/users/xunwang', 171 | # help='where the trained models save') 172 | parser.add_argument('--save_dir', default=None, 173 | help='where the trained models save') 174 | parser.add_argument('--nThreads', '-j', default=16, type=int, metavar='N', 175 | help='number of data loading threads (default: 2)') 176 | parser.add_argument('--momentum', type=float, default=0.9) 177 | parser.add_argument('--weight-decay', type=float, default=5e-4) 178 | 179 | parser.add_argument('--loss_base', type=float, default=0.75) 180 | 181 | 182 | 183 | 184 | main(parser.parse_args()) 185 | 186 | 187 | 188 | 189 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import print_function, absolute_import 3 | import time 4 | from utils import AverageMeter, orth_reg 5 | import torch 6 | from torch.autograd import Variable 7 | from torch.backends import cudnn 8 | 9 | cudnn.benchmark = True 10 | 11 | 12 | def train(epoch, model, criterion, optimizer, train_loader, args): 13 | 14 | losses = AverageMeter() 15 | batch_time = AverageMeter() 16 | accuracy = AverageMeter() 17 | pos_sims = AverageMeter() 18 | neg_sims = AverageMeter() 19 | 20 | end = time.time() 21 | 22 | freq = min(args.print_freq, len(train_loader)) 23 | 24 | for i, data_ in enumerate(train_loader, 0): 25 | 26 | inputs, labels = data_ 27 | 28 | # wrap them in Variable 29 | inputs = Variable(inputs).cuda() 30 | labels = Variable(labels).cuda() 31 | 32 | optimizer.zero_grad() 33 | 34 | embed_feat = model(inputs) 35 | 36 | loss, inter_, dist_ap, dist_an = criterion(embed_feat, labels) 37 | 38 | if args.orth_reg != 0: 39 | loss = orth_reg(net=model, loss=loss, cof=args.orth_reg) 40 | 41 | loss.backward() 42 | optimizer.step() 43 | 44 | # measure elapsed time 45 | batch_time.update(time.time() - end) 46 | end = time.time() 47 | 48 | losses.update(loss.item()) 49 | accuracy.update(inter_) 50 | pos_sims.update(dist_ap) 51 | neg_sims.update(dist_an) 52 | 53 | if (i + 1) % freq == 0 or (i+1) == len(train_loader): 54 | print('Epoch: [{0:03d}][{1}/{2}]\t' 55 | 'Time {batch_time.avg:.3f}\t' 56 | 'Loss {loss.avg:.4f} \t' 57 | 'Accuracy {accuracy.avg:.4f} \t' 58 | 'Pos {pos.avg:.4f}\t' 59 | 'Neg {neg.avg:.4f} \t'.format 60 | (epoch + 1, i + 1, len(train_loader), batch_time=batch_time, 61 | loss=losses, accuracy=accuracy, pos=pos_sims, neg=neg_sims)) 62 | 63 | if epoch == 0 and i == 0: 64 | print('-- HA-HA-HA-HA-AH-AH-AH-AH --') 65 | -------------------------------------------------------------------------------- /utils/Batch_generator.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import defaultdict 3 | 4 | import numpy as np 5 | 6 | 7 | class BatchGenerator(object): 8 | def __init__(self, labels, num_instances, batch_size): 9 | self.labels = labels 10 | self.num_instances = num_instances 11 | self.batch_size = batch_size 12 | self.ids = set(self.labels) 13 | self.num_id = batch_size//num_instances 14 | 15 | self.index_dic = defaultdict(list) 16 | 17 | for index, cat_id in enumerate(self.labels): 18 | self.index_dic[cat_id].append(index) 19 | 20 | def __len__(self): 21 | return self.num_id*self.num_instances 22 | 23 | def batch(self): 24 | ret = [] 25 | indices = np.random.choice( list(self.ids), size=self.num_id, replace=False) 26 | # print(indices) 27 | for cat_id in indices: 28 | t = self.index_dic[cat_id] 29 | if len(t) >= self.num_instances: 30 | t = np.random.choice(t, size=self.num_instances, replace=False) 31 | else: 32 | t = np.random.choice(t, size=self.num_instances, replace=True) 33 | ret.extend(t) 34 | return ret 35 | 36 | def get_id(self): 37 | ret = self.batch() 38 | # print(ret) 39 | result = [self.labels[k] for k in ret] 40 | return result 41 | 42 | 43 | def main(): 44 | labels = np.load('/Users/wangxun/Deep_metric/labels.npy') 45 | num_instances = 8 46 | batch_size = 128 47 | Batch = BatchGenerator(labels, num_instances=num_instances, batch_size=batch_size) 48 | print(Batch.batch()) 49 | 50 | if __name__ == '__main__': 51 | main() 52 | print('Hello world') -------------------------------------------------------------------------------- /utils/HyperparamterDisplay.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | 3 | 4 | def display(args): 5 | # Display information of current training 6 | print('Learn Rate \t%.1e' % args.lr) 7 | print('Epochs \t%05d' % args.epochs) 8 | print('Log Path \t%s' % args.save_dir) 9 | print('Network \t %s' % args.net) 10 | print('Data Set \t %s' % args.data) 11 | print('Batch Size \t %d' % args.batch_size) 12 | print('Num-Instance \t %d' % args.num_instances) 13 | print('Embedded Dimension \t %d' % args.dim) 14 | 15 | print('Loss Function \t%s' % args.loss) 16 | # print('Number of Neighbour \t%d' % args.k) 17 | print('Alpha \t %d' % args.alpha) 18 | 19 | print('Begin to fine tune %s Network' % args.net) 20 | print(40*'#') 21 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from .meters import * 3 | from .sampler import RandomIdentitySampler, FastRandomIdentitySampler 4 | import torch 5 | from .osutils import mkdir_if_missing 6 | from .orthogonal_regularizaton import orth_reg 7 | from .str2nums import chars2nums 8 | from .HyperparamterDisplay import display 9 | from .Batch_generator import BatchGenerator 10 | from .cluster import cluster_ 11 | from .numpy_tozero import to_zero 12 | 13 | def to_numpy(tensor): 14 | if torch.is_tensor(tensor): 15 | return tensor.cpu().numpy() 16 | elif type(tensor).__module__ != 'numpy': 17 | raise ValueError("Cannot convert {} to numpy array" 18 | .format(type(tensor))) 19 | return tensor 20 | 21 | 22 | def to_torch(ndarray): 23 | if type(ndarray).__module__ == 'numpy': 24 | return torch.from_numpy(ndarray) 25 | elif not torch.is_tensor(ndarray): 26 | raise ValueError("Cannot convert {} to torch tensor" 27 | .format(type(ndarray))) 28 | return ndarray 29 | 30 | -------------------------------------------------------------------------------- /utils/cluster.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import absolute_import, print_function 3 | 4 | import numpy as np 5 | 6 | from sklearn.cluster import KMeans 7 | # from sklearn.mixture import GaussianMixture 8 | 9 | 10 | def cluster_(features, labels, n_clusters): 11 | centers = [] 12 | center_labels = [] 13 | for label in set(labels): 14 | X = features[labels == label] 15 | kmeans = KMeans(n_clusters=n_clusters, random_state=None).fit(X) 16 | center_ = kmeans.cluster_centers_ 17 | centers.extend(center_) 18 | center_labels.extend(n_clusters*[label]) 19 | centers = np.conjugate(centers) 20 | centers = normalize(centers) 21 | return centers, center_labels 22 | 23 | 24 | def normalize(X): 25 | norm_inverse = np.diag(1/np.sqrt(np.sum(np.power(X, 2), 1))) 26 | X_norm = np.matmul(norm_inverse, X) 27 | return X_norm 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | -------------------------------------------------------------------------------- /utils/logging.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import sys 4 | 5 | from .osutils import mkdir_if_missing 6 | 7 | 8 | class Logger(object): 9 | def __init__(self, fpath=None): 10 | self.console = sys.stdout 11 | self.file = None 12 | if fpath is not None: 13 | mkdir_if_missing(os.path.dirname(fpath)) 14 | self.file = open(fpath, 'w') 15 | 16 | def __del__(self): 17 | self.close() 18 | 19 | def __enter__(self): 20 | pass 21 | 22 | def __exit__(self, *args): 23 | self.close() 24 | 25 | def write(self, msg): 26 | self.console.write(msg) 27 | if self.file is not None: 28 | self.file.write(msg) 29 | 30 | def flush(self): 31 | self.console.flush() 32 | if self.file is not None: 33 | self.file.flush() 34 | os.fsync(self.file.fileno()) 35 | 36 | def close(self): 37 | self.console.close() 38 | if self.file is not None: 39 | self.file.close() 40 | -------------------------------------------------------------------------------- /utils/map.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import absolute_import 3 | 4 | 5 | class Mazz(): 6 | def __init__(self, m=16, n=16, map_=[[0, 1], [1, 1]]): 7 | self.route = [] 8 | self.position = [0, 0] 9 | self.target = [m, n] 10 | self.map = map_ 11 | 12 | def get_avaliable_direction(self, position): 13 | # position 表示当前位置 14 | # 返回可以走的方向 如[[0, 1]]表示仅可以向下走 15 | avaliable_direction = [] 16 | if self.map[position[0]+1, position[1]] == 1: 17 | avaliable_direction.append([1, 0]) 18 | if self.map[position[0], position[0]+1] == 1: 19 | avaliable_direction.append([0, 1]) 20 | if self.map[position[0], position[0]-1] == 1: 21 | avaliable_direction.append([0, -1]) 22 | if self.map[position[0]-1, position[1]] == 1: 23 | avaliable_direction.append([-1, 0]) 24 | return avaliable_direction 25 | 26 | def get_route(self, position, target, last_route): 27 | avaliable_direction = self.get_avaliable_direction(self, position) 28 | if position == target: 29 | return [] 30 | 31 | for direction in avaliable_direction: 32 | position_ = [position[0]+direction[0], position[1]+direction[1]] 33 | if position not in last_route: 34 | last_route.append(position) 35 | return [direction]+ self.get_route(self, position_, target, last_route) 36 | 37 | def Solution(self): 38 | last_route = [] 39 | return self.get_route(self.position, self.target, last_route) 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | -------------------------------------------------------------------------------- /utils/meters.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | 4 | class AverageMeter(object): 5 | """Computes and stores the average and current value""" 6 | 7 | def __init__(self): 8 | self.val = 0 9 | self.avg = 0 10 | self.sum = 0 11 | self.count = 0 12 | 13 | def reset(self): 14 | self.val = 0 15 | self.avg = 0 16 | self.sum = 0 17 | self.count = 0 18 | 19 | def update(self, val, n=1): 20 | self.val = val 21 | self.sum += val * n 22 | self.count += n 23 | self.avg = self.sum / self.count 24 | -------------------------------------------------------------------------------- /utils/numpy_tozero.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | 3 | def to_zero(x): 4 | h = x.shape[0] 5 | w = x.shape[1] 6 | for i in range(h): 7 | for j in range(w): 8 | x[i][j] = 0 9 | -------------------------------------------------------------------------------- /utils/orthogonal_regularizaton.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, print_function 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | 6 | 7 | def orth_reg(net, loss, cof=1): 8 | orth_loss = 0 9 | 10 | if cof == 0: 11 | return orth_loss 12 | 13 | for m in net.modules(): 14 | if isinstance(m, nn.Linear): 15 | w = m.weight 16 | mat_ = torch.matmul(w, w.t()) 17 | diff = mat_ - torch.diag(torch.diag(mat_)) 18 | orth_loss = torch.mean(torch.pow(diff, 2)) 19 | loss = loss + cof*orth_loss 20 | return loss 21 | -------------------------------------------------------------------------------- /utils/osutils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import errno 4 | 5 | 6 | def mkdir_if_missing(dir_path): 7 | try: 8 | os.makedirs(dir_path) 9 | except OSError as e: 10 | if e.errno != errno.EEXIST: 11 | raise 12 | -------------------------------------------------------------------------------- /utils/sampler.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import defaultdict 3 | 4 | import numpy as np 5 | import torch 6 | from torch.utils.data.sampler import ( 7 | Sampler, SequentialSampler, RandomSampler, SubsetRandomSampler, 8 | WeightedRandomSampler) 9 | 10 | 11 | class RandomIdentitySampler(Sampler): 12 | def __init__(self, data_source, num_instances=1): 13 | self.data_source = data_source 14 | self.num_instances = num_instances 15 | self.index_dic = defaultdict(list) 16 | 17 | for index, (_, pid) in enumerate(data_source): 18 | self.index_dic[pid].append(index) 19 | 20 | self.pids = list(self.index_dic.keys()) 21 | self.num_samples = len(self.pids) 22 | # print(len(self)) 23 | # print(self.num_samples) 24 | 25 | def __len__(self): 26 | return self.num_samples * self.num_instances 27 | 28 | def __iter__(self): 29 | indices = torch.randperm(self.num_samples) 30 | ret = [] 31 | for i in indices: 32 | pid = self.pids[i] 33 | t = self.index_dic[pid] 34 | if len(t) >= self.num_instances: 35 | t = np.random.choice(t, size=self.num_instances, replace=False) 36 | else: 37 | t = np.random.choice(t, size=self.num_instances, replace=True) 38 | ret.extend(t) 39 | return iter(ret) 40 | 41 | 42 | class FastRandomIdentitySampler(Sampler): 43 | def __init__(self, data_source, num_instances=1): 44 | self.data_source = data_source 45 | self.num_instances = num_instances 46 | self.index_dic = defaultdict(list) 47 | 48 | # for index, (_, pid) in enumerate(data_source): 49 | # self.index_dic[pid].append(index) 50 | 51 | self.index_dic = data_source.Index 52 | 53 | self.pids = list(self.index_dic.keys()) 54 | self.num_samples = len(self.pids) 55 | 56 | def __len__(self): 57 | return self.num_samples * self.num_instances 58 | 59 | def __iter__(self): 60 | indices = torch.randperm(self.num_samples) 61 | ret = [] 62 | for i in indices: 63 | pid = self.pids[i] 64 | t = self.index_dic[pid] 65 | if len(t) >= self.num_instances: 66 | t = np.random.choice(t, size=self.num_instances, replace=False) 67 | else: 68 | t = np.random.choice(t, size=self.num_instances, replace=True) 69 | ret.extend(t) 70 | # print('Done data sampling') 71 | return iter(ret) 72 | -------------------------------------------------------------------------------- /utils/serialization.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import json 3 | import os.path as osp 4 | import shutil 5 | 6 | import torch 7 | from torch.nn import Parameter 8 | 9 | from .osutils import mkdir_if_missing 10 | 11 | 12 | def read_json(fpath): 13 | with open(fpath, 'r') as f: 14 | obj = json.load(f) 15 | return obj 16 | 17 | 18 | def write_json(obj, fpath): 19 | mkdir_if_missing(osp.dirname(fpath)) 20 | with open(fpath, 'w') as f: 21 | json.dump(obj, f, indent=4, separators=(',', ': ')) 22 | 23 | 24 | def save_checkpoint(state, is_best, fpath='checkpoint.pth.tar'): 25 | mkdir_if_missing(osp.dirname(fpath)) 26 | torch.save(state, fpath) 27 | if is_best: 28 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'model_best.pth.tar')) 29 | 30 | 31 | def load_checkpoint(fpath): 32 | if osp.isfile(fpath): 33 | checkpoint = torch.load(fpath) 34 | print("=> Loaded checkpoint '{}'".format(fpath)) 35 | return checkpoint 36 | else: 37 | raise ValueError("=> No checkpoint found at '{}'".format(fpath)) 38 | 39 | 40 | def copy_state_dict(state_dict, model, strip=None): 41 | tgt_state = model.state_dict() 42 | copied_names = set() 43 | for name, param in state_dict.items(): 44 | if strip is not None and name.startswith(strip): 45 | name = name[len(strip):] 46 | if name not in tgt_state: 47 | continue 48 | if isinstance(param, Parameter): 49 | param = param.data 50 | if param.size() != tgt_state[name].size(): 51 | print('mismatch:', name, param.size(), tgt_state[name].size()) 52 | continue 53 | tgt_state[name].copy_(param) 54 | copied_names.add(name) 55 | 56 | missing = set(tgt_state.keys()) - copied_names 57 | if len(missing) > 0: 58 | print("missing keys in state_dict:", missing) 59 | 60 | return model 61 | -------------------------------------------------------------------------------- /utils/str2nums.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | 4 | def chars2nums(a): 5 | nums_list = ['0'] 6 | for k in a: 7 | # print(k) 8 | if k == ',': 9 | nums_list.append('0') 10 | else: 11 | nums_list[-1] = nums_list[-1] + k 12 | 13 | return [int(char_) for char_ in nums_list] 14 | 15 | --------------------------------------------------------------------------------