├── .gitignore ├── README.md ├── cars196.py ├── cub2011.py ├── data ├── log.txt.margin_cars196 ├── log.txt.margin_cub2011 ├── log.txt.margin_stanfordonlineproducts └── log.txt.triplet ├── inception_v1_googlenet.py ├── model.py ├── resnet18.py ├── resnet50.py ├── sampler.py ├── stanford_online_products.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | *.pyc 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Metric learning [models](./model.py) in PyTorch, recall@1 2 | | |[CUB2011](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html) | [CARS196](http://ai.stanford.edu/~jkrause/cars/car_dataset.html) | [Stanford Online Products](http://cvgl.stanford.edu/projects/lifted_struct/) 3 | |:---:|:---:|:---:|:---:| 4 | | [Margin contrastive loss](https://arxiv.org/abs/1706.07567), semi-hard | [0.58](./data/log.txt.margin_cub2011) @ epoch60 | [0.80](./data/log.txt.margin_cars196) @ epoch60 | [0.7526](./data/log.txt.margin_stanfordonlineproducts) @ epoch90 | 5 | | [Lifted structured embedding](https://arxiv.org/abs/1511.06452) | 6 | | [Triplet loss](https://arxiv.org/abs/1503.03832)| 7 | 8 | Original impl of [Margin contrastive loss](https://arxiv.org/abs/1706.07567) published at: https://github.com/apache/incubator-mxnet/tree/19ede063c4756fa49cfe741e654180aee33991c6/example/gluon/embedding_learning (temporarily removed in https://github.com/apache/incubator-mxnet/pull/20602) 9 | 10 | # Examples 11 | ```shell 12 | # evaluation results are saved in ./data/log.txt 13 | 14 | # train margin contrastive loss on CUB2011 using ResNet-50 15 | python train.py --dataset cub2011 --model margin --base resnet50 16 | 17 | # download GoogLeNet weights and train using LiftedStruct loss 18 | wget -P ./data https://github.com/vadimkantorov/metriclearningbench/releases/download/data/googlenet.h5 19 | python train.py --dataset cub2011 --model liftedstruct --base inception_v1_googlenet 20 | 21 | # evaluate raw final layer embeddings on CUB2011 using ResNet-50 22 | python train.py --dataset cub2011 --model untrained --epochs 1 23 | ``` 24 | 25 | -------------------------------------------------------------------------------- /cars196.py: -------------------------------------------------------------------------------- 1 | import os 2 | import scipy.io 3 | import torch 4 | import torch.utils.data as data 5 | import torchvision 6 | from torchvision.datasets import ImageFolder 7 | from torchvision.datasets import CIFAR10 8 | from torchvision.datasets.folder import default_loader 9 | from torchvision.datasets.utils import download_url 10 | 11 | 12 | class Cars196MetricLearning(ImageFolder, CIFAR10): 13 | base_folder_devkit = 'devkit' 14 | url_devkit = 'http://ai.stanford.edu/~jkrause/cars/car_devkit.tgz' 15 | filename_devkit = 'cars_devkit.tgz' 16 | tgz_md5_devkit = 'c3b158d763b6e2245038c8ad08e45376' 17 | 18 | base_folder_trainims = 'cars_train' 19 | url_trainims = 'http://imagenet.stanford.edu/internal/car196/cars_train.tgz' 20 | filename_trainims = 'cars_ims_train.tgz' 21 | tgz_md5_trainims = '065e5b463ae28d29e77c1b4b166cfe61' 22 | 23 | base_folder_testims = 'cars_test' 24 | url_testims = 'http://imagenet.stanford.edu/internal/car196/cars_test.tgz' 25 | filename_testims = 'cars_ims_test.tgz' 26 | tgz_md5_testims = '4ce7ebf6a94d07f1952d94dd34c4d501' 27 | 28 | url_testanno = 'http://imagenet.stanford.edu/internal/car196/cars_test_annos_withlabels.mat' 29 | filename_testanno = 'cars_test_annos_withlabels.mat' 30 | mat_md5_testanno = 'b0a2b23655a3edd16d84508592a98d10' 31 | 32 | filename_trainanno = 'cars_train_annos.mat' 33 | 34 | base_folder = 'cars_train' 35 | train_list = [ 36 | ['00001.jpg', '8df595812fee3ca9a215e1ad4b0fb0c4'], 37 | ['00002.jpg', '4b9e5efcc3612378ec63a22f618b5028'] 38 | ] 39 | test_list = [] 40 | num_training_classes = 98 41 | 42 | def __init__(self, root, train=False, transform=None, target_transform=None, download=False, **kwargs): 43 | self.root = root 44 | self.transform = transform 45 | self.target_transform = target_transform 46 | self.loader = default_loader 47 | 48 | if download: 49 | self.url, self.filename, self.tgz_md5 = self.url_devkit, self.filename_devkit, self.tgz_md5_devkit 50 | self.download() 51 | 52 | self.url, self.filename, self.tgz_md5 = self.url_trainims, self.filename_trainims, self.tgz_md5_trainims 53 | self.download() 54 | 55 | self.url, self.filename, self.tgz_md5 = self.url_testims, self.filename_testims, self.tgz_md5_testims 56 | self.download() 57 | 58 | download_url(self.url_testanno, os.path.join(root, self.base_folder_devkit), self.filename_testanno, self.mat_md5_testanno) 59 | 60 | if not self._check_integrity(): 61 | raise RuntimeError('Dataset not found or corrupted.' + 62 | ' You can use download=True to download it') 63 | 64 | self.imgs = [(os.path.join(root, self.base_folder_trainims, a[-1][0]), int(a[-2][0]) - 1) for filename in [self.filename_trainanno] for a in scipy.io.loadmat(os.path.join(root, self.base_folder_devkit, filename))['annotations'][0] if (int(a[-2][0]) - 1 < self.num_training_classes) == train] + [(os.path.join(root, self.base_folder_testims, a[-1][0]), int(a[-2][0]) - 1) for filename in [self.filename_testanno] for a in scipy.io.loadmat(os.path.join(root, self.base_folder_devkit, filename))['annotations'][0] if (int(a[-2][0]) - 1 < self.num_training_classes) == train] 65 | -------------------------------------------------------------------------------- /cub2011.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.utils.data as data 4 | import torchvision 5 | from torchvision.datasets import ImageFolder 6 | from torchvision.datasets import CIFAR10 7 | 8 | 9 | class CUB2011(ImageFolder, CIFAR10): 10 | base_folder = 'CUB_200_2011/images' 11 | url = 'http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz' 12 | filename = 'CUB_200_2011.tgz' 13 | tgz_md5 = '97eceeb196236b17998738112f37df78' 14 | 15 | train_list = [ 16 | ['001.Black_footed_Albatross/Black_Footed_Albatross_0001_796111.jpg', '4c84da568f89519f84640c54b7fba7c2'], 17 | ['002.Laysan_Albatross/Laysan_Albatross_0001_545.jpg', 'e7db63424d0e384dba02aacaf298cdc0'], 18 | ] 19 | test_list = [ 20 | ['198.Rock_Wren/Rock_Wren_0001_189289.jpg', '487d082f1fbd58faa7b08aa5ede3cc00'], 21 | ['200.Common_Yellowthroat/Common_Yellowthroat_0003_190521.jpg', '96fd60ce4b4805e64368efc32bf5c6fe'] 22 | ] 23 | 24 | def __init__(self, root, transform=None, target_transform=None, download=False, **kwargs): 25 | self.root = root 26 | if download: 27 | self.download() 28 | 29 | if not self._check_integrity(): 30 | raise RuntimeError('Dataset not found or corrupted.' + 31 | ' You can use download=True to download it') 32 | ImageFolder.__init__(self, os.path.join(root, self.base_folder), 33 | transform=transform, target_transform=target_transform, **kwargs) 34 | 35 | class CUB2011MetricLearning(CUB2011): 36 | num_training_classes = 100 37 | 38 | def __init__(self, root, train=False, transform=None, target_transform=None, download=False, **kwargs): 39 | CUB2011.__init__(self, root, transform=transform, target_transform=target_transform, download=download, **kwargs) 40 | self.classes = self.classes[:self.num_training_classes] if train else self.classes[self.num_training_classes:] 41 | self.class_to_idx = {class_label : class_label_ind for class_label, class_label_ind in self.class_to_idx.items() if class_label in self.classes} 42 | self.imgs = [(image_file_path, class_label_ind) for image_file_path, class_label_ind in self.imgs if class_label_ind in self.class_to_idx.values()] 43 | -------------------------------------------------------------------------------- /data/log.txt.margin_cars196: -------------------------------------------------------------------------------- 1 | loss epoch 0: 0.3424 2 | recall@1 epoch 0: 0.341409 3 | loss epoch 1: 0.2839 4 | recall@1 epoch 1: 0.398844 5 | loss epoch 2: 0.2776 6 | recall@1 epoch 2: 0.433157 7 | loss epoch 3: 0.2715 8 | recall@1 epoch 3: 0.477309 9 | loss epoch 4: 0.2680 10 | recall@1 epoch 4: 0.501537 11 | loss epoch 5: 0.2671 12 | recall@1 epoch 5: 0.517771 13 | loss epoch 6: 0.2604 14 | recall@1 epoch 6: 0.539540 15 | loss epoch 7: 0.2559 16 | recall@1 epoch 7: 0.560817 17 | loss epoch 8: 0.2554 18 | recall@1 epoch 8: 0.591809 19 | loss epoch 9: 0.2493 20 | recall@1 epoch 9: 0.611241 21 | loss epoch 10: 0.2507 22 | recall@1 epoch 10: 0.630919 23 | loss epoch 11: 0.2461 24 | loss epoch 12: 0.2462 25 | loss epoch 13: 0.2410 26 | loss epoch 14: 0.2452 27 | loss epoch 15: 0.2380 28 | recall@1 epoch 15: 0.679375 29 | loss epoch 16: 0.2361 30 | loss epoch 17: 0.2349 31 | loss epoch 18: 0.2368 32 | loss epoch 19: 0.2310 33 | loss epoch 20: 0.2300 34 | recall@1 epoch 20: 0.714918 35 | loss epoch 21: 0.2274 36 | loss epoch 22: 0.2291 37 | loss epoch 23: 0.2283 38 | loss epoch 24: 0.2239 39 | loss epoch 25: 0.2230 40 | recall@1 epoch 25: 0.744312 41 | loss epoch 26: 0.2246 42 | loss epoch 27: 0.2205 43 | loss epoch 28: 0.2207 44 | loss epoch 29: 0.2158 45 | loss epoch 30: 0.2203 46 | recall@1 epoch 30: 0.753905 47 | loss epoch 31: 0.2168 48 | loss epoch 32: 0.2159 49 | loss epoch 33: 0.2174 50 | loss epoch 34: 0.2149 51 | loss epoch 35: 0.2135 52 | recall@1 epoch 35: 0.767556 53 | loss epoch 36: 0.2135 54 | loss epoch 37: 0.2132 55 | loss epoch 38: 0.2106 56 | loss epoch 39: 0.2120 57 | loss epoch 40: 0.2095 58 | recall@1 epoch 40: 0.774198 59 | loss epoch 41: 0.2082 60 | loss epoch 42: 0.2105 61 | loss epoch 43: 0.2118 62 | loss epoch 44: 0.2078 63 | loss epoch 45: 0.2072 64 | recall@1 epoch 45: 0.784774 65 | loss epoch 46: 0.2036 66 | loss epoch 47: 0.2056 67 | loss epoch 48: 0.2028 68 | loss epoch 49: 0.2045 69 | loss epoch 50: 0.2044 70 | recall@1 epoch 50: 0.789325 71 | loss epoch 51: 0.2026 72 | loss epoch 52: 0.2001 73 | loss epoch 53: 0.2035 74 | loss epoch 54: 0.1997 75 | loss epoch 55: 0.1995 76 | recall@1 epoch 55: 0.791293 77 | loss epoch 56: 0.1995 78 | loss epoch 57: 0.1958 79 | loss epoch 58: 0.1992 80 | loss epoch 59: 0.1962 81 | loss epoch 60: 0.1974 82 | recall@1 epoch 60: 0.796704 83 | loss epoch 61: 0.1976 84 | loss epoch 62: 0.1961 85 | loss epoch 63: 0.1941 86 | loss epoch 64: 0.1972 87 | loss epoch 65: 0.1930 88 | recall@1 epoch 65: 0.795228 89 | loss epoch 66: 0.1920 90 | loss epoch 67: 0.1925 91 | loss epoch 68: 0.1902 92 | loss epoch 69: 0.1947 93 | loss epoch 70: 0.1909 94 | recall@1 epoch 70: 0.801254 95 | loss epoch 71: 0.1886 96 | loss epoch 72: 0.1922 97 | loss epoch 73: 0.1879 98 | loss epoch 74: 0.1902 99 | loss epoch 75: 0.1860 100 | recall@1 epoch 75: 0.792522 101 | loss epoch 76: 0.1851 102 | loss epoch 77: 0.1890 103 | loss epoch 78: 0.1876 104 | loss epoch 79: 0.1864 105 | loss epoch 80: 0.1864 106 | recall@1 epoch 80: 0.798672 107 | loss epoch 81: 0.1834 108 | loss epoch 82: 0.1854 109 | loss epoch 83: 0.1838 110 | loss epoch 84: 0.1837 111 | loss epoch 85: 0.1826 112 | recall@1 epoch 85: 0.800885 113 | loss epoch 86: 0.1810 114 | loss epoch 87: 0.1822 115 | loss epoch 88: 0.1807 116 | loss epoch 89: 0.1838 117 | loss epoch 90: 0.1822 118 | recall@1 epoch 90: 0.803222 119 | loss epoch 91: 0.1822 120 | loss epoch 92: 0.1775 121 | loss epoch 93: 0.1812 122 | loss epoch 94: 0.1802 123 | loss epoch 95: 0.1773 124 | recall@1 epoch 95: 0.801623 125 | loss epoch 96: 0.1786 126 | loss epoch 97: 0.1765 127 | loss epoch 98: 0.1797 128 | loss epoch 99: 0.1774 129 | recall@1 epoch 99: 0.799656 130 | -------------------------------------------------------------------------------- /data/log.txt.margin_cub2011: -------------------------------------------------------------------------------- 1 | loss epoch 0: 0.307651839826 2 | recall@1 epoch 0: 0.408676569885 3 | loss epoch 1: 0.260696530666 4 | recall@1 epoch 1: 0.423531397704 5 | loss epoch 2: 0.250390903457 6 | recall@1 epoch 2: 0.436360567184 7 | loss epoch 3: 0.246007021355 8 | recall@1 epoch 3: 0.458980418636 9 | loss epoch 4: 0.241646199771 10 | recall@1 epoch 4: 0.469277515192 11 | loss epoch 5: 0.241376496204 12 | recall@1 epoch 5: 0.479237002026 13 | loss epoch 6: 0.238328721212 14 | recall@1 epoch 6: 0.488521269413 15 | loss epoch 7: 0.234978532338 16 | recall@1 epoch 7: 0.493079000675 17 | loss epoch 8: 0.232177487534 18 | recall@1 epoch 8: 0.499662390277 19 | loss epoch 9: 0.23453499312 20 | recall@1 epoch 9: 0.503376097232 21 | loss epoch 10: 0.230749792379 22 | recall@1 epoch 10: 0.508271438217 23 | loss epoch 11: 0.230309016679 24 | loss epoch 12: 0.224887439414 25 | loss epoch 13: 0.22651492286 26 | loss epoch 14: 0.222068217785 27 | loss epoch 15: 0.225454618425 28 | recall@1 epoch 15: 0.533929777178 29 | loss epoch 16: 0.217896575189 30 | loss epoch 17: 0.219121066772 31 | loss epoch 18: 0.218482105628 32 | loss epoch 19: 0.218714774303 33 | loss epoch 20: 0.216247693676 34 | recall@1 epoch 20: 0.545408507765 35 | loss epoch 21: 0.212131172743 36 | loss epoch 22: 0.215825723889 37 | loss epoch 23: 0.210707748714 38 | loss epoch 24: 0.212570798138 39 | loss epoch 25: 0.2129502128 40 | recall@1 epoch 25: 0.559756920999 41 | loss epoch 26: 0.213876347179 42 | loss epoch 27: 0.212965220213 43 | loss epoch 28: 0.209254210734 44 | loss epoch 29: 0.209565333698 45 | loss epoch 30: 0.208511256005 46 | recall@1 epoch 30: 0.557562457799 47 | loss epoch 31: 0.207103219693 48 | loss epoch 32: 0.203476860147 49 | loss epoch 33: 0.208993546341 50 | loss epoch 34: 0.206787813293 51 | loss epoch 35: 0.204728358142 52 | recall@1 epoch 35: 0.563808237677 53 | loss epoch 36: 0.204224372364 54 | loss epoch 37: 0.202145223384 55 | loss epoch 38: 0.204203057224 56 | loss epoch 39: 0.202038740334 57 | loss epoch 40: 0.198881830534 58 | recall@1 epoch 40: 0.568028359217 59 | loss epoch 41: 0.201480427838 60 | loss epoch 42: 0.200715439151 61 | loss epoch 43: 0.20016783735 62 | loss epoch 44: 0.200699485191 63 | loss epoch 45: 0.199302793845 64 | recall@1 epoch 45: 0.576637407157 65 | loss epoch 46: 0.197816935246 66 | loss epoch 47: 0.194830061625 67 | loss epoch 48: 0.194410285872 68 | loss epoch 49: 0.195210286457 69 | loss epoch 50: 0.189996315409 70 | recall@1 epoch 50: 0.577312626604 71 | loss epoch 51: 0.197133129058 72 | loss epoch 52: 0.191882981554 73 | loss epoch 53: 0.192962919564 74 | loss epoch 54: 0.190905502633 75 | loss epoch 55: 0.192039086119 76 | recall@1 epoch 55: 0.569209993248 77 | loss epoch 56: 0.187494912225 78 | loss epoch 57: 0.193167505705 79 | loss epoch 58: 0.190151484116 80 | loss epoch 59: 0.190349356636 81 | loss epoch 60: 0.18771220776 82 | recall@1 epoch 60: 0.57056043214 83 | loss epoch 61: 0.188167388348 84 | loss epoch 62: 0.184030650103 85 | loss epoch 63: 0.185289305837 86 | loss epoch 64: 0.187186750381 87 | loss epoch 65: 0.184913028194 88 | recall@1 epoch 65: 0.583727211344 89 | loss epoch 66: 0.183452914914 90 | loss epoch 67: 0.185527184735 91 | loss epoch 68: 0.18199729142 92 | loss epoch 69: 0.184882262155 93 | loss epoch 70: 0.183131900495 94 | recall@1 epoch 70: 0.567184334909 95 | loss epoch 71: 0.18158705079 96 | loss epoch 72: 0.184409142836 97 | loss epoch 73: 0.178368283031 98 | loss epoch 74: 0.177207721964 99 | loss epoch 75: 0.180457273579 100 | recall@1 epoch 75: 0.564314652262 101 | -------------------------------------------------------------------------------- /data/log.txt.margin_stanfordonlineproducts: -------------------------------------------------------------------------------- 1 | loss epoch 0: 0.2165 2 | recall@1 epoch 0: 0.626188 3 | loss epoch 1: 0.1980 4 | recall@1 epoch 1: 0.649113 5 | loss epoch 2: 0.1930 6 | recall@1 epoch 2: 0.666204 7 | loss epoch 3: 0.1904 8 | recall@1 epoch 3: 0.675179 9 | loss epoch 4: 0.1869 10 | recall@1 epoch 4: 0.679542 11 | loss epoch 5: 0.1860 12 | recall@1 epoch 5: 0.690154 13 | loss epoch 6: 0.1830 14 | recall@1 epoch 6: 0.696732 15 | loss epoch 7: 0.1824 16 | recall@1 epoch 7: 0.700418 17 | loss epoch 8: 0.1813 18 | recall@1 epoch 8: 0.703526 19 | loss epoch 9: 0.1804 20 | recall@1 epoch 9: 0.704005 21 | loss epoch 10: 0.1786 22 | recall@1 epoch 10: 0.708286 23 | loss epoch 11: 0.1772 24 | loss epoch 12: 0.1766 25 | loss epoch 13: 0.1760 26 | loss epoch 14: 0.1749 27 | loss epoch 15: 0.1745 28 | recall@1 epoch 15: 0.719360 29 | loss epoch 16: 0.1731 30 | loss epoch 17: 0.1726 31 | loss epoch 18: 0.1724 32 | loss epoch 19: 0.1716 33 | loss epoch 20: 0.1709 34 | recall@1 epoch 20: 0.725757 35 | loss epoch 21: 0.1704 36 | loss epoch 22: 0.1698 37 | loss epoch 23: 0.1691 38 | loss epoch 24: 0.1694 39 | loss epoch 25: 0.1688 40 | recall@1 epoch 25: 0.733294 41 | loss epoch 26: 0.1677 42 | loss epoch 27: 0.1670 43 | loss epoch 28: 0.1662 44 | loss epoch 29: 0.1661 45 | loss epoch 30: 0.1665 46 | recall@1 epoch 30: 0.732599 47 | loss epoch 31: 0.1656 48 | loss epoch 32: 0.1652 49 | loss epoch 33: 0.1651 50 | loss epoch 34: 0.1650 51 | loss epoch 35: 0.1652 52 | recall@1 epoch 35: 0.738649 53 | loss epoch 36: 0.1636 54 | loss epoch 37: 0.1638 55 | loss epoch 38: 0.1630 56 | loss epoch 39: 0.1638 57 | loss epoch 40: 0.1627 58 | recall@1 epoch 40: 0.738583 59 | loss epoch 41: 0.1619 60 | loss epoch 42: 0.1621 61 | loss epoch 43: 0.1621 62 | loss epoch 44: 0.1623 63 | loss epoch 45: 0.1614 64 | recall@1 epoch 45: 0.743541 65 | loss epoch 46: 0.1606 66 | loss epoch 47: 0.1611 67 | loss epoch 48: 0.1605 68 | loss epoch 49: 0.1608 69 | loss epoch 50: 0.1613 70 | recall@1 epoch 50: 0.741872 71 | loss epoch 51: 0.1609 72 | loss epoch 52: 0.1593 73 | loss epoch 53: 0.1600 74 | loss epoch 54: 0.1596 75 | loss epoch 55: 0.1595 76 | recall@1 epoch 55: 0.744054 77 | loss epoch 56: 0.1597 78 | loss epoch 57: 0.1589 79 | loss epoch 58: 0.1594 80 | loss epoch 59: 0.1590 81 | loss epoch 60: 0.1581 82 | recall@1 epoch 60: 0.743508 83 | loss epoch 61: 0.1586 84 | loss epoch 62: 0.1586 85 | loss epoch 63: 0.1580 86 | loss epoch 64: 0.1578 87 | loss epoch 65: 0.1570 88 | recall@1 epoch 65: 0.747112 89 | loss epoch 66: 0.1573 90 | loss epoch 67: 0.1570 91 | loss epoch 68: 0.1571 92 | loss epoch 69: 0.1567 93 | loss epoch 70: 0.1576 94 | recall@1 epoch 70: 0.747740 95 | loss epoch 71: 0.1574 96 | loss epoch 72: 0.1564 97 | loss epoch 73: 0.1573 98 | loss epoch 74: 0.1561 99 | loss epoch 75: 0.1566 100 | recall@1 epoch 75: 0.746203 101 | loss epoch 76: 0.1564 102 | loss epoch 77: 0.1557 103 | loss epoch 78: 0.1568 104 | loss epoch 79: 0.1556 105 | loss epoch 80: 0.1560 106 | recall@1 epoch 80: 0.750632 107 | loss epoch 81: 0.1560 108 | loss epoch 82: 0.1552 109 | loss epoch 83: 0.1554 110 | loss epoch 84: 0.1550 111 | loss epoch 85: 0.1559 112 | recall@1 epoch 85: 0.750533 113 | loss epoch 86: 0.1551 114 | loss epoch 87: 0.1552 115 | loss epoch 88: 0.1549 116 | loss epoch 89: 0.1551 117 | loss epoch 90: 0.1544 118 | recall@1 epoch 90: 0.752649 119 | loss epoch 91: 0.1539 120 | loss epoch 92: 0.1544 121 | loss epoch 93: 0.1536 122 | loss epoch 94: 0.1550 123 | loss epoch 95: 0.1540 124 | recall@1 epoch 95: 0.752516 125 | loss epoch 96: 0.1534 126 | loss epoch 97: 0.1540 127 | loss epoch 98: 0.1535 128 | loss epoch 99: 0.1538 129 | recall@1 epoch 99: 0.749277 130 | -------------------------------------------------------------------------------- /data/log.txt.triplet: -------------------------------------------------------------------------------- 1 | loss epoch 0: 0.692732459501 2 | recall@1 epoch 0: 0.410533423363 3 | loss epoch 1: 0.566342098233 4 | recall@1 epoch 1: 0.424037812289 5 | loss epoch 2: 0.535647478117 6 | recall@1 epoch 2: 0.428764348413 7 | loss epoch 3: 0.496225292916 8 | recall@1 epoch 3: 0.434166103984 9 | loss epoch 4: 0.47738218016 10 | recall@1 epoch 4: 0.435010128292 11 | loss epoch 5: 0.452147184183 12 | recall@1 epoch 5: 0.439567859554 13 | loss epoch 6: 0.442743059734 14 | recall@1 epoch 6: 0.436529372046 15 | loss epoch 7: 0.410216862093 16 | recall@1 epoch 7: 0.440411883862 17 | loss epoch 8: 0.428484538813 18 | recall@1 epoch 8: 0.442099932478 19 | loss epoch 9: 0.404211970937 20 | recall@1 epoch 9: 0.439736664416 21 | loss epoch 10: 0.402375747652 22 | recall@1 epoch 10: 0.443281566509 23 | loss epoch 11: 0.389410010984 24 | recall@1 epoch 11: 0.442437542201 25 | loss epoch 12: 0.396608246085 26 | recall@1 epoch 12: 0.446151249156 27 | loss epoch 13: 0.381904043905 28 | recall@1 epoch 13: 0.446995273464 29 | loss epoch 14: 0.382338195877 30 | recall@1 epoch 14: 0.448345712357 31 | loss epoch 15: 0.376436312561 32 | recall@1 epoch 15: 0.444969615125 33 | loss epoch 16: 0.370241520521 34 | recall@1 epoch 16: 0.449527346388 35 | loss epoch 17: 0.362908354272 36 | recall@1 epoch 17: 0.445644834571 37 | loss epoch 18: 0.351351717082 38 | recall@1 epoch 18: 0.446488858879 39 | loss epoch 19: 0.364051500092 40 | recall@1 epoch 19: 0.446320054018 41 | loss epoch 20: 0.357174557674 42 | recall@1 epoch 20: 0.447839297772 43 | loss epoch 21: 0.363504252842 44 | recall@1 epoch 21: 0.446995273464 45 | loss epoch 22: 0.364821989413 46 | recall@1 epoch 22: 0.451553004727 47 | loss epoch 23: 0.339786096597 48 | recall@1 epoch 23: 0.453578663065 49 | loss epoch 24: 0.352994238715 50 | recall@1 epoch 24: 0.451553004727 51 | loss epoch 25: 0.333510527792 52 | recall@1 epoch 25: 0.451721809588 53 | loss epoch 26: 0.340962251081 54 | recall@1 epoch 26: 0.448852126941 55 | loss epoch 27: 0.35343186201 56 | recall@1 epoch 27: 0.44868332208 57 | loss epoch 28: 0.343101203118 58 | recall@1 epoch 28: 0.449696151249 59 | loss epoch 29: 0.330306629124 60 | recall@1 epoch 29: 0.449358541526 61 | loss epoch 30: 0.328538669192 62 | recall@1 epoch 30: 0.450540175557 63 | loss epoch 31: 0.336667052916 64 | recall@1 epoch 31: 0.448514517218 65 | loss epoch 32: 0.3378175131 66 | recall@1 epoch 32: 0.450371370695 67 | loss epoch 33: 0.341602200723 68 | recall@1 epoch 33: 0.453578663065 69 | loss epoch 34: 0.32409546751 70 | recall@1 epoch 34: 0.45087778528 71 | loss epoch 35: 0.325712782695 72 | recall@1 epoch 35: 0.450202565834 73 | loss epoch 36: 0.335531536328 74 | recall@1 epoch 36: 0.450033760972 75 | loss epoch 37: 0.319976006837 76 | recall@1 epoch 37: 0.448852126941 77 | loss epoch 38: 0.327589080709 78 | recall@1 epoch 38: 0.449864956111 79 | loss epoch 39: 0.327446717445 80 | recall@1 epoch 39: 0.448345712357 81 | loss epoch 40: 0.33329237463 82 | recall@1 epoch 40: 0.449527346388 83 | loss epoch 41: 0.32932183762 84 | recall@1 epoch 41: 0.45189061445 85 | loss epoch 42: 0.311598845798 86 | recall@1 epoch 42: 0.451384199865 87 | loss epoch 43: 0.332838869937 88 | recall@1 epoch 43: 0.450708980419 89 | loss epoch 44: 0.31082314431 90 | recall@1 epoch 44: 0.451046590142 91 | loss epoch 45: 0.326535865017 92 | recall@1 epoch 45: 0.448852126941 93 | loss epoch 46: 0.306733998915 94 | recall@1 epoch 46: 0.448345712357 95 | loss epoch 47: 0.315368418784 96 | recall@1 epoch 47: 0.448852126941 97 | loss epoch 48: 0.312686819907 98 | recall@1 epoch 48: 0.449189736664 99 | loss epoch 49: 0.317843409012 100 | recall@1 epoch 49: 0.448514517218 101 | loss epoch 50: 0.326787239507 102 | recall@1 epoch 50: 0.45087778528 103 | loss epoch 51: 0.31320472256 104 | recall@1 epoch 51: 0.44868332208 105 | loss epoch 52: 0.323794039209 106 | recall@1 epoch 52: 0.449527346388 107 | loss epoch 53: 0.311905582638 108 | recall@1 epoch 53: 0.45087778528 109 | loss epoch 54: 0.301235658805 110 | recall@1 epoch 54: 0.451721809588 111 | loss epoch 55: 0.320410002833 112 | recall@1 epoch 55: 0.450708980419 113 | loss epoch 56: 0.310281305695 114 | recall@1 epoch 56: 0.451721809588 115 | -------------------------------------------------------------------------------- /inception_v1_googlenet.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import torch 3 | import torch.nn as nn 4 | 5 | class inception_v1_googlenet(nn.Sequential): 6 | output_size = 1024 7 | input_side = 227 8 | rescale = 255.0 9 | rgb_mean = [122.7717, 115.9465, 102.9801] 10 | rgb_std = [1, 1, 1] 11 | 12 | def __init__(self): 13 | super(inception_v1_googlenet, self).__init__(OrderedDict([ 14 | ('conv1', nn.Sequential(OrderedDict([ 15 | ('7x7_s2', nn.Conv2d(3, 64, (7, 7), (2, 2), (3, 3))), 16 | ('relu1', nn.ReLU(True)), 17 | ('pool1', nn.MaxPool2d((3, 3), (2, 2), ceil_mode = True)), 18 | ('lrn1', nn.CrossMapLRN2d(5, 0.0001, 0.75, 1)) 19 | ]))), 20 | 21 | ('conv2', nn.Sequential(OrderedDict([ 22 | ('3x3_reduce', nn.Conv2d(64, 64, (1, 1), (1, 1), (0, 0))), 23 | ('relu1', nn.ReLU(True)), 24 | ('3x3', nn.Conv2d(64, 192, (3, 3), (1, 1), (1, 1))), 25 | ('relu2', nn.ReLU(True)), 26 | ('lrn2', nn.CrossMapLRN2d(5, 0.0001, 0.75, 1)), 27 | ('pool2', nn.MaxPool2d((3, 3), (2, 2), ceil_mode = True)) 28 | ]))), 29 | 30 | ('inception_3a', InceptionModule(192, 64, 96, 128, 16, 32, 32)), 31 | ('inception_3b', InceptionModule(256, 128, 128, 192, 32, 96, 64)), 32 | 33 | ('pool3', nn.MaxPool2d((3, 3), (2, 2), ceil_mode = True)), 34 | 35 | ('inception_4a', InceptionModule(480, 192, 96, 208, 16, 48, 64)), 36 | ('inception_4b', InceptionModule(512, 160, 112, 224, 24, 64, 64)), 37 | ('inception_4c', InceptionModule(512, 128, 128, 256, 24, 64, 64)), 38 | ('inception_4d', InceptionModule(512, 112, 144, 288, 32, 64, 64)), 39 | ('inception_4e', InceptionModule(528, 256, 160, 320, 32, 128, 128)), 40 | 41 | ('pool4', nn.MaxPool2d((3, 3), (2, 2), ceil_mode = True)), 42 | 43 | ('inception_5a', InceptionModule(832, 256, 160, 320, 32, 128, 128)), 44 | ('inception_5b', InceptionModule(832, 384, 192, 384, 48, 128, 128)), 45 | 46 | ('pool5', nn.AvgPool2d((7, 7), (1, 1), ceil_mode = True)), 47 | 48 | #('drop5', nn.Dropout(0.4)) 49 | ])) 50 | 51 | class InceptionModule(nn.Module): 52 | def __init__(self, inplane, outplane_a1x1, outplane_b3x3_reduce, outplane_b3x3, outplane_c5x5_reduce, outplane_c5x5, outplane_pool_proj): 53 | super(InceptionModule, self).__init__() 54 | a = nn.Sequential(OrderedDict([ 55 | ('1x1', nn.Conv2d(inplane, outplane_a1x1, (1, 1), (1, 1), (0, 0))), 56 | ('1x1_relu', nn.ReLU(True)) 57 | ])) 58 | 59 | b = nn.Sequential(OrderedDict([ 60 | ('3x3_reduce', nn.Conv2d(inplane, outplane_b3x3_reduce, (1, 1), (1, 1), (0, 0))), 61 | ('3x3_relu1', nn.ReLU(True)), 62 | ('3x3', nn.Conv2d(outplane_b3x3_reduce, outplane_b3x3, (3, 3), (1, 1), (1, 1))), 63 | ('3x3_relu2', nn.ReLU(True)) 64 | ])) 65 | 66 | c = nn.Sequential(OrderedDict([ 67 | ('5x5_reduce', nn.Conv2d(inplane, outplane_c5x5_reduce, (1, 1), (1, 1), (0, 0))), 68 | ('5x5_relu1', nn.ReLU(True)), 69 | ('5x5', nn.Conv2d(outplane_c5x5_reduce, outplane_c5x5, (5, 5), (1, 1), (2, 2))), 70 | ('5x5_relu2', nn.ReLU(True)) 71 | ])) 72 | 73 | d = nn.Sequential(OrderedDict([ 74 | ('pool_pool', nn.MaxPool2d((3, 3), (1, 1), (1, 1))), 75 | ('pool_proj', nn.Conv2d(inplane, outplane_pool_proj, (1, 1), (1, 1), (0, 0))), 76 | ('pool_relu', nn.ReLU(True)) 77 | ])) 78 | 79 | for container in [a, b, c, d]: 80 | for name, module in container.named_children(): 81 | self.add_module(name, module) 82 | 83 | self.branches = [a, b, c, d] 84 | 85 | def forward(self, input): 86 | return torch.cat([branch(input) for branch in self.branches], 1) 87 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | def topk_mask(input, dim, K = 10, **kwargs): 6 | index = input.topk(max(1, min(K, input.size(dim))), dim = dim, **kwargs)[1] 7 | return torch.zeros_like(input).scatter(dim, index, 1.0) 8 | 9 | def pdist(A, squared = False, eps = 1e-4): 10 | prod = torch.mm(A, A.t()) 11 | norm = prod.diag().unsqueeze(1).expand_as(prod) 12 | res = (norm + norm.t() - 2 * prod).clamp(min = 0) 13 | return res if squared else res.clamp(min = eps).sqrt() 14 | 15 | class Model(nn.Module): 16 | def __init__(self, base_model, num_classes, embedding_size = 128): 17 | super(Model, self).__init__() 18 | self.base_model = base_model 19 | self.num_classes = num_classes 20 | self.embedder = nn.Linear(base_model.output_size, embedding_size) 21 | 22 | def forward(self, input): 23 | return self.embedder(F.relu(self.base_model(input).view(len(input), -1))) 24 | 25 | criterion = None 26 | optimizer = torch.optim.SGD 27 | optimizer_params = dict(lr = 1e-4, momentum = 0.9, weight_decay = 2e-4) 28 | lr_scheduler_params = dict(step_size = float('inf'), gamma = 0.1) 29 | 30 | class Untrained(Model): 31 | def forward(self, input): 32 | return self.base_model(input).view(input.size(0), -1).detach() 33 | 34 | class LiftedStruct(Model): 35 | def criterion(self, embeddings, labels, margin = 1.0, eps = 1e-4): 36 | d = pdist(embeddings, squared = False, eps = eps) 37 | pos = torch.eq(*[labels.unsqueeze(dim).expand_as(d) for dim in [0, 1]]).type_as(d) 38 | neg_i = torch.mul((margin - d).exp(), 1 - pos).sum(1).expand_as(d) 39 | return torch.sum(F.relu(pos.triu(1) * ((neg_i + neg_i.t()).log() + d)).pow(2)) / (pos.sum() - len(d)) 40 | 41 | class Triplet(Model): 42 | def criterion(self, embeddings, labels, margin = 1.0): 43 | d = pdist(embeddings) 44 | pos = torch.eq(*[labels.unsqueeze(dim).expand_as(d) for dim in [0, 1]]).type_as(d) - torch.autograd.Variable(torch.eye(len(d))).type_as(d) 45 | T = d.unsqueeze(1).expand(*(len(d),) * 3) 46 | M = pos.unsqueeze(1).expand_as(T) * (1 - pos.unsqueeze(2).expand_as(T)) 47 | return (M * F.relu(T - T.transpose(1, 2) + margin)).sum() / M.sum() 48 | 49 | optimizer_params = dict(lr = 1e-4, momentum = 0.9, weight_decay = 5e-4) 50 | lr_scheduler_params = dict(step_size = 30, gamma = 0.5) 51 | 52 | class TripletRatio(Model): 53 | def criterion(self, embeddings, labels, margin = 0.1, eps = 1e-4): 54 | d = pdist(embeddings, squared = False, eps = eps) 55 | pos = torch.eq(*[labels.unsqueeze(dim).expand_as(d) for dim in [0, 1]]).type_as(d) 56 | T = d.unsqueeze(1).expand(*(len(d),) * 3) 57 | M = pos.unsqueeze(1).expand_as(T) * (1 - pos.unsqueeze(2).expand_as(T)) 58 | return (M * T.div(T.transpose(1, 2) + margin)).sum() / M.sum() 59 | 60 | class Pddm(Model): 61 | def __init__(self, base_model, num_classes, d = 1024): 62 | nn.Module.__init__(self) 63 | self.base_model = base_model 64 | #self.embedder = nn.Linear(base_model.output_size, d) 65 | self.embedder = lambda x: x #nn.Linear(base_model.output_size, d) 66 | self.wu = nn.Linear(d, d) 67 | self.wv = nn.Linear(d, d) 68 | self.wc = nn.Linear(2 * d, d) 69 | self.ws = nn.Linear(d, 1) 70 | 71 | def forward(self, input): 72 | return F.normalize(Model.forward(self, input)) 73 | 74 | def criterion(self, embeddings, labels, Alpha = 0.5, Beta = 1.0, Lambda = 0.5): 75 | #embeddings = embeddings * topk_mask(embeddings, dim = 1, K = 512) 76 | d = pdist(embeddings, squared = True) 77 | pos = torch.eq(*[labels.unsqueeze(dim).expand_as(d) for dim in [0, 1]]).type_as(embeddings) - torch.autograd.Variable(torch.eye(len(d))).type_as(embeddings) 78 | 79 | f1, f2 = [embeddings.unsqueeze(dim).expand(len(embeddings), *embeddings.size()) for dim in [0, 1]] 80 | u = (f1 - f2).abs() 81 | v = (f1 + f2) / 2 82 | u_ = F.normalize(F.relu(F.dropout(self.wu(u.view(-1, u.size(-1))), training = self.training))) 83 | v_ = F.normalize(F.relu(F.dropout(self.wv(v.view(-1, v.size(-1))), training = self.training))) 84 | s = self.ws(F.relu(F.dropout(self.wc(torch.cat((u_, v_), -1)), training = self.training))).view_as(d) 85 | 86 | sneg = s * (1 - pos) 87 | i, j = min([(s[i, j], (i, j)) for i, j in pos.nonzero()])[1] 88 | k, l = sneg.max(1, keepdim = True)[1][[i, j], ...].squeeze(1) 89 | 90 | E_m = F.relu(Alpha - s[i, j] + s[i, k]) + F.relu(Alpha - s[i, j] + s[j, l]) 91 | E_e = F.relu(Beta + d[i, j] - d[i, k]) + F.relu(Beta + d[i, j] - d[j, l]) 92 | 93 | return E_m + Lambda * E_e 94 | 95 | optimizer_params = dict(lr = 1e-4, momentum = 0.9, weight_decay = 5e-4) 96 | lr_scheduler_params = dict(step_size = 10, gamma = 0.1) 97 | 98 | class Margin(Model): 99 | def forward(self, input): 100 | return F.normalize(Model.forward(self, input)) 101 | 102 | def criterion(self, embeddings, labels, alpha = 0.2, beta = 1.2, distance_threshold = 0.5, inf = 1e6, eps = 1e-6, distance_weighted_sampling = False): 103 | d = pdist(embeddings) 104 | pos = torch.eq(*[labels.unsqueeze(dim).expand_as(d) for dim in [0, 1]]).type_as(d) - torch.eye(len(d)).type_as(d) 105 | num_neg = int(pos.sum() / len(pos)) 106 | if distance_weighted_sampling: 107 | neg = torch.zeros_like(pos).scatter_(1, torch.multinomial((d.clamp(min = distance_threshold).pow(embeddings.size(-1) - 2) * (1 - d.clamp(min = distance_threshold).pow(2) / 4).pow(0.5 * (embeddings.size(-1) - 3))).reciprocal().masked_fill_(pos + torch.eye(len(d)).type_as(d) > 0, eps), replacement = False, num_samples = num_neg), 1) 108 | else: 109 | neg = topk_mask(d + inf * ((pos > 0) + (d < distance_threshold)).type_as(d), dim = 1, largest = False, K = num_neg) 110 | L = F.relu(alpha + (pos * 2 - 1) * (d - beta)) 111 | M = ((pos + neg > 0) * (L > 0)).float() 112 | return (M * L).sum() / M.sum() 113 | 114 | optimizer = torch.optim.Adam 115 | optimizer_params = dict(lr = 1e-3, weight_decay = 1e-4, base_model_lr_mult = 1e-2) 116 | #optimizer_params = dict(lr = 1e-3, momentum = 0.9, weight_decay = 5e-4, base_model_lr_mult = 1) 117 | #lr_scheduler_params = dict(step_size = 10, gamma = 0.5) 118 | -------------------------------------------------------------------------------- /resnet18.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import torch.nn as nn 3 | import torchvision 4 | 5 | class resnet18(nn.Sequential): 6 | output_size = 512 7 | input_side = 224 8 | rescale = 1 9 | rgb_mean = [0.485, 0.456, 0.406] 10 | rgb_std = [0.229, 0.224, 0.225] 11 | 12 | def __init__(self, dilation = False): 13 | super(resnet18, self).__init__() 14 | pretrained = torchvision.models.resnet18(pretrained = True) 15 | for module in filter(lambda m: type(m) == nn.BatchNorm2d, pretrained.modules()): 16 | module.eval() 17 | module.train = lambda _: None 18 | 19 | if dilation: 20 | pretrained.layer4[0].conv1.dilation = (2, 2) 21 | pretrained.layer4[0].conv1.padding = (2, 2) 22 | pretrained.layer4[0].conv1.stride = (1, 1) 23 | pretrained.layer4[0].downsample[0].stride = (1, 1) 24 | 25 | for module_name in ['conv1', 'bn1', 'relu', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'avgpool']: 26 | self.add_module(module_name, getattr(pretrained, module_name)) 27 | -------------------------------------------------------------------------------- /resnet50.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import torch.nn as nn 3 | import torchvision 4 | 5 | class resnet50(nn.Sequential): 6 | output_size = 2048 7 | input_side = 224 8 | rescale = 1 9 | rgb_mean = [0.485, 0.456, 0.406] 10 | rgb_std = [0.229, 0.224, 0.225] 11 | 12 | def __init__(self, dilation = False): 13 | super(resnet50, self).__init__() 14 | pretrained = torchvision.models.resnet50(pretrained = True) 15 | for module in filter(lambda m: type(m) == nn.BatchNorm2d, pretrained.modules()): 16 | module.eval() 17 | module.train = lambda _: None 18 | 19 | if dilation: 20 | pretrained.layer4[0].conv1.dilation = (2, 2) 21 | pretrained.layer4[0].conv1.padding = (2, 2) 22 | pretrained.layer4[0].conv1.stride = (1, 1) 23 | pretrained.layer4[0].downsample[0].stride = (1, 1) 24 | 25 | for module_name in ['conv1', 'bn1', 'relu', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'avgpool']: 26 | self.add_module(module_name, getattr(pretrained, module_name)) 27 | -------------------------------------------------------------------------------- /sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import itertools 4 | 5 | def index_dataset(dataset): 6 | return {c : [example_idx for example_idx, (image_file_name, class_label_ind) in enumerate(dataset.imgs) if class_label_ind == c] for c in set(dict(dataset.imgs).values())} 7 | 8 | def sample_from_class(images_by_class, class_label_ind): 9 | return images_by_class[class_label_ind][random.randrange(len(images_by_class[class_label_ind]))] 10 | 11 | def simple(batch_size, dataset, prob_other = 0.5): 12 | '''lazy sampling, not like in lifted_struct. they add to the pool all postiive combinations, then compute the average number of positive pairs per image, then sample for every image the same number of negative pairs''' 13 | images_by_class = index_dataset(dataset) 14 | for batch_idx in xrange(int(math.ceil(len(dataset) * 1.0 / batch_size))): 15 | example_indices = [] 16 | for i in range(0, batch_size, 2): 17 | perm = random.sample(images_by_class.keys(), 2) 18 | example_indices += [sample_from_class(images_by_class, perm[0]), sample_from_class(images_by_class, perm[0 if i == 0 or random.random() > prob_other else 1])] 19 | yield example_indices[:batch_size] 20 | 21 | def triplet(batch_size, dataset): 22 | images_by_class = index_dataset(dataset) 23 | for batch_idx in xrange(int(math.ceil(len(dataset) * 1.0 / batch_size))): 24 | example_indices = [] 25 | for i in range(0, batch_size, 3): 26 | perm = random.sample(images_by_class.keys(), 2) 27 | example_indices += [sample_from_class(images_by_class, perm[0]), sample_from_class(images_by_class, perm[0]), sample_from_class(images_by_class, perm[1])] 28 | yield example_indices[:batch_size] 29 | 30 | def npairs(batch_size, dataset, K = 4): 31 | images_by_class = index_dataset(dataset) 32 | for batch_idx in xrange(int(math.ceil(len(dataset) * 1.0 / batch_size))): 33 | example_indices = [sample_from_class(images_by_class, class_label_ind) for k in range(int(math.ceil(batch_size * 1.0 / K))) for class_label_ind in [random.choice(images_by_class.keys())] for i in range(K)] 34 | yield example_indices[:batch_size] 35 | -------------------------------------------------------------------------------- /stanford_online_products.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.utils.data as data 4 | import torchvision 5 | from torchvision.datasets import ImageFolder 6 | from torchvision.datasets import CIFAR10 7 | from torchvision.datasets.utils import download_url 8 | from torchvision.datasets.folder import default_loader 9 | 10 | class StanfordOnlineProducts(ImageFolder, CIFAR10): 11 | base_folder = 'Stanford_Online_Products' 12 | url = 'ftp://cs.stanford.edu/cs/cvgl/Stanford_Online_Products.zip' 13 | filename = 'Stanford_Online_Products.zip' 14 | zip_md5 = '7f73d41a2f44250d4779881525aea32e' 15 | 16 | train_list = [ 17 | ['bicycle_final/111265328556_0.JPG', '77420a4db9dd9284378d7287a0729edb'], 18 | ['chair_final/111182689872_0.JPG', 'ce78d10ed68560f4ea5fa1bec90206ba'] 19 | ] 20 | test_list = [ 21 | ['table_final/111194782300_0.JPG', '8203e079b5c134161bbfa7ee2a43a0a1'], 22 | ['toaster_final/111157129195_0.JPG', 'd6c24ee8c05d986cafffa6af82ae224e'] 23 | ] 24 | num_training_classes = 11318 25 | 26 | def __init__(self, root, train=False, transform=None, target_transform=None, download=False, **kwargs): 27 | self.root = root 28 | self.transform = transform 29 | self.target_transform = target_transform 30 | self.loader = default_loader 31 | 32 | if download: 33 | self.download() 34 | 35 | if not self._check_integrity(): 36 | raise RuntimeError('Dataset not found or corrupted.' + 37 | ' You can use download=True to download it') 38 | 39 | self.imgs = [(os.path.join(root, self.base_folder, path), int(class_id) - 1) for i, (image_id, class_id, super_class_id, path) in enumerate(map(str.split, open(os.path.join(root, self.base_folder, 'Ebay_{}.txt'.format('train' if train else 'test'))))) if i > 1] 40 | 41 | def download(self): 42 | import zipfile 43 | 44 | if self._check_integrity(): 45 | print('Files already downloaded and verified') 46 | return 47 | 48 | root = self.root 49 | download_url(self.url, root, self.filename, self.zip_md5) 50 | 51 | # extract file 52 | cwd = os.getcwd() 53 | os.chdir(root) 54 | with zipfile.ZipFile(self.filename, "r") as zip: 55 | zip.extractall() 56 | os.chdir(cwd) 57 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import random 5 | import argparse 6 | import itertools 7 | import hickle 8 | import torch 9 | import torchvision.transforms as T 10 | 11 | import cub2011 12 | import cars196 13 | import stanford_online_products 14 | import inception_v1_googlenet 15 | import resnet18 16 | import resnet50 17 | import model 18 | import sampler 19 | 20 | parser = argparse.ArgumentParser() 21 | LookupChoices = type('', (argparse.Action, ), dict(__call__ = lambda a, p, n, v, o: setattr(n, a.dest, a.choices[v]))) 22 | parser.add_argument('--dataset', choices = dict(cub2011 = cub2011.CUB2011MetricLearning, cars196 = cars196.Cars196MetricLearning, stanford_online_products = stanford_online_products.StanfordOnlineProducts), default = cub2011.CUB2011MetricLearning, action = LookupChoices) 23 | parser.add_argument('--base', choices = dict(inception_v1_googlenet = inception_v1_googlenet.inception_v1_googlenet, resnet18 = resnet18.resnet18, resnet50 = resnet50.resnet50), default = resnet50.resnet50, action = LookupChoices) 24 | parser.add_argument('--model', choices = dict(liftedstruct = model.LiftedStruct, triplet = model.Triplet, tripletratio = model.TripletRatio, pddm = model.Pddm, untrained = model.Untrained, margin = model.Margin), default = model.Margin, action = LookupChoices) 25 | parser.add_argument('--sampler', choices = dict(simple = sampler.simple, triplet = sampler.triplet, npairs = sampler.npairs), default = sampler.npairs, action = LookupChoices) 26 | parser.add_argument('--data', default = 'data') 27 | parser.add_argument('--log', default = 'data/log.txt') 28 | parser.add_argument('--seed', default = 1, type = int) 29 | parser.add_argument('--threads', default = 16, type = int) 30 | parser.add_argument('--epochs', default = 100, type = int) 31 | parser.add_argument('--batch', default = 128, type = int) 32 | opts = parser.parse_args() 33 | 34 | for set_random_seed in [random.seed, torch.manual_seed, torch.cuda.manual_seed_all]: 35 | set_random_seed(opts.seed) 36 | 37 | def recall(embeddings, labels, K = 1): 38 | prod = torch.mm(embeddings, embeddings.t()) 39 | norm = prod.diag().unsqueeze(1).expand_as(prod) 40 | D = norm + norm.t() - 2 * prod 41 | knn_inds = D.topk(1 + K, dim = 1, largest = False)[1][:, 1:] 42 | return (labels.unsqueeze(-1).expand_as(knn_inds) == labels[knn_inds.flatten()].view_as(knn_inds)).max(1)[0].float().mean() 43 | 44 | base_model = opts.base() 45 | base_model_weights_path = os.path.join(opts.data, opts.base.__name__ + '.h5') 46 | if os.path.exists(base_model_weights_path): 47 | base_model.load_state_dict({k : torch.from_numpy(v) for k, v in hickle.load(base_model_weights_path).items()}) 48 | 49 | normalize = T.Compose([ 50 | T.ToTensor(), 51 | T.Lambda(lambda x: x * base_model.rescale), 52 | T.Normalize(mean = base_model.rgb_mean, std = base_model.rgb_std), 53 | T.Lambda(lambda x: x[[2, 1, 0], ...]) 54 | ]) 55 | 56 | dataset_train = opts.dataset(opts.data, train = True, transform = transforms.Compose([ 57 | T.RandomSizedCrop(base_model.input_side), 58 | T.RandomHorizontalFlip(), 59 | normalize 60 | ]), download = True) 61 | dataset_eval = opts.dataset(opts.data, train = False, transform = transforms.Compose([ 62 | T.Scale(256), 63 | T.CenterCrop(base_model.input_side), 64 | normalize 65 | ]), download = True) 66 | 67 | adapt_sampler = lambda batch, dataset, sampler, **kwargs: type('', (torch.utils.data.Sampler, ), dict(__len__ = dataset.__len__, __iter__ = lambda _: itertools.chain.from_iterable(sampler(batch, dataset, **kwargs))))() 68 | loader_train = torch.utils.data.DataLoader(dataset_train, sampler = adapt_sampler(opts.batch, dataset_train, opts.sampler), num_workers = opts.threads, batch_size = opts.batch, drop_last = True, pin_memory = True) 69 | loader_eval = torch.utils.data.DataLoader(dataset_eval, shuffle = False, num_workers = opts.threads, batch_size = opts.batch, pin_memory = True) 70 | 71 | model = opts.model(base_model, dataset_train.num_training_classes).cuda() 72 | model_weights, model_biases, base_model_weights, base_model_biases = [[p for k, p in model.named_parameters() if p.requires_grad and ('bias' in k) == is_bias and ('base' in k) == is_base] for is_base in [False, True] for is_bias in [False, True]] 73 | 74 | base_model_lr_mult = model.optimizer_params.pop('base_model_lr_mult', 1.0) 75 | optimizer = model.optimizer([dict(params = base_model_weights, lr = base_model_lr_mult * model.optimizer_params['lr']), dict(params = base_model_biases, lr = base_model_lr_mult * model.optimizer_params['lr'], weight_decay = 0.0), dict(params = model_biases, weight_decay = 0.0)], **model.optimizer_params) 76 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, **model.lr_scheduler_params) 77 | 78 | log = open(opts.log, 'w', 0) 79 | for epoch in range(opts.epochs): 80 | scheduler.step() 81 | model.train() 82 | loss_all, norm_all = [], [] 83 | for batch_idx, batch in enumerate(loader_train if model.criterion is not None else []): 84 | tic = time.time() 85 | images, labels = [tensor.cuda() for tensor in batch] 86 | loss = model.criterion(model(images), labels) 87 | loss_all.append(float(loss)) 88 | optimizer.zero_grad() 89 | loss.backward() 90 | optimizer.step() 91 | print('train {:>3}.{:05} loss {:.04f} hz {:.02f}'.format(epoch, batch_idx, loss_all[-1], len(images) / (time.time() - tic))) 92 | log.write('loss epoch {}: {:.04f}\n'.format(epoch, torch.Tensor(loss_all or [0.0]).mean())) 93 | 94 | if epoch < 10 or epoch % 5 == 0 or epoch == opts.epochs - 1: 95 | model.eval() 96 | embeddings_all, labels_all = [], [] 97 | for batch_idx, batch in enumerate(loader_eval): 98 | tic = time.time() 99 | images, labels = [tensor.cuda() for tensor in batch] 100 | with torch.no_grad(): 101 | output = model(images) 102 | embeddings_all.append(output.data.cpu()) 103 | labels_all.append(labels.data.cpu()) 104 | print('eval {:>3}.{:05} hz {:.02f}'.format(epoch, batch_idx, len(images) / (time.time() - tic))) 105 | log.write('recall@1 epoch {}: {:.06f}\n'.format(epoch, recall(torch.cat(embeddings_all), torch.cat(labels_all)))) 106 | --------------------------------------------------------------------------------