├── .gitignore ├── Icon ├── LICENSE ├── README.md ├── bag_of_tricks_config.yaml ├── data_v1 ├── Icon ├── __init__.py ├── dukemtmc.py ├── gta.py ├── market1501.py └── sampler.py ├── data_v2 ├── Icon ├── __init__.py ├── datamanager.py ├── datasets │ ├── Icon │ ├── __init__.py │ ├── dataset.py │ ├── image │ │ ├── Icon │ │ ├── __init__.py │ │ ├── aicity24.py │ │ ├── cuhk01.py │ │ ├── cuhk02.py │ │ ├── cuhk03.py │ │ ├── cuhk03_detected.py │ │ ├── cuhk03_labeled.py │ │ ├── cuhk03_splited.py │ │ ├── dukemtmcreid.py │ │ ├── grid.py │ │ ├── ilids.py │ │ ├── market1501.py │ │ ├── mot17.py │ │ ├── msmt17.py │ │ ├── prid.py │ │ ├── sensereid.py │ │ └── viper.py │ ├── utils.py │ └── video │ │ ├── Icon │ │ ├── __init__.py │ │ ├── dukemtmcvidreid.py │ │ ├── ilidsvid.py │ │ ├── mars.py │ │ └── prid2011.py ├── sampler.py ├── transforms.py └── utils.py ├── engine_v1.py ├── engine_v2.py ├── engine_v3.py ├── lmbn_config.yaml ├── loss ├── Icon ├── __init__.py ├── center_loss.py ├── focal_loss.py ├── grouploss.py ├── multi_similarity_loss.py ├── osm_caa_loss.py ├── ranked_loss.py └── triplet.py ├── main.py ├── model ├── Icon ├── __init__.py ├── attention.py ├── bnneck.py ├── c.py ├── g_c.py ├── g_p.py ├── lmbn_n.py ├── lmbn_n_drop_no_bnneck.py ├── lmbn_n_no_drop.py ├── lmbn_r.py ├── lmbn_r_no_drop.py ├── mcn.py ├── mgn.py ├── osnet.py ├── p.py ├── pcb.py ├── pyramid.py ├── resnet50.py ├── resnet50_ibn.py └── se_resnet.py ├── optim ├── Icon ├── __init__.py ├── n_adam.py ├── nadam.py ├── warmup_cosine_scheduler.py └── warmup_scheduler.py ├── option.py ├── requirements.txt └── utils ├── Icon ├── LightMB.png ├── __init__.py ├── functions.py ├── model_complexity.py ├── random_erasing.py ├── rank_cylib ├── Makefile ├── __init__.py ├── rank_cy.pyx ├── setup.py └── test_cython.py ├── re_ranking.py ├── utility.py ├── visualize_actmap.py └── visualize_rank.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # jixunbo 10 | "Icon\r" 11 | experiment 12 | .DS_Store 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | pip-wheel-metadata/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | *.py,cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | cover/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | .pybuilder/ 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | # For a library or package, you might want to ignore these files since the code is 93 | # intended to run in multiple environments; otherwise, check them in: 94 | # .python-version 95 | 96 | # pipenv 97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 100 | # install all needed dependencies. 101 | #Pipfile.lock 102 | 103 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 104 | __pypackages__/ 105 | 106 | # Celery stuff 107 | celerybeat-schedule 108 | celerybeat.pid 109 | 110 | # SageMath parsed files 111 | *.sage.py 112 | 113 | # Environments 114 | .env 115 | .venv 116 | env/ 117 | venv/ 118 | ENV/ 119 | env.bak/ 120 | venv.bak/ 121 | 122 | # Spyder project settings 123 | .spyderproject 124 | .spyproject 125 | 126 | # Rope project settings 127 | .ropeproject 128 | 129 | # mkdocs documentation 130 | /site 131 | 132 | # mypy 133 | .mypy_cache/ 134 | .dmypy.json 135 | dmypy.json 136 | 137 | # Pyre type checker 138 | .pyre/ 139 | 140 | # pytype static type analyzer 141 | .pytype/ 142 | 143 | # Cython debug symbols 144 | cython_debug/ 145 | 146 | # static files generated from Django application using `collectstatic` 147 | media 148 | static -------------------------------------------------------------------------------- /Icon : -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jixunbo/LightMBN/1f4ec0de6535604f89257195f5c083845491dc59/Icon -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 jixunbo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /bag_of_tricks_config.yaml: -------------------------------------------------------------------------------- 1 | T: 3 2 | act: relu 3 | amsgrad: false 4 | batchid: 16 5 | batchimage: 4 6 | batchtest: 32 7 | beta1: 0.9 8 | beta2: 0.999 9 | bnneck: true 10 | config: '' 11 | cpu: false 12 | cutout: false 13 | dampening: 0 14 | data_test: DukeMTMC 15 | data_train: DukeMTMC 16 | datadir: /content/ReIDataset/ 17 | decay_type: step_40_70 18 | drop_block: false 19 | epochs: 120 20 | epsilon: 1.0e-08 21 | feat_inference: after 22 | feats: 256 23 | gamma: 0.1 24 | h_ratio: 0.33 25 | height: 384 26 | if_labelsmooth: true 27 | loss: 1*CrossEntropy+1*Triplet 28 | lr: 0.00035 29 | lr_decay: 60 30 | margin: 0.3 31 | model: ResNet50 32 | momentum: 0.9 33 | nGPU: 1 34 | nThread: 4 35 | nesterov: true 36 | num_anchors: 1 37 | num_classes: 702 38 | optimizer: ADAM 39 | parts: 2 40 | pcb_different_lr: true 41 | pool: avg 42 | probability: 0.5 43 | random_erasing: true 44 | reset: false 45 | sampler: true 46 | test_every: 10 47 | w_ratio: 1.0 48 | warmup: linear 49 | weight_decay: 0.0005 50 | width: 128 51 | -------------------------------------------------------------------------------- /data_v1/Icon : -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jixunbo/LightMBN/1f4ec0de6535604f89257195f5c083845491dc59/data_v1/Icon -------------------------------------------------------------------------------- /data_v1/__init__.py: -------------------------------------------------------------------------------- 1 | from importlib import import_module 2 | from torchvision import transforms 3 | from utils.random_erasing import RandomErasing, Cutout 4 | from .sampler import RandomSampler, RandomIdentitySampler 5 | from torch.utils.data import dataloader 6 | 7 | 8 | class Data: 9 | def __init__(self, args): 10 | 11 | # train_list = [ 12 | # transforms.Resize((args.height, args.width), interpolation=3), 13 | # transforms.RandomHorizontalFlip(), 14 | # transforms.ToTensor(), 15 | # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[ 16 | # 0.229, 0.224, 0.225]) 17 | # ] 18 | 19 | train_list = [ 20 | transforms.Resize((args.height, args.width), interpolation=3), 21 | transforms.Pad(10), 22 | transforms.RandomCrop((args.height, args.width)), 23 | transforms.RandomHorizontalFlip(), 24 | transforms.ToTensor(), 25 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 26 | ] 27 | if args.random_erasing: 28 | train_list.append(RandomErasing( 29 | probability=args.probability, mean=[0.485, 0.456, 0.406])) 30 | print('Using random_erasing augmentation.') 31 | if args.cutout: 32 | train_list.append(Cutout(mean=[0.485, 0.456, 0.406])) 33 | print('Using cutout augmentation.') 34 | 35 | train_transform = transforms.Compose(train_list) 36 | 37 | test_transform = transforms.Compose([ 38 | transforms.Resize((args.height, args.width), interpolation=3), 39 | transforms.ToTensor(), 40 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[ 41 | 0.229, 0.224, 0.225]) 42 | ]) 43 | if not args.test_only and args.model == 'MGN': 44 | module_train = import_module('data.' + args.data_train.lower()) 45 | self.trainset = getattr(module_train, args.data_train)( 46 | args, train_transform, 'train') 47 | self.train_loader = dataloader.DataLoader(self.trainset, 48 | sampler=RandomIdentitySampler( 49 | self.trainset, args.batchid * args.batchimage, args.batchimage), 50 | # shuffle=True, 51 | batch_size=args.batchid * args.batchimage, 52 | num_workers=args.nThread) 53 | # elif not args.test_only and args.model in ['ResNet50','PCB'] and args.loss.split('*')[1]=='CrossEntropy': 54 | # module_train = import_module('data.' + args.data_train.lower()) 55 | # self.trainset = getattr(module_train, args.data_train)( 56 | # args, train_transform, 'train') 57 | # self.train_loader = dataloader.DataLoader(self.trainset, 58 | # shuffle=True, 59 | # batch_size=args.batchid * args.batchimage, 60 | # num_workers=args.nThread) 61 | elif not args.test_only and args.model in ['ResNet50', 'PCB', 'PCB_v', 'PCB_conv', 'BB_2_db','BB', 'MGDB','MGDB_v2','MGDB_v3','BB_2_v3','BB_2', 'PCB_conv_modi_2', 'BB_2_conv','BB_2_cat', 'BB_4_cat','PCB_conv_modi', 'Pyramid','PLR'] and bool(args.sampler): 62 | 63 | module_train = import_module('data.' + args.data_train.lower()) 64 | self.trainset = getattr(module_train, args.data_train)( 65 | args, train_transform, 'train') 66 | # self.train_loader = dataloader.DataLoader(self.trainset, 67 | # sampler=RandomSampler( 68 | # self.trainset, args.batchid, batch_image=args.batchimage), 69 | # # shuffle=True, 70 | # batch_size=args.batchid * args.batchimage, 71 | # num_workers=args.nThread, 72 | # drop_last=True) 73 | self.train_loader = dataloader.DataLoader(self.trainset, 74 | sampler=RandomIdentitySampler( 75 | self.trainset, args.batchid * args.batchimage, args.batchimage), 76 | # shuffle=True, 77 | batch_size=args.batchid * args.batchimage, 78 | num_workers=args.nThread) 79 | 80 | elif not args.test_only and args.model not in ['MGN', 'ResNet50', 'PCB','BB_2_db', 'PCB_v', 'PCB_conv','MGDB', 'PCB_conv_modi_2', 'PCB_conv_modi', 'BB', 'BB_2','BB_2_cat','BB_4_cat','PLR']: 81 | raise Exception( 82 | 'DataLoader for {} not designed'.format(args.model)) 83 | else: 84 | self.train_loader = None 85 | 86 | if args.data_test in ['Market1501', 'DukeMTMC', 'GTA']: 87 | module = import_module('data.' + args.data_train.lower()) 88 | self.galleryset = getattr(module, args.data_test)( 89 | args, test_transform, 'test') 90 | self.queryset = getattr(module, args.data_test)( 91 | args, test_transform, 'query') 92 | 93 | else: 94 | raise Exception() 95 | # print(len(self.trainset)) 96 | 97 | self.test_loader = dataloader.DataLoader( 98 | self.galleryset, batch_size=args.batchtest, num_workers=args.nThread) 99 | self.query_loader = dataloader.DataLoader( 100 | self.queryset, batch_size=args.batchtest, num_workers=args.nThread) 101 | -------------------------------------------------------------------------------- /data_v1/dukemtmc.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import dataset 2 | from torchvision.datasets.folder import default_loader 3 | import os 4 | import re 5 | 6 | 7 | class DukeMTMC(dataset.Dataset): 8 | def __init__(self, args, transform, dtype): 9 | 10 | self.transform = transform 11 | self.loader = default_loader 12 | 13 | data_path = args.datadir 14 | if dtype == 'train': 15 | data_path += '/bounding_box_train' 16 | elif dtype == 'test': 17 | data_path += '/bounding_box_test' 18 | else: 19 | data_path += '/query' 20 | 21 | self.imgs = [path for path in self.list_pictures(data_path) if self.id(path) != -1] 22 | 23 | self._id2label = {_id: idx for idx, _id in enumerate(self.unique_ids)} 24 | 25 | def __getitem__(self, index): 26 | path = self.imgs[index] 27 | target = self._id2label[self.id(path)] 28 | 29 | img = self.loader(path) 30 | if self.transform is not None: 31 | img = self.transform(img) 32 | 33 | return img, target 34 | 35 | def __len__(self): 36 | return len(self.imgs) 37 | 38 | @staticmethod 39 | def id(file_path): 40 | """ 41 | :param file_path: unix style file path 42 | :return: person id 43 | """ 44 | return int(file_path.split('/')[-1].split('_')[0]) 45 | 46 | @staticmethod 47 | def camera(file_path): 48 | """ 49 | :param file_path: unix style file path 50 | :return: camera id 51 | """ 52 | return int(file_path.split('/')[-1].split('_')[1][1]) 53 | 54 | @property 55 | def ids(self): 56 | """ 57 | :return: person id list corresponding to dataset image paths 58 | """ 59 | return [self.id(path) for path in self.imgs] 60 | 61 | @property 62 | def unique_ids(self): 63 | """ 64 | :return: unique person ids in ascending order 65 | """ 66 | return sorted(set(self.ids)) 67 | 68 | @property 69 | def cameras(self): 70 | """ 71 | :return: camera id list corresponding to dataset image paths 72 | """ 73 | return [self.camera(path) for path in self.imgs] 74 | 75 | @staticmethod 76 | def list_pictures(directory, ext='jpg|jpeg|bmp|png|ppm|npy'): 77 | assert os.path.isdir( 78 | directory), 'dataset is not exists!{}'.format(directory) 79 | 80 | return sorted([os.path.join(root, f) 81 | for root, _, files in os.walk(directory) for f in files 82 | if re.match(r'([\w]+\.(?:' + ext + '))', f)]) 83 | 84 | -------------------------------------------------------------------------------- /data_v1/gta.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import dataset 2 | from torchvision.datasets.folder import default_loader 3 | import os 4 | import re 5 | 6 | 7 | class GTA(dataset.Dataset): 8 | def __init__(self, args, transform, dtype): 9 | 10 | self.transform = transform 11 | self.loader = default_loader 12 | 13 | data_path = args.datadir 14 | if dtype == 'train': 15 | data_path += '/train' 16 | elif dtype == 'test': 17 | data_path += '/gallery' 18 | else: 19 | data_path += '/query' 20 | 21 | self.imgs = [path for path in self.list_pictures(data_path) if self.id(path) != -1] 22 | 23 | self._id2label = {_id: idx for idx, _id in enumerate(self.unique_ids)} 24 | print('{} classes.'.format(len(self.unique_ids))) 25 | 26 | def __getitem__(self, index): 27 | path = self.imgs[index] 28 | target = self._id2label[self.id(path)] 29 | 30 | img = self.loader(path) 31 | if self.transform is not None: 32 | img = self.transform(img) 33 | 34 | return img, target 35 | 36 | def __len__(self): 37 | return len(self.imgs) 38 | 39 | @staticmethod 40 | def id(file_path): 41 | """ 42 | :param file_path: unix style file path 43 | :return: person id 44 | """ 45 | return int(file_path.split('/')[-1].split('_')[4]) 46 | 47 | @staticmethod 48 | def camera(file_path): 49 | """ 50 | :param file_path: unix style file path 51 | :return: camera id 52 | """ 53 | return int(file_path.split('/')[-1].split('_')[5][-1]) 54 | 55 | @property 56 | def ids(self): 57 | """ 58 | :return: person id list corresponding to dataset image paths 59 | """ 60 | return [self.id(path) for path in self.imgs] 61 | 62 | @property 63 | def unique_ids(self): 64 | """ 65 | :return: unique person ids in ascending order 66 | """ 67 | return sorted(set(self.ids)) 68 | 69 | 70 | @property 71 | def cameras(self): 72 | """ 73 | :return: camera id list corresponding to dataset image paths 74 | """ 75 | return [self.camera(path) for path in self.imgs] 76 | 77 | @staticmethod 78 | def list_pictures(directory, ext='jpg|jpeg|bmp|png|ppm|npy'): 79 | assert os.path.isdir( 80 | directory), 'dataset is not exists!{}'.format(directory) 81 | imgs=[] 82 | 83 | for d in os.listdir(directory): 84 | if os.path.isdir(os.path.join(directory,d)): 85 | for file in os.listdir(os.path.join(directory,d)): 86 | if file.split('.')[-1] == 'jpeg': 87 | imgs.append(os.path.join(directory,d,file)) 88 | return imgs 89 | 90 | # return sorted([os.path.join(root, f) 91 | # for root, _, files in os.walk(directory) for f in files 92 | # if re.match(r'([\w]+\.(?:' + ext + '))', f)]) 93 | if __name__ == '__main__': 94 | dataset = GTA 95 | -------------------------------------------------------------------------------- /data_v1/market1501.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import dataset 2 | from torchvision.datasets.folder import default_loader 3 | import os 4 | import re 5 | 6 | 7 | class Market1501(dataset.Dataset): 8 | def __init__(self, args, transform, dtype): 9 | 10 | self.transform = transform 11 | self.loader = default_loader 12 | 13 | data_path = args.datadir 14 | if dtype == 'train': 15 | data_path += '/bounding_box_train' 16 | elif dtype == 'test': 17 | data_path += '/bounding_box_test' 18 | else: 19 | data_path += '/query' 20 | 21 | self.imgs = [path for path in self.list_pictures(data_path) if self.id(path) != -1] 22 | 23 | self._id2label = {_id: idx for idx, _id in enumerate(self.unique_ids)} 24 | 25 | def __getitem__(self, index): 26 | path = self.imgs[index] 27 | target = self._id2label[self.id(path)] 28 | 29 | img = self.loader(path) 30 | if self.transform is not None: 31 | img = self.transform(img) 32 | 33 | return img, target 34 | 35 | def __len__(self): 36 | return len(self.imgs) 37 | 38 | @staticmethod 39 | def id(file_path): 40 | """ 41 | :param file_path: unix style file path 42 | :return: person id 43 | """ 44 | return int(file_path.split('/')[-1].split('_')[0]) 45 | 46 | @staticmethod 47 | def camera(file_path): 48 | """ 49 | :param file_path: unix style file path 50 | :return: camera id 51 | """ 52 | return int(file_path.split('/')[-1].split('_')[1][1]) 53 | 54 | @property 55 | def ids(self): 56 | """ 57 | :return: person id list corresponding to dataset image paths 58 | """ 59 | return [self.id(path) for path in self.imgs] 60 | 61 | @property 62 | def unique_ids(self): 63 | """ 64 | :return: unique person ids in ascending order 65 | """ 66 | return sorted(set(self.ids)) 67 | 68 | @property 69 | def cameras(self): 70 | """ 71 | :return: camera id list corresponding to dataset image paths 72 | """ 73 | return [self.camera(path) for path in self.imgs] 74 | 75 | @staticmethod 76 | def list_pictures(directory, ext='jpg|jpeg|bmp|png|ppm|npy'): 77 | assert os.path.isdir( 78 | directory), 'dataset is not exists!{}'.format(directory) 79 | 80 | return sorted([os.path.join(root, f) 81 | for root, _, files in os.walk(directory) for f in files 82 | if re.match(r'([\w]+\.(?:' + ext + '))', f)]) 83 | 84 | -------------------------------------------------------------------------------- /data_v1/sampler.py: -------------------------------------------------------------------------------- 1 | import random 2 | import copy 3 | import collections 4 | import numpy as np 5 | from torch.utils.data import sampler 6 | 7 | 8 | class RandomSampler(sampler.Sampler): 9 | def __init__(self, data_source, batch_id, batch_image): 10 | super(RandomSampler, self).__init__(data_source) 11 | 12 | self.data_source = data_source 13 | self.batch_image = batch_image 14 | self.batch_id = batch_id 15 | 16 | self._id2index = collections.defaultdict(list) 17 | for idx, path in enumerate(data_source.imgs): 18 | _id = data_source.id(path) 19 | self._id2index[_id].append(idx) 20 | 21 | def __iter__(self): 22 | unique_ids = self.data_source.unique_ids 23 | random.shuffle(unique_ids) 24 | 25 | imgs = [] 26 | for _id in unique_ids: 27 | imgs.extend(self._sample(self._id2index[_id], self.batch_image)) 28 | return iter(imgs) 29 | 30 | def __len__(self): 31 | return len(self._id2index) * self.batch_image 32 | 33 | @staticmethod 34 | def _sample(population, k): 35 | if len(population) < k: 36 | population = population * k 37 | return random.sample(population, k) 38 | 39 | 40 | class RandomIdentitySampler(sampler.Sampler): 41 | """ 42 | Randomly sample N identities, then for each identity, 43 | randomly sample K instances, therefore batch size is N*K. 44 | Args: 45 | - data_source (list): list of (img_path, pid, camid). 46 | - num_instances (int): number of instances per identity in a batch. 47 | - batch_size (int): number of examples in a batch. 48 | """ 49 | 50 | def __init__(self, data_source, batch_size, num_instances): 51 | self.data_source = data_source 52 | self.batch_size = batch_size 53 | self.num_instances = num_instances 54 | self.num_pids_per_batch = self.batch_size // self.num_instances 55 | self.index_dic = collections.defaultdict(list) 56 | for index, path in enumerate(self.data_source.imgs): 57 | _id = data_source.id(path) 58 | self.index_dic[_id].append(index) 59 | self.pids = list(self.index_dic.keys()) 60 | 61 | # estimate number of examples in an epoch 62 | self.length = 0 63 | for pid in self.pids: 64 | idxs = self.index_dic[pid] 65 | num = len(idxs) 66 | if num < self.num_instances: 67 | num = self.num_instances 68 | self.length += num - num % self.num_instances 69 | 70 | def __iter__(self): 71 | batch_idxs_dict = collections.defaultdict(list) 72 | 73 | for pid in self.pids: 74 | idxs = copy.deepcopy(self.index_dic[pid]) 75 | if len(idxs) < self.num_instances: 76 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True) 77 | random.shuffle(idxs) 78 | batch_idxs = [] 79 | for idx in idxs: 80 | batch_idxs.append(idx) 81 | if len(batch_idxs) == self.num_instances: 82 | batch_idxs_dict[pid].append(batch_idxs) 83 | batch_idxs = [] 84 | 85 | avai_pids = copy.deepcopy(self.pids) 86 | final_idxs = [] 87 | 88 | while len(avai_pids) >= self.num_pids_per_batch: 89 | selected_pids = random.sample(avai_pids, self.num_pids_per_batch) 90 | for pid in selected_pids: 91 | batch_idxs = batch_idxs_dict[pid].pop(0) 92 | final_idxs.extend(batch_idxs) 93 | if len(batch_idxs_dict[pid]) == 0: 94 | avai_pids.remove(pid) 95 | 96 | self.length = len(final_idxs) 97 | return iter(final_idxs) 98 | 99 | def __len__(self): 100 | return self.length 101 | 102 | 103 | class a_RandomIdentitySampler(sampler.Sampler): 104 | """ 105 | Randomly sample N identities, then for each identity, 106 | randomly sample K instances, therefore batch size is N*K. 107 | Args: 108 | - data_source (list): list of (img_path, pid, camid). 109 | - num_instances (int): number of instances per identity in a batch. 110 | - batch_size (int): number of examples in a batch. 111 | """ 112 | 113 | def __init__(self, data_source, batch_size, num_instances): 114 | self.data_source = data_source 115 | self.batch_size = batch_size 116 | self.num_instances = num_instances 117 | self.num_pids_per_batch = self.batch_size // self.num_instances 118 | self.index_dic = collections.defaultdict(list) 119 | for index, path in enumerate(self.data_source.imgs): 120 | _id = path[1] 121 | self.index_dic[_id].append(index) 122 | self.pids = list(self.index_dic.keys()) 123 | 124 | # estimate number of examples in an epoch 125 | self.length = 0 126 | for pid in self.pids: 127 | idxs = self.index_dic[pid] 128 | num = len(idxs) 129 | if num < self.num_instances: 130 | num = self.num_instances 131 | self.length += num - num % self.num_instances 132 | 133 | def __iter__(self): 134 | batch_idxs_dict = collections.defaultdict(list) 135 | 136 | for pid in self.pids: 137 | idxs = copy.deepcopy(self.index_dic[pid]) 138 | if len(idxs) < self.num_instances: 139 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True) 140 | random.shuffle(idxs) 141 | batch_idxs = [] 142 | for idx in idxs: 143 | batch_idxs.append(idx) 144 | if len(batch_idxs) == self.num_instances: 145 | batch_idxs_dict[pid].append(batch_idxs) 146 | batch_idxs = [] 147 | 148 | avai_pids = copy.deepcopy(self.pids) 149 | final_idxs = [] 150 | 151 | while len(avai_pids) >= self.num_pids_per_batch: 152 | selected_pids = random.sample(avai_pids, self.num_pids_per_batch) 153 | for pid in selected_pids: 154 | batch_idxs = batch_idxs_dict[pid].pop(0) 155 | final_idxs.extend(batch_idxs) 156 | if len(batch_idxs_dict[pid]) == 0: 157 | avai_pids.remove(pid) 158 | 159 | self.length = len(final_idxs) 160 | return iter(final_idxs) 161 | 162 | def __len__(self): 163 | return self.length 164 | -------------------------------------------------------------------------------- /data_v2/Icon : -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jixunbo/LightMBN/1f4ec0de6535604f89257195f5c083845491dc59/data_v2/Icon -------------------------------------------------------------------------------- /data_v2/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | 4 | from .datasets import Dataset, ImageDataset, VideoDataset 5 | from .datasets import register_image_dataset 6 | from .datasets import register_video_dataset 7 | from .datamanager import ImageDataManager, VideoDataManager -------------------------------------------------------------------------------- /data_v2/datasets/Icon : -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jixunbo/LightMBN/1f4ec0de6535604f89257195f5c083845491dc59/data_v2/datasets/Icon -------------------------------------------------------------------------------- /data_v2/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | 4 | from .dataset import Dataset, ImageDataset, VideoDataset 5 | from .image import * 6 | from .video import * 7 | 8 | 9 | __image_datasets = { 10 | "market1501": Market1501, 11 | "cuhk03": CUHK03, 12 | "cuhk03_detected": CUHK03_Detected, 13 | "cuhk03_labeled": CUHK03_Labeled, 14 | "cuhk03_splited": CUHK03_splited, 15 | "dukemtmc": DukeMTMCreID, 16 | "msmt17": MSMT17, 17 | "viper": VIPeR, 18 | "grid": GRID, 19 | "cuhk01": CUHK01, 20 | "ilids": iLIDS, 21 | "sensereid": SenseReID, 22 | "prid": PRID, 23 | "cuhk02": CUHK02, 24 | "mot17": MOT17, 25 | "veri": VeRi, 26 | "aicity24": AICity24Balanced, 27 | } 28 | 29 | 30 | __video_datasets = { 31 | "mars": Mars, 32 | "ilidsvid": iLIDSVID, 33 | "prid2011": PRID2011, 34 | "dukemtmcvidreid": DukeMTMCVidReID, 35 | } 36 | 37 | 38 | def init_image_dataset(name, **kwargs): 39 | """Initializes an image dataset.""" 40 | avai_datasets = list(__image_datasets.keys()) 41 | if name not in avai_datasets: 42 | raise ValueError( 43 | 'Invalid dataset name. Received "{}", ' 44 | "but expected to be one of {}".format(name, avai_datasets) 45 | ) 46 | return __image_datasets[name](**kwargs) 47 | 48 | 49 | def init_video_dataset(name, **kwargs): 50 | """Initializes a video dataset.""" 51 | avai_datasets = list(__video_datasets.keys()) 52 | if name not in avai_datasets: 53 | raise ValueError( 54 | 'Invalid dataset name. Received "{}", ' 55 | "but expected to be one of {}".format(name, avai_datasets) 56 | ) 57 | return __video_datasets[name](**kwargs) 58 | 59 | 60 | def register_image_dataset(name, dataset): 61 | """Registers a new image dataset. 62 | 63 | Args: 64 | name (str): key corresponding to the new dataset. 65 | dataset (Dataset): the new dataset class. 66 | 67 | Examples:: 68 | 69 | import torchreid 70 | import NewDataset 71 | torchreid.data.register_image_dataset('new_dataset', NewDataset) 72 | # single dataset case 73 | datamanager = torchreid.data.ImageDataManager( 74 | root='reid-data', 75 | sources='new_dataset' 76 | ) 77 | # multiple dataset case 78 | datamanager = torchreid.data.ImageDataManager( 79 | root='reid-data', 80 | sources=['new_dataset', 'dukemtmcreid'] 81 | ) 82 | """ 83 | global __image_datasets 84 | curr_datasets = list(__image_datasets.keys()) 85 | if name in curr_datasets: 86 | raise ValueError( 87 | "The given name already exists, please choose " 88 | "another name excluding {}".format(curr_datasets) 89 | ) 90 | __image_datasets[name] = dataset 91 | 92 | 93 | def register_video_dataset(name, dataset): 94 | """Registers a new video dataset. 95 | 96 | Args: 97 | name (str): key corresponding to the new dataset. 98 | dataset (Dataset): the new dataset class. 99 | 100 | Examples:: 101 | 102 | import torchreid 103 | import NewDataset 104 | torchreid.data.register_video_dataset('new_dataset', NewDataset) 105 | # single dataset case 106 | datamanager = torchreid.data.VideoDataManager( 107 | root='reid-data', 108 | sources='new_dataset' 109 | ) 110 | # multiple dataset case 111 | datamanager = torchreid.data.VideoDataManager( 112 | root='reid-data', 113 | sources=['new_dataset', 'ilidsvid'] 114 | ) 115 | """ 116 | global __video_datasets 117 | curr_datasets = list(__video_datasets.keys()) 118 | if name in curr_datasets: 119 | raise ValueError( 120 | "The given name already exists, please choose " 121 | "another name excluding {}".format(curr_datasets) 122 | ) 123 | __video_datasets[name] = dataset 124 | -------------------------------------------------------------------------------- /data_v2/datasets/image/Icon : -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jixunbo/LightMBN/1f4ec0de6535604f89257195f5c083845491dc59/data_v2/datasets/image/Icon -------------------------------------------------------------------------------- /data_v2/datasets/image/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | 4 | from ..dataset import ImageDataset 5 | from .market1501 import Market1501 6 | from .dukemtmcreid import DukeMTMCreID 7 | from .cuhk03 import CUHK03 8 | from .msmt17 import MSMT17 9 | from .viper import VIPeR 10 | from .grid import GRID 11 | from .cuhk01 import CUHK01 12 | from .ilids import iLIDS 13 | from .sensereid import SenseReID 14 | from .prid import PRID 15 | from .cuhk02 import CUHK02 16 | from .cuhk03_detected import CUHK03_Detected 17 | from .cuhk03_labeled import CUHK03_Labeled 18 | from .cuhk03_splited import CUHK03_splited 19 | from .msmt17 import MSMT17 20 | from .mot17 import MOT17 21 | from .veri import VeRi 22 | from .aicity24 import AICity24Balanced 23 | -------------------------------------------------------------------------------- /data_v2/datasets/image/aicity24.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import re 3 | 4 | from .. import ImageDataset 5 | 6 | 7 | class AICity24Balanced(ImageDataset): 8 | """AICity24 dataset. 9 | 10 | Cropped from the tracking dataset of the AI City Challenge 2024 with balanced 11 | distribution of identities. For each identity, we choose 100 images for the 12 | training set. Then, for query and gallery, we choose 108 images for each 13 | identitiy in the validation set, where 8 images are used for query and 100 14 | images are used for gallery. 15 | 16 | We choose a "hard" setting here, where there is no shared camera between 17 | query and gallery, i.e., for one identity in the query set, the same identity 18 | in the gallery will be from different cameras. 19 | 20 | | subset | # ids | # images | # cameras | 21 | |:---------|:--------|:-----------|:------------| 22 | | train | 1012 | 101200 | 350 | 23 | | query | 518 | 4144 | 20 | 24 | | gallery | 518 | 51800 | 154 | 25 | 26 | """ 27 | 28 | _junk_pids = [0, -1] 29 | dataset_dir = "AICITY24" 30 | dataset_name = "AICity24" 31 | 32 | def __init__(self, root="datasets", **kwargs): 33 | # self.root = osp.abspath(osp.expanduser(root)) 34 | self.seed = 0 35 | self.root = root 36 | self.data_dir = osp.join(self.root, self.dataset_dir) 37 | 38 | self.train_dir = osp.join(self.data_dir, "bounding_box_train") 39 | self.query_dir = osp.join(self.data_dir, "bounding_box_query") 40 | self.gallery_dir = osp.join(self.data_dir, "bounding_box_test") 41 | 42 | self.valid_train_stems = [] 43 | with open(osp.join(self.data_dir, "valid_train_stems.txt"), "r") as f: 44 | for line in f: 45 | self.valid_train_stems.append(line.strip()) 46 | 47 | self.valid_test_stems = [] 48 | with open(osp.join(self.data_dir, "valid_test_stems.txt"), "r") as f: 49 | for line in f: 50 | self.valid_test_stems.append(line.strip()) 51 | 52 | self.valid_query_stems = [] 53 | with open(osp.join(self.data_dir, "valid_query_stems.txt"), "r") as f: 54 | for line in f: 55 | self.valid_query_stems.append(line.strip()) 56 | 57 | required_files = [ 58 | self.data_dir, 59 | self.train_dir, 60 | self.query_dir, 61 | self.gallery_dir, 62 | ] 63 | 64 | train = self.process_dir(self.train_dir, mode="train") 65 | query = self.process_dir(self.query_dir, mode="query") 66 | gallery = self.process_dir(self.gallery_dir, mode="test") 67 | 68 | super(AICity24Balanced, self).__init__(train, query, gallery, **kwargs) 69 | 70 | def process_dir(self, dir_path, mode="train"): 71 | if mode == "train": 72 | img_paths = [osp.join(dir_path, f + ".jpg") for f in self.valid_train_stems] 73 | elif mode == "test": 74 | img_paths = [osp.join(dir_path, f + ".jpg") for f in self.valid_test_stems] 75 | elif mode == "query": 76 | img_paths = [osp.join(dir_path, f + ".jpg") for f in self.valid_query_stems] 77 | else: 78 | raise ValueError("Invalid mode") 79 | 80 | pattern = re.compile(r"([-\d]+)_s([-\d]+)c([-\d]+)") 81 | 82 | data = [] 83 | 84 | for img_path in img_paths: 85 | pid, sceneid, camid = map(int, pattern.search(img_path).groups()) 86 | data.append((img_path, pid, camid)) 87 | 88 | return data 89 | -------------------------------------------------------------------------------- /data_v2/datasets/image/cuhk01.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import sys 6 | import os 7 | import os.path as osp 8 | import glob 9 | import zipfile 10 | import numpy as np 11 | 12 | from .. import ImageDataset 13 | from ..utils import read_json, write_json 14 | 15 | 16 | class CUHK01(ImageDataset): 17 | """CUHK01. 18 | 19 | Reference: 20 | Li et al. Human Reidentification with Transferred Metric Learning. ACCV 2012. 21 | 22 | URL: ``_ 23 | 24 | Dataset statistics: 25 | - identities: 971. 26 | - images: 3884. 27 | - cameras: 4. 28 | """ 29 | dataset_dir = 'cuhk01' 30 | dataset_url = None 31 | 32 | def __init__(self, root='', split_id=0, **kwargs): 33 | self.root = osp.abspath(osp.expanduser(root)) 34 | self.dataset_dir = osp.join(self.root, self.dataset_dir) 35 | self.download_dataset(self.dataset_dir, self.dataset_url) 36 | 37 | self.zip_path = osp.join(self.dataset_dir, 'CUHK01.zip') 38 | self.campus_dir = osp.join(self.dataset_dir, 'campus') 39 | self.split_path = osp.join(self.dataset_dir, 'splits.json') 40 | 41 | self.extract_file() 42 | 43 | required_files = [ 44 | self.dataset_dir, 45 | self.campus_dir 46 | ] 47 | self.check_before_run(required_files) 48 | 49 | self.prepare_split() 50 | splits = read_json(self.split_path) 51 | if split_id >= len(splits): 52 | raise ValueError('split_id exceeds range, received {}, but expected between 0 and {}'.format(split_id, len(splits)-1)) 53 | split = splits[split_id] 54 | 55 | train = split['train'] 56 | query = split['query'] 57 | gallery = split['gallery'] 58 | 59 | train = [tuple(item) for item in train] 60 | query = [tuple(item) for item in query] 61 | gallery = [tuple(item) for item in gallery] 62 | 63 | super(CUHK01, self).__init__(train, query, gallery, **kwargs) 64 | 65 | def extract_file(self): 66 | if not osp.exists(self.campus_dir): 67 | print('Extracting files') 68 | zip_ref = zipfile.ZipFile(self.zip_path, 'r') 69 | zip_ref.extractall(self.dataset_dir) 70 | zip_ref.close() 71 | 72 | def prepare_split(self): 73 | """ 74 | Image name format: 0001001.png, where first four digits represent identity 75 | and last four digits represent cameras. Camera 1&2 are considered the same 76 | view and camera 3&4 are considered the same view. 77 | """ 78 | if not osp.exists(self.split_path): 79 | print('Creating 10 random splits of train ids and test ids') 80 | img_paths = sorted(glob.glob(osp.join(self.campus_dir, '*.png'))) 81 | img_list = [] 82 | pid_container = set() 83 | for img_path in img_paths: 84 | img_name = osp.basename(img_path) 85 | pid = int(img_name[:4]) - 1 86 | camid = (int(img_name[4:7]) - 1) // 2 # result is either 0 or 1 87 | img_list.append((img_path, pid, camid)) 88 | pid_container.add(pid) 89 | 90 | num_pids = len(pid_container) 91 | num_train_pids = num_pids // 2 92 | 93 | splits = [] 94 | for _ in range(10): 95 | order = np.arange(num_pids) 96 | np.random.shuffle(order) 97 | train_idxs = order[:num_train_pids] 98 | train_idxs = np.sort(train_idxs) 99 | idx2label = {idx: label for label, idx in enumerate(train_idxs)} 100 | 101 | train, test_a, test_b = [], [], [] 102 | for img_path, pid, camid in img_list: 103 | if pid in train_idxs: 104 | train.append((img_path, idx2label[pid], camid)) 105 | else: 106 | if camid == 0: 107 | test_a.append((img_path, pid, camid)) 108 | else: 109 | test_b.append((img_path, pid, camid)) 110 | 111 | # use cameraA as query and cameraB as gallery 112 | split = { 113 | 'train': train, 114 | 'query': test_a, 115 | 'gallery': test_b, 116 | 'num_train_pids': num_train_pids, 117 | 'num_query_pids': num_pids - num_train_pids, 118 | 'num_gallery_pids': num_pids - num_train_pids 119 | } 120 | splits.append(split) 121 | 122 | # use cameraB as query and cameraA as gallery 123 | split = { 124 | 'train': train, 125 | 'query': test_b, 126 | 'gallery': test_a, 127 | 'num_train_pids': num_train_pids, 128 | 'num_query_pids': num_pids - num_train_pids, 129 | 'num_gallery_pids': num_pids - num_train_pids 130 | } 131 | splits.append(split) 132 | 133 | print('Totally {} splits are created'.format(len(splits))) 134 | write_json(splits, self.split_path) 135 | print('Split file saved to {}'.format(self.split_path)) 136 | -------------------------------------------------------------------------------- /data_v2/datasets/image/cuhk02.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import sys 6 | import os 7 | import os.path as osp 8 | import glob 9 | 10 | from .. import ImageDataset 11 | 12 | 13 | class CUHK02(ImageDataset): 14 | """CUHK02. 15 | 16 | Reference: 17 | Li and Wang. Locally Aligned Feature Transforms across Views. CVPR 2013. 18 | 19 | URL: ``_ 20 | 21 | Dataset statistics: 22 | - 5 camera view pairs each with two cameras 23 | - 971, 306, 107, 193 and 239 identities from P1 - P5 24 | - totally 1,816 identities 25 | - image format is png 26 | 27 | Protocol: Use P1 - P4 for training and P5 for evaluation. 28 | """ 29 | dataset_dir = 'cuhk02' 30 | cam_pairs = ['P1', 'P2', 'P3', 'P4', 'P5'] 31 | test_cam_pair = 'P5' 32 | 33 | def __init__(self, root='', **kwargs): 34 | self.root = osp.abspath(osp.expanduser(root)) 35 | self.dataset_dir = osp.join(self.root, self.dataset_dir, 'Dataset') 36 | 37 | required_files = [self.dataset_dir] 38 | self.check_before_run(required_files) 39 | 40 | train, query, gallery = self.get_data_list() 41 | 42 | super(CUHK02, self).__init__(train, query, gallery, **kwargs) 43 | 44 | def get_data_list(self): 45 | num_train_pids, camid = 0, 0 46 | train, query, gallery = [], [], [] 47 | 48 | for cam_pair in self.cam_pairs: 49 | cam_pair_dir = osp.join(self.dataset_dir, cam_pair) 50 | 51 | cam1_dir = osp.join(cam_pair_dir, 'cam1') 52 | cam2_dir = osp.join(cam_pair_dir, 'cam2') 53 | 54 | impaths1 = glob.glob(osp.join(cam1_dir, '*.png')) 55 | impaths2 = glob.glob(osp.join(cam2_dir, '*.png')) 56 | 57 | if cam_pair == self.test_cam_pair: 58 | # add images to query 59 | for impath in impaths1: 60 | pid = osp.basename(impath).split('_')[0] 61 | pid = int(pid) 62 | query.append((impath, pid, camid)) 63 | camid += 1 64 | 65 | # add images to gallery 66 | for impath in impaths2: 67 | pid = osp.basename(impath).split('_')[0] 68 | pid = int(pid) 69 | gallery.append((impath, pid, camid)) 70 | camid += 1 71 | 72 | else: 73 | pids1 = [osp.basename(impath).split('_')[0] for impath in impaths1] 74 | pids2 = [osp.basename(impath).split('_')[0] for impath in impaths2] 75 | pids = set(pids1 + pids2) 76 | pid2label = {pid: label+num_train_pids for label, pid in enumerate(pids)} 77 | 78 | # add images to train from cam1 79 | for impath in impaths1: 80 | pid = osp.basename(impath).split('_')[0] 81 | pid = pid2label[pid] 82 | train.append((impath, pid, camid)) 83 | camid += 1 84 | 85 | # add images to train from cam1 86 | for impath in impaths1: 87 | pid = osp.basename(impath).split('_')[0] 88 | pid = pid2label[pid] 89 | train.append((impath, pid, camid)) 90 | camid += 1 91 | num_train_pids += len(pids) 92 | 93 | return train, query, gallery 94 | -------------------------------------------------------------------------------- /data_v2/datasets/image/cuhk03_detected.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import sys 6 | import os 7 | import os.path as osp 8 | import glob 9 | import re 10 | 11 | from .. import ImageDataset 12 | 13 | 14 | class CUHK03_Detected(ImageDataset): 15 | """DukeMTMC-reID. 16 | 17 | Reference: 18 | - Ristani et al. Performance Measures and a Data Set for Multi-Target, Multi-Camera Tracking. ECCVW 2016. 19 | - Zheng et al. Unlabeled Samples Generated by GAN Improve the Person Re-identification Baseline in vitro. ICCV 2017. 20 | 21 | URL: ``_ 22 | 23 | Dataset statistics: 24 | - identities: 1404 (train + query). 25 | - images:16522 (train) + 2228 (query) + 17661 (gallery). 26 | - cameras: 8. 27 | """ 28 | dataset_dir = 'cuhk03' 29 | 30 | def __init__(self, root='', **kwargs): 31 | self.root = osp.abspath(osp.expanduser(root)) 32 | self.dataset_dir = osp.join(self.root, self.dataset_dir) 33 | self.train_dir = osp.join(self.dataset_dir, 'CUHK03_detected/bounding_box_train') 34 | self.query_dir = osp.join(self.dataset_dir, 'CUHK03_detected/query') 35 | self.gallery_dir = osp.join(self.dataset_dir, 'CUHK03_detected/bounding_box_test') 36 | 37 | required_files = [ 38 | self.dataset_dir, 39 | self.train_dir, 40 | self.query_dir, 41 | self.gallery_dir 42 | ] 43 | self.check_before_run(required_files) 44 | 45 | train = self.process_dir(self.train_dir, relabel=True) 46 | query = self.process_dir(self.query_dir, relabel=False) 47 | gallery = self.process_dir(self.gallery_dir, relabel=False) 48 | 49 | super(CUHK03_Detected, self).__init__(train, query, gallery, **kwargs) 50 | 51 | def process_dir(self, dir_path, relabel=False): 52 | img_paths = glob.glob(osp.join(dir_path, '*.png')) 53 | pattern = re.compile(r'([-\d]+)_c(\d)') 54 | 55 | pid_container = set() 56 | for img_path in img_paths: 57 | pid, _ = map(int, pattern.search(img_path).groups()) 58 | pid_container.add(pid) 59 | pid2label = {pid:label for label, pid in enumerate(pid_container)} 60 | 61 | data = [] 62 | for img_path in img_paths: 63 | pid, camid = map(int, pattern.search(img_path).groups()) 64 | #assert 1 <= camid <= 8 65 | camid -= 1 # index starts from 0 66 | if relabel: pid = pid2label[pid] 67 | data.append((img_path, pid, camid)) 68 | 69 | return data 70 | -------------------------------------------------------------------------------- /data_v2/datasets/image/cuhk03_labeled.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import sys 6 | import os 7 | import os.path as osp 8 | import glob 9 | import re 10 | 11 | from .. import ImageDataset 12 | 13 | 14 | class CUHK03_Labeled(ImageDataset): 15 | """DukeMTMC-reID. 16 | 17 | Reference: 18 | - Ristani et al. Performance Measures and a Data Set for Multi-Target, Multi-Camera Tracking. ECCVW 2016. 19 | - Zheng et al. Unlabeled Samples Generated by GAN Improve the Person Re-identification Baseline in vitro. ICCV 2017. 20 | 21 | URL: ``_ 22 | 23 | Dataset statistics: 24 | - identities: 1404 (train + query). 25 | - images:16522 (train) + 2228 (query) + 17661 (gallery). 26 | - cameras: 8. 27 | """ 28 | dataset_dir = 'cuhk03' 29 | 30 | def __init__(self, root='', **kwargs): 31 | self.root = osp.abspath(osp.expanduser(root)) 32 | self.dataset_dir = osp.join(self.root, self.dataset_dir) 33 | self.train_dir = osp.join(self.dataset_dir, 'CUHK03_labeled/bounding_box_train') 34 | self.query_dir = osp.join(self.dataset_dir, 'CUHK03_labeled/query') 35 | self.gallery_dir = osp.join(self.dataset_dir, 'CUHK03_labeled/bounding_box_test') 36 | 37 | required_files = [ 38 | self.dataset_dir, 39 | self.train_dir, 40 | self.query_dir, 41 | self.gallery_dir 42 | ] 43 | self.check_before_run(required_files) 44 | 45 | train = self.process_dir(self.train_dir, relabel=True) 46 | query = self.process_dir(self.query_dir, relabel=False) 47 | gallery = self.process_dir(self.gallery_dir, relabel=False) 48 | 49 | super(CUHK03_Labeled, self).__init__(train, query, gallery, **kwargs) 50 | 51 | def process_dir(self, dir_path, relabel=False): 52 | img_paths = glob.glob(osp.join(dir_path, '*.png')) 53 | pattern = re.compile(r'([-\d]+)_c(\d)') 54 | 55 | pid_container = set() 56 | for img_path in img_paths: 57 | pid, _ = map(int, pattern.search(img_path).groups()) 58 | pid_container.add(pid) 59 | pid2label = {pid:label for label, pid in enumerate(pid_container)} 60 | 61 | data = [] 62 | for img_path in img_paths: 63 | pid, camid = map(int, pattern.search(img_path).groups()) 64 | #assert 1 <= camid <= 8 65 | camid -= 1 # index starts from 0 66 | if relabel: pid = pid2label[pid] 67 | data.append((img_path, pid, camid)) 68 | 69 | return data 70 | -------------------------------------------------------------------------------- /data_v2/datasets/image/cuhk03_splited.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import sys 6 | import os 7 | import os.path as osp 8 | 9 | from .. import ImageDataset 10 | from ..utils import mkdir_if_missing, read_json, write_json 11 | 12 | 13 | class CUHK03_splited(ImageDataset): 14 | """CUHK03. 15 | 16 | Reference: 17 | Li et al. DeepReID: Deep Filter Pairing Neural Network for Person Re-identification. CVPR 2014. 18 | 19 | URL: ``_ 20 | 21 | Dataset statistics: 22 | - identities: 1360. 23 | - images: 13164. 24 | - cameras: 6. 25 | - splits: 20 (classic). 26 | """ 27 | dataset_dir = 'CUHK03' 28 | #dataset_url = None 29 | 30 | def __init__(self, root='', split_id=0, cuhk03_labeled=False, cuhk03_classic_split=False, **kwargs): 31 | self.root = osp.abspath(osp.expanduser(root)) 32 | self.dataset_dir = osp.join(self.root, self.dataset_dir) 33 | #self.download_dataset(self.dataset_dir, self.dataset_url) 34 | 35 | # self.data_dir = osp.join(self.dataset_dir, 'cuhk03_release') 36 | # self.raw_mat_path = osp.join(self.data_dir, 'cuhk-03.mat') 37 | 38 | self.imgs_detected_dir = osp.join(self.dataset_dir, 'images_detected') 39 | self.imgs_labeled_dir = osp.join(self.dataset_dir, 'images_labeled') 40 | 41 | self.split_classic_det_json_path = osp.join( 42 | self.dataset_dir, 'splits_classic_detected.json') 43 | self.split_classic_lab_json_path = osp.join( 44 | self.dataset_dir, 'splits_classic_labeled.json') 45 | 46 | self.split_new_det_json_path = osp.join( 47 | self.dataset_dir, 'splits_new_detected.json') 48 | self.split_new_lab_json_path = osp.join( 49 | self.dataset_dir, 'splits_new_labeled.json') 50 | 51 | self.split_new_det_mat_path = osp.join( 52 | self.dataset_dir, 'cuhk03_new_protocol_config_detected.mat') 53 | self.split_new_lab_mat_path = osp.join( 54 | self.dataset_dir, 'cuhk03_new_protocol_config_labeled.mat') 55 | 56 | required_files = [ 57 | self.dataset_dir, 58 | # self.data_dir, 59 | # self.raw_mat_path, 60 | self.split_new_det_mat_path, 61 | self.split_new_lab_mat_path 62 | ] 63 | self.check_before_run(required_files) 64 | 65 | # self.preprocess_split() 66 | 67 | if cuhk03_labeled: 68 | split_path = self.split_classic_lab_json_path if cuhk03_classic_split else self.split_new_lab_json_path 69 | else: 70 | split_path = self.split_classic_det_json_path if cuhk03_classic_split else self.split_new_det_json_path 71 | 72 | splits = read_json(split_path) 73 | assert split_id < len(splits), 'Condition split_id ({}) < len(splits) ({}) is false'.format( 74 | split_id, len(splits)) 75 | split = splits[split_id] 76 | 77 | train = split['train'] 78 | query = split['query'] 79 | gallery = split['gallery'] 80 | new_train_list = [] 81 | new_query_list = [] 82 | new_gallery_list = [] 83 | for item in train: 84 | new_train_list.append( 85 | [osp.join(self.dataset_dir, item[0][31:]), item[1], item[2]]) 86 | for item in query: 87 | new_query_list.append( 88 | [osp.join(self.dataset_dir, item[0][31:]), item[1], item[2]]) 89 | for item in gallery: 90 | new_gallery_list.append( 91 | [osp.join(self.dataset_dir, item[0][31:]), item[1], item[2]]) 92 | super(CUHK03_splited, self).__init__(new_train_list, 93 | new_query_list, new_gallery_list, **kwargs) 94 | -------------------------------------------------------------------------------- /data_v2/datasets/image/dukemtmcreid.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import sys 6 | import os 7 | import os.path as osp 8 | import glob 9 | import re 10 | 11 | from .. import ImageDataset 12 | 13 | 14 | class DukeMTMCreID(ImageDataset): 15 | """DukeMTMC-reID. 16 | 17 | Reference: 18 | - Ristani et al. Performance Measures and a Data Set for Multi-Target, Multi-Camera Tracking. ECCVW 2016. 19 | - Zheng et al. Unlabeled Samples Generated by GAN Improve the Person Re-identification Baseline in vitro. ICCV 2017. 20 | 21 | URL: ``_ 22 | 23 | Dataset statistics: 24 | - identities: 1404 (train + query). 25 | - images:16522 (train) + 2228 (query) + 17661 (gallery). 26 | - cameras: 8. 27 | """ 28 | dataset_dir = '' 29 | dataset_url = 'http://vision.cs.duke.edu/DukeMTMC/data/misc/DukeMTMC-reID.zip' 30 | 31 | def __init__(self, root='', **kwargs): 32 | self.root = osp.abspath(osp.expanduser(root)) 33 | self.dataset_dir = osp.join(self.root, self.dataset_dir) 34 | self.download_dataset(self.dataset_dir, self.dataset_url) 35 | self.train_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/bounding_box_train') 36 | self.query_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/query') 37 | self.gallery_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/bounding_box_test') 38 | 39 | required_files = [ 40 | self.dataset_dir, 41 | self.train_dir, 42 | self.query_dir, 43 | self.gallery_dir 44 | ] 45 | self.check_before_run(required_files) 46 | 47 | train = self.process_dir(self.train_dir, relabel=True) 48 | query = self.process_dir(self.query_dir, relabel=False) 49 | gallery = self.process_dir(self.gallery_dir, relabel=False) 50 | 51 | super(DukeMTMCreID, self).__init__(train, query, gallery, **kwargs) 52 | 53 | def process_dir(self, dir_path, relabel=False): 54 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 55 | pattern = re.compile(r'([-\d]+)_c(\d)') 56 | 57 | pid_container = set() 58 | for img_path in img_paths: 59 | pid, _ = map(int, pattern.search(img_path).groups()) 60 | pid_container.add(pid) 61 | ####### modified ######## 62 | pid2label = {pid:label for label, pid in enumerate(sorted(pid_container))} 63 | # pid2label = {pid:label for label, pid in enumerate(pid_container)} 64 | 65 | ######################### 66 | 67 | data = [] 68 | for img_path in img_paths: 69 | pid, camid = map(int, pattern.search(img_path).groups()) 70 | assert 1 <= camid <= 8 71 | camid -= 1 # index starts from 0 72 | if relabel: pid = pid2label[pid] 73 | data.append((img_path, pid, camid)) 74 | 75 | return data -------------------------------------------------------------------------------- /data_v2/datasets/image/grid.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import sys 6 | import os 7 | import os.path as osp 8 | import glob 9 | from scipy.io import loadmat 10 | 11 | from .. import ImageDataset 12 | from ..utils import read_json, write_json 13 | 14 | 15 | class GRID(ImageDataset): 16 | """GRID. 17 | 18 | Reference: 19 | Loy et al. Multi-camera activity correlation analysis. CVPR 2009. 20 | 21 | URL: ``_ 22 | 23 | Dataset statistics: 24 | - identities: 250. 25 | - images: 1275. 26 | - cameras: 8. 27 | """ 28 | dataset_dir = 'grid' 29 | dataset_url = 'http://personal.ie.cuhk.edu.hk/~ccloy/files/datasets/underground_reid.zip' 30 | 31 | def __init__(self, root='', split_id=0, **kwargs): 32 | self.root = osp.abspath(osp.expanduser(root)) 33 | self.dataset_dir = osp.join(self.root, self.dataset_dir) 34 | self.download_dataset(self.dataset_dir, self.dataset_url) 35 | 36 | self.probe_path = osp.join(self.dataset_dir, 'underground_reid', 'probe') 37 | self.gallery_path = osp.join(self.dataset_dir, 'underground_reid', 'gallery') 38 | self.split_mat_path = osp.join(self.dataset_dir, 'underground_reid', 'features_and_partitions.mat') 39 | self.split_path = osp.join(self.dataset_dir, 'splits.json') 40 | 41 | required_files = [ 42 | self.dataset_dir, 43 | self.probe_path, 44 | self.gallery_path, 45 | self.split_mat_path 46 | ] 47 | self.check_before_run(required_files) 48 | 49 | self.prepare_split() 50 | splits = read_json(self.split_path) 51 | if split_id >= len(splits): 52 | raise ValueError('split_id exceeds range, received {}, ' 53 | 'but expected between 0 and {}'.format(split_id, len(splits)-1)) 54 | split = splits[split_id] 55 | 56 | train = split['train'] 57 | query = split['query'] 58 | gallery = split['gallery'] 59 | 60 | train = [tuple(item) for item in train] 61 | query = [tuple(item) for item in query] 62 | gallery = [tuple(item) for item in gallery] 63 | 64 | super(GRID, self).__init__(train, query, gallery, **kwargs) 65 | 66 | def prepare_split(self): 67 | if not osp.exists(self.split_path): 68 | print('Creating 10 random splits') 69 | split_mat = loadmat(self.split_mat_path) 70 | trainIdxAll = split_mat['trainIdxAll'][0] # length = 10 71 | probe_img_paths = sorted(glob.glob(osp.join(self.probe_path, '*.jpeg'))) 72 | gallery_img_paths = sorted(glob.glob(osp.join(self.gallery_path, '*.jpeg'))) 73 | 74 | splits = [] 75 | for split_idx in range(10): 76 | train_idxs = trainIdxAll[split_idx][0][0][2][0].tolist() 77 | assert len(train_idxs) == 125 78 | idx2label = {idx: label for label, idx in enumerate(train_idxs)} 79 | 80 | train, query, gallery = [], [], [] 81 | 82 | # processing probe folder 83 | for img_path in probe_img_paths: 84 | img_name = osp.basename(img_path) 85 | img_idx = int(img_name.split('_')[0]) 86 | camid = int(img_name.split('_')[1]) - 1 # index starts from 0 87 | if img_idx in train_idxs: 88 | train.append((img_path, idx2label[img_idx], camid)) 89 | else: 90 | query.append((img_path, img_idx, camid)) 91 | 92 | # process gallery folder 93 | for img_path in gallery_img_paths: 94 | img_name = osp.basename(img_path) 95 | img_idx = int(img_name.split('_')[0]) 96 | camid = int(img_name.split('_')[1]) - 1 # index starts from 0 97 | if img_idx in train_idxs: 98 | train.append((img_path, idx2label[img_idx], camid)) 99 | else: 100 | gallery.append((img_path, img_idx, camid)) 101 | 102 | split = { 103 | 'train': train, 104 | 'query': query, 105 | 'gallery': gallery, 106 | 'num_train_pids': 125, 107 | 'num_query_pids': 125, 108 | 'num_gallery_pids': 900 109 | } 110 | splits.append(split) 111 | 112 | print('Totally {} splits are created'.format(len(splits))) 113 | write_json(splits, self.split_path) 114 | print('Split file saved to {}'.format(self.split_path)) -------------------------------------------------------------------------------- /data_v2/datasets/image/ilids.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import sys 6 | import os 7 | import os.path as osp 8 | import glob 9 | import numpy as np 10 | import copy 11 | import random 12 | from collections import defaultdict 13 | 14 | from .. import ImageDataset 15 | from ..utils import read_json, write_json 16 | 17 | 18 | class iLIDS(ImageDataset): 19 | """QMUL-iLIDS. 20 | 21 | Reference: 22 | Zheng et al. Associating Groups of People. BMVC 2009. 23 | 24 | Dataset statistics: 25 | - identities: 119. 26 | - images: 476. 27 | - cameras: 8 (not explicitly provided). 28 | """ 29 | dataset_dir = 'ilids' 30 | dataset_url = 'http://www.eecs.qmul.ac.uk/~jason/data/i-LIDS_Pedestrian.tgz' 31 | 32 | def __init__(self, root='', split_id=0, **kwargs): 33 | self.root = osp.abspath(osp.expanduser(root)) 34 | self.dataset_dir = osp.join(self.root, self.dataset_dir) 35 | self.download_dataset(self.dataset_dir, self.dataset_url) 36 | 37 | self.data_dir = osp.join(self.dataset_dir, 'i-LIDS_Pedestrian/Persons') 38 | self.split_path = osp.join(self.dataset_dir, 'splits.json') 39 | 40 | required_files = [ 41 | self.dataset_dir, 42 | self.data_dir 43 | ] 44 | self.check_before_run(required_files) 45 | 46 | self.prepare_split() 47 | splits = read_json(self.split_path) 48 | if split_id >= len(splits): 49 | raise ValueError('split_id exceeds range, received {}, but ' 50 | 'expected between 0 and {}'.format(split_id, len(splits)-1)) 51 | split = splits[split_id] 52 | 53 | train, query, gallery = self.process_split(split) 54 | 55 | super(iLIDS, self).__init__(train, query, gallery, **kwargs) 56 | 57 | def prepare_split(self): 58 | if not osp.exists(self.split_path): 59 | print('Creating splits ...') 60 | 61 | paths = glob.glob(osp.join(self.data_dir, '*.jpg')) 62 | img_names = [osp.basename(path) for path in paths] 63 | num_imgs = len(img_names) 64 | assert num_imgs == 476, 'There should be 476 images, but ' \ 65 | 'got {}, please check the data'.format(num_imgs) 66 | 67 | # store image names 68 | # image naming format: 69 | # the first four digits denote the person ID 70 | # the last four digits denote the sequence index 71 | pid_dict = defaultdict(list) 72 | for img_name in img_names: 73 | pid = int(img_name[:4]) 74 | pid_dict[pid].append(img_name) 75 | pids = list(pid_dict.keys()) 76 | num_pids = len(pids) 77 | assert num_pids == 119, 'There should be 119 identities, ' \ 78 | 'but got {}, please check the data'.format(num_pids) 79 | 80 | num_train_pids = int(num_pids * 0.5) 81 | num_test_pids = num_pids - num_train_pids # supposed to be 60 82 | 83 | splits = [] 84 | for _ in range(10): 85 | # randomly choose num_train_pids train IDs and num_test_pids test IDs 86 | pids_copy = copy.deepcopy(pids) 87 | random.shuffle(pids_copy) 88 | train_pids = pids_copy[:num_train_pids] 89 | test_pids = pids_copy[num_train_pids:] 90 | 91 | train = [] 92 | query = [] 93 | gallery = [] 94 | 95 | # for train IDs, all images are used in the train set. 96 | for pid in train_pids: 97 | img_names = pid_dict[pid] 98 | train.extend(img_names) 99 | 100 | # for each test ID, randomly choose two images, one for 101 | # query and the other one for gallery. 102 | for pid in test_pids: 103 | img_names = pid_dict[pid] 104 | samples = random.sample(img_names, 2) 105 | query.append(samples[0]) 106 | gallery.append(samples[1]) 107 | 108 | split = {'train': train, 'query': query, 'gallery': gallery} 109 | splits.append(split) 110 | 111 | print('Totally {} splits are created'.format(len(splits))) 112 | write_json(splits, self.split_path) 113 | print('Split file is saved to {}'.format(self.split_path)) 114 | 115 | def get_pid2label(self, img_names): 116 | pid_container = set() 117 | for img_name in img_names: 118 | pid = int(img_name[:4]) 119 | pid_container.add(pid) 120 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 121 | return pid2label 122 | 123 | def parse_img_names(self, img_names, pid2label=None): 124 | data = [] 125 | 126 | for img_name in img_names: 127 | pid = int(img_name[:4]) 128 | if pid2label is not None: 129 | pid = pid2label[pid] 130 | camid = int(img_name[4:7]) - 1 # 0-based 131 | img_path = osp.join(self.data_dir, img_name) 132 | data.append((img_path, pid, camid)) 133 | 134 | return data 135 | 136 | def process_split(self, split): 137 | train, query, gallery = [], [], [] 138 | train_pid2label = self.get_pid2label(split['train']) 139 | train = self.parse_img_names(split['train'], train_pid2label) 140 | query = self.parse_img_names(split['query']) 141 | gallery = self.parse_img_names(split['gallery']) 142 | return train, query, gallery -------------------------------------------------------------------------------- /data_v2/datasets/image/market1501.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import sys 6 | import os 7 | import os.path as osp 8 | import glob 9 | import re 10 | import warnings 11 | 12 | from .. import ImageDataset 13 | 14 | 15 | class Market1501(ImageDataset): 16 | """Market1501. 17 | 18 | Reference: 19 | Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015. 20 | 21 | URL: ``_ 22 | 23 | Dataset statistics: 24 | - identities: 1501 (+1 for background). 25 | - images: 12936 (train) + 3368 (query) + 15913 (gallery). 26 | """ 27 | _junk_pids = [0, -1] 28 | dataset_dir = 'Market-1501' 29 | dataset_url = 'http://188.138.127.15:81/Datasets/Market-1501-v15.09.15.zip' 30 | 31 | def __init__(self, root='', market1501_500k=False, **kwargs): 32 | self.root = osp.abspath(osp.expanduser(root)) 33 | self.dataset_dir = osp.join(self.root, self.dataset_dir) 34 | self.download_dataset(self.dataset_dir, self.dataset_url) 35 | 36 | # allow alternative directory structure 37 | self.data_dir = self.dataset_dir 38 | data_dir = osp.join(self.data_dir, 'Market-1501-v15.09.15') 39 | if osp.isdir(data_dir): 40 | self.data_dir = data_dir 41 | else: 42 | warnings.warn('The current data structure is deprecated. Please ' 43 | 'put data folders such as "bounding_box_train" under ' 44 | '"Market-1501-v15.09.15".') 45 | 46 | self.train_dir = osp.join(self.data_dir, 'bounding_box_train') 47 | self.query_dir = osp.join(self.data_dir, 'query') 48 | self.gallery_dir = osp.join(self.data_dir, 'bounding_box_test') 49 | self.extra_gallery_dir = osp.join(self.data_dir, 'images') 50 | self.market1501_500k = market1501_500k 51 | 52 | required_files = [ 53 | self.data_dir, 54 | self.train_dir, 55 | self.query_dir, 56 | self.gallery_dir 57 | ] 58 | if self.market1501_500k: 59 | required_files.append(self.extra_gallery_dir) 60 | self.check_before_run(required_files) 61 | 62 | train = self.process_dir(self.train_dir, relabel=True) 63 | query = self.process_dir(self.query_dir, relabel=False) 64 | gallery = self.process_dir(self.gallery_dir, relabel=False) 65 | if self.market1501_500k: 66 | gallery += self.process_dir(self.extra_gallery_dir, relabel=False) 67 | 68 | super(Market1501, self).__init__(train, query, gallery, **kwargs) 69 | 70 | def process_dir(self, dir_path, relabel=False): 71 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 72 | pattern = re.compile(r'([-\d]+)_c(\d)') 73 | 74 | pid_container = set() 75 | for img_path in img_paths: 76 | pid, _ = map(int, pattern.search(img_path).groups()) 77 | if pid == -1: 78 | continue # junk images are just ignored 79 | pid_container.add(pid) 80 | ####### modified ####### 81 | pid2label = {pid:label for label, pid in enumerate(sorted(pid_container))} 82 | # pid2label = {pid:label for label, pid in enumerate(pid_container)} 83 | 84 | ######################## 85 | 86 | data = [] 87 | for img_path in img_paths: 88 | pid, camid = map(int, pattern.search(img_path).groups()) 89 | if pid == -1: 90 | continue # junk images are just ignored 91 | assert 0 <= pid <= 1501 # pid == 0 means background 92 | assert 1 <= camid <= 6 93 | camid -= 1 # index starts from 0 94 | if relabel: 95 | pid = pid2label[pid] 96 | data.append((img_path, pid, camid)) 97 | 98 | return data -------------------------------------------------------------------------------- /data_v2/datasets/image/mot17.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import sys 6 | import os 7 | import os.path as osp 8 | import glob 9 | import re 10 | 11 | from .. import ImageDataset 12 | 13 | 14 | class MOT17(ImageDataset): 15 | """DukeMTMC-reID. 16 | 17 | Reference: 18 | - Ristani et al. Performance Measures and a Data Set for Multi-Target, Multi-Camera Tracking. ECCVW 2016. 19 | - Zheng et al. Unlabeled Samples Generated by GAN Improve the Person Re-identification Baseline in vitro. ICCV 2017. 20 | 21 | URL: ``_ 22 | 23 | Dataset statistics: 24 | - identities: 1404 (train + query). 25 | - images:16522 (train) + 2228 (query) + 17661 (gallery). 26 | - cameras: 8. 27 | """ 28 | dataset_dir = '' 29 | dataset_url = 'http://vision.cs.duke.edu/DukeMTMC/data/misc/DukeMTMC-reID.zip' 30 | 31 | def __init__(self, root='', **kwargs): 32 | self.root = osp.abspath(osp.expanduser(root)) 33 | self.dataset_dir = osp.join(self.root, self.dataset_dir) 34 | self.download_dataset(self.dataset_dir, self.dataset_url) 35 | self.train_dir = osp.join(self.dataset_dir, 'train') 36 | # self.query_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/query') 37 | # self.gallery_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/bounding_box_test') 38 | 39 | required_files = [ 40 | self.dataset_dir, 41 | self.train_dir, 42 | # self.query_dir, 43 | # self.gallery_dir 44 | ] 45 | self.check_before_run(required_files) 46 | 47 | train = self.process_dir(self.train_dir, relabel=True) 48 | query = self.process_dir(self.train_dir, relabel=False) 49 | gallery = self.process_dir(self.train_dir, relabel=False) 50 | # query = self.process_dir(self.query_dir, relabel=False) 51 | # gallery = self.process_dir(self.gallery_dir, relabel=False) 52 | 53 | super(MOT17, self).__init__(train,query,gallery, **kwargs) 54 | 55 | def process_dir(self, dir_path, relabel=False): 56 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 57 | pattern = re.compile(r'([-\d]+)_(\d)') 58 | 59 | pid_container = set() 60 | for img_path in img_paths: 61 | pid, _ = map(int, pattern.search(img_path).groups()) 62 | pid_container.add(pid) 63 | pid2label = {pid:label for label, pid in enumerate(pid_container)} 64 | 65 | data = [] 66 | for img_path in img_paths: 67 | pid, camid = map(int, pattern.search(img_path).groups()) 68 | assert 1 <= camid <= 8 69 | camid -= 1 # index starts from 0 70 | if relabel: pid = pid2label[pid] 71 | data.append((img_path, pid, camid)) 72 | 73 | return data -------------------------------------------------------------------------------- /data_v2/datasets/image/msmt17.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import sys 6 | import os 7 | import os.path as osp 8 | import glob 9 | import re 10 | 11 | from .. import ImageDataset 12 | 13 | 14 | class MSMT17(ImageDataset): 15 | """ 16 | MSMT17. 17 | 18 | Reference: 19 | Wei et al. Person Transfer GAN to Bridge Domain Gap for Person Re-Identification. CVPR 2018. 20 | 21 | URL: ``_ 22 | 23 | Dataset statistics: 24 | - identities: 4101. 25 | - images: 32621 (train) + 11659 (query) + 82161 (gallery). 26 | - cameras: 15. 27 | """ 28 | 29 | dataset_dir = 'msmt17' 30 | 31 | def __init__(self, root='', **kwargs): 32 | self.root = osp.abspath(osp.expanduser(root)) 33 | self.dataset_dir = osp.join(self.root, self.dataset_dir) 34 | self.train_dir = osp.join(self.dataset_dir, 'MSMT17/bounding_box_train') 35 | self.query_dir = osp.join(self.dataset_dir, 'MSMT17/query') 36 | self.gallery_dir = osp.join(self.dataset_dir, 'MSMT17/bounding_box_test') 37 | 38 | required_files = [ 39 | self.dataset_dir, 40 | self.train_dir, 41 | self.query_dir, 42 | self.gallery_dir 43 | ] 44 | self.check_before_run(required_files) 45 | 46 | train = self.process_dir(self.train_dir, relabel=True) 47 | query = self.process_dir(self.query_dir, relabel=False) 48 | gallery = self.process_dir(self.gallery_dir, relabel=False) 49 | 50 | super(MSMT17, self).__init__(train, query, gallery, **kwargs) 51 | 52 | def process_dir(self, dir_path, relabel=False): 53 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 54 | pattern = re.compile(r'([-\d]+)_c(\d)') 55 | 56 | pid_container = set() 57 | for img_path in img_paths: 58 | pid, _ = map(int, pattern.search(img_path).groups()) 59 | pid_container.add(pid) 60 | pid2label = {pid:label for label, pid in enumerate(pid_container)} 61 | 62 | data = [] 63 | for img_path in img_paths: 64 | pid, camid = map(int, pattern.search(img_path).groups()) 65 | assert 1 <= camid <= 15 66 | camid -= 1 # index starts from 0 67 | if relabel: pid = pid2label[pid] 68 | data.append((img_path, pid, camid)) 69 | 70 | return data 71 | -------------------------------------------------------------------------------- /data_v2/datasets/image/prid.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import sys 6 | import os 7 | import os.path as osp 8 | import random 9 | 10 | from .. import ImageDataset 11 | from ..utils import read_json, write_json 12 | 13 | 14 | class PRID(ImageDataset): 15 | """PRID (single-shot version of prid-2011) 16 | 17 | Reference: 18 | Hirzer et al. Person Re-Identification by Descriptive and Discriminative 19 | Classification. SCIA 2011. 20 | 21 | URL: ``_ 22 | 23 | Dataset statistics: 24 | - Two views. 25 | - View A captures 385 identities. 26 | - View B captures 749 identities. 27 | - 200 identities appear in both views. 28 | """ 29 | dataset_dir = 'prid2011' 30 | dataset_url = None 31 | 32 | def __init__(self, root='', split_id=0, **kwargs): 33 | self.root = osp.abspath(osp.expanduser(root)) 34 | self.dataset_dir = osp.join(self.root, self.dataset_dir) 35 | self.download_dataset(self.dataset_dir, self.dataset_url) 36 | 37 | self.cam_a_dir = osp.join(self.dataset_dir, 'prid_2011', 'single_shot', 'cam_a') 38 | self.cam_b_dir = osp.join(self.dataset_dir, 'prid_2011', 'single_shot', 'cam_b') 39 | self.split_path = osp.join(self.dataset_dir, 'splits_single_shot.json') 40 | 41 | required_files = [ 42 | self.dataset_dir, 43 | self.cam_a_dir, 44 | self.cam_b_dir 45 | ] 46 | self.check_before_run(required_files) 47 | 48 | self.prepare_split() 49 | splits = read_json(self.split_path) 50 | if split_id >= len(splits): 51 | raise ValueError('split_id exceeds range, received {}, but expected between 0 and {}'.format(split_id, len(splits)-1)) 52 | split = splits[split_id] 53 | 54 | train, query, gallery = self.process_split(split) 55 | 56 | super(PRID, self).__init__(train, query, gallery, **kwargs) 57 | 58 | def prepare_split(self): 59 | if not osp.exists(self.split_path): 60 | print('Creating splits ...') 61 | 62 | splits = [] 63 | for _ in range(10): 64 | # randomly sample 100 IDs for train and use the rest 100 IDs for test 65 | # (note: there are only 200 IDs appearing in both views) 66 | pids = [i for i in range(1, 201)] 67 | train_pids = random.sample(pids, 100) 68 | train_pids.sort() 69 | test_pids = [i for i in pids if i not in train_pids] 70 | split = {'train': train_pids, 'test': test_pids} 71 | splits.append(split) 72 | 73 | print('Totally {} splits are created'.format(len(splits))) 74 | write_json(splits, self.split_path) 75 | print('Split file is saved to {}'.format(self.split_path)) 76 | 77 | def process_split(self, split): 78 | train, query, gallery = [], [], [] 79 | train_pids = split['train'] 80 | test_pids = split['test'] 81 | 82 | train_pid2label = {pid: label for label, pid in enumerate(train_pids)} 83 | 84 | # train 85 | train = [] 86 | for pid in train_pids: 87 | img_name = 'person_' + str(pid).zfill(4) + '.png' 88 | pid = train_pid2label[pid] 89 | img_a_path = osp.join(self.cam_a_dir, img_name) 90 | train.append((img_a_path, pid, 0)) 91 | img_b_path = osp.join(self.cam_b_dir, img_name) 92 | train.append((img_b_path, pid, 1)) 93 | 94 | # query and gallery 95 | query, gallery = [], [] 96 | for pid in test_pids: 97 | img_name = 'person_' + str(pid).zfill(4) + '.png' 98 | img_a_path = osp.join(self.cam_a_dir, img_name) 99 | query.append((img_a_path, pid, 0)) 100 | img_b_path = osp.join(self.cam_b_dir, img_name) 101 | gallery.append((img_b_path, pid, 1)) 102 | for pid in range(201, 750): 103 | img_name = 'person_' + str(pid).zfill(4) + '.png' 104 | img_b_path = osp.join(self.cam_b_dir, img_name) 105 | gallery.append((img_b_path, pid, 1)) 106 | 107 | return train, query, gallery -------------------------------------------------------------------------------- /data_v2/datasets/image/sensereid.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import sys 6 | import os 7 | import os.path as osp 8 | import glob 9 | import copy 10 | 11 | from .. import ImageDataset 12 | 13 | 14 | class SenseReID(ImageDataset): 15 | """SenseReID. 16 | 17 | This dataset is used for test purpose only. 18 | 19 | Reference: 20 | Zhao et al. Spindle Net: Person Re-identification with Human Body 21 | Region Guided Feature Decomposition and Fusion. CVPR 2017. 22 | 23 | URL: ``_ 24 | 25 | Dataset statistics: 26 | - query: 522 ids, 1040 images. 27 | - gallery: 1717 ids, 3388 images. 28 | """ 29 | dataset_dir = 'sensereid' 30 | dataset_url = None 31 | 32 | def __init__(self, root='', **kwargs): 33 | self.root = osp.abspath(osp.expanduser(root)) 34 | self.dataset_dir = osp.join(self.root, self.dataset_dir) 35 | self.download_dataset(self.dataset_dir, self.dataset_url) 36 | 37 | self.query_dir = osp.join(self.dataset_dir, 'SenseReID', 'test_probe') 38 | self.gallery_dir = osp.join(self.dataset_dir, 'SenseReID', 'test_gallery') 39 | 40 | required_files = [ 41 | self.dataset_dir, 42 | self.query_dir, 43 | self.gallery_dir 44 | ] 45 | self.check_before_run(required_files) 46 | 47 | query = self.process_dir(self.query_dir) 48 | gallery = self.process_dir(self.gallery_dir) 49 | 50 | # relabel 51 | g_pids = set() 52 | for _, pid, _ in gallery: 53 | g_pids.add(pid) 54 | pid2label = {pid: i for i, pid in enumerate(g_pids)} 55 | 56 | query = [(img_path, pid2label[pid], camid) for img_path, pid, camid in query] 57 | gallery = [(img_path, pid2label[pid], camid) for img_path, pid, camid in gallery] 58 | train = copy.deepcopy(query) + copy.deepcopy(gallery) # dummy variable 59 | 60 | super(SenseReID, self).__init__(train, query, gallery, **kwargs) 61 | 62 | def process_dir(self, dir_path): 63 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 64 | data = [] 65 | 66 | for img_path in img_paths: 67 | img_name = osp.splitext(osp.basename(img_path))[0] 68 | pid, camid = img_name.split('_') 69 | pid, camid = int(pid), int(camid) 70 | data.append((img_path, pid, camid)) 71 | 72 | return data -------------------------------------------------------------------------------- /data_v2/datasets/image/viper.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import sys 6 | import os 7 | import os.path as osp 8 | import glob 9 | import numpy as np 10 | 11 | from .. import ImageDataset 12 | from ..utils import read_json, write_json 13 | 14 | 15 | class VIPeR(ImageDataset): 16 | """VIPeR. 17 | 18 | Reference: 19 | Gray et al. Evaluating appearance models for recognition, reacquisition, and tracking. PETS 2007. 20 | 21 | URL: ``_ 22 | 23 | Dataset statistics: 24 | - identities: 632. 25 | - images: 632 x 2 = 1264. 26 | - cameras: 2. 27 | """ 28 | dataset_dir = 'viper' 29 | dataset_url = 'http://users.soe.ucsc.edu/~manduchi/VIPeR.v1.0.zip' 30 | 31 | def __init__(self, root='', split_id=0, **kwargs): 32 | self.root = osp.abspath(osp.expanduser(root)) 33 | self.dataset_dir = osp.join(self.root, self.dataset_dir) 34 | self.download_dataset(self.dataset_dir, self.dataset_url) 35 | 36 | self.cam_a_dir = osp.join(self.dataset_dir, 'VIPeR', 'cam_a') 37 | self.cam_b_dir = osp.join(self.dataset_dir, 'VIPeR', 'cam_b') 38 | self.split_path = osp.join(self.dataset_dir, 'splits.json') 39 | 40 | required_files = [ 41 | self.dataset_dir, 42 | self.cam_a_dir, 43 | self.cam_b_dir 44 | ] 45 | self.check_before_run(required_files) 46 | 47 | self.prepare_split() 48 | splits = read_json(self.split_path) 49 | if split_id >= len(splits): 50 | raise ValueError('split_id exceeds range, received {}, ' 51 | 'but expected between 0 and {}'.format(split_id, len(splits)-1)) 52 | split = splits[split_id] 53 | 54 | train = split['train'] 55 | query = split['query'] # query and gallery share the same images 56 | gallery = split['gallery'] 57 | 58 | train = [tuple(item) for item in train] 59 | query = [tuple(item) for item in query] 60 | gallery = [tuple(item) for item in gallery] 61 | 62 | super(VIPeR, self).__init__(train, query, gallery, **kwargs) 63 | 64 | def prepare_split(self): 65 | if not osp.exists(self.split_path): 66 | print('Creating 10 random splits of train ids and test ids') 67 | 68 | cam_a_imgs = sorted(glob.glob(osp.join(self.cam_a_dir, '*.bmp'))) 69 | cam_b_imgs = sorted(glob.glob(osp.join(self.cam_b_dir, '*.bmp'))) 70 | assert len(cam_a_imgs) == len(cam_b_imgs) 71 | num_pids = len(cam_a_imgs) 72 | print('Number of identities: {}'.format(num_pids)) 73 | num_train_pids = num_pids // 2 74 | 75 | """ 76 | In total, there will be 20 splits because each random split creates two 77 | sub-splits, one using cameraA as query and cameraB as gallery 78 | while the other using cameraB as query and cameraA as gallery. 79 | Therefore, results should be averaged over 20 splits (split_id=0~19). 80 | 81 | In practice, a model trained on split_id=0 can be applied to split_id=0&1 82 | as split_id=0&1 share the same training data (so on and so forth). 83 | """ 84 | splits = [] 85 | for _ in range(10): 86 | order = np.arange(num_pids) 87 | np.random.shuffle(order) 88 | train_idxs = order[:num_train_pids] 89 | test_idxs = order[num_train_pids:] 90 | assert not bool(set(train_idxs) & set(test_idxs)), 'Error: train and test overlap' 91 | 92 | train = [] 93 | for pid, idx in enumerate(train_idxs): 94 | cam_a_img = cam_a_imgs[idx] 95 | cam_b_img = cam_b_imgs[idx] 96 | train.append((cam_a_img, pid, 0)) 97 | train.append((cam_b_img, pid, 1)) 98 | 99 | test_a = [] 100 | test_b = [] 101 | for pid, idx in enumerate(test_idxs): 102 | cam_a_img = cam_a_imgs[idx] 103 | cam_b_img = cam_b_imgs[idx] 104 | test_a.append((cam_a_img, pid, 0)) 105 | test_b.append((cam_b_img, pid, 1)) 106 | 107 | # use cameraA as query and cameraB as gallery 108 | split = { 109 | 'train': train, 110 | 'query': test_a, 111 | 'gallery': test_b, 112 | 'num_train_pids': num_train_pids, 113 | 'num_query_pids': num_pids - num_train_pids, 114 | 'num_gallery_pids': num_pids - num_train_pids 115 | } 116 | splits.append(split) 117 | 118 | # use cameraB as query and cameraA as gallery 119 | split = { 120 | 'train': train, 121 | 'query': test_b, 122 | 'gallery': test_a, 123 | 'num_train_pids': num_train_pids, 124 | 'num_query_pids': num_pids - num_train_pids, 125 | 'num_gallery_pids': num_pids - num_train_pids 126 | } 127 | splits.append(split) 128 | 129 | print('Totally {} splits are created'.format(len(splits))) 130 | write_json(splits, self.split_path) 131 | print('Split file saved to {}'.format(self.split_path)) -------------------------------------------------------------------------------- /data_v2/datasets/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | __all__ = ['mkdir_if_missing', 'check_isfile', 'read_json', 'write_json', 6 | 'set_random_seed', 'download_url', 'read_image', 'collect_env_info'] 7 | 8 | import sys 9 | import os 10 | import os.path as osp 11 | import time 12 | import errno 13 | import json 14 | from collections import OrderedDict 15 | import warnings 16 | import random 17 | import numpy as np 18 | import PIL 19 | from PIL import Image 20 | 21 | import torch 22 | 23 | 24 | def mkdir_if_missing(dirname): 25 | """Creates dirname if it is missing.""" 26 | if not osp.exists(dirname): 27 | try: 28 | os.makedirs(dirname) 29 | except OSError as e: 30 | if e.errno != errno.EEXIST: 31 | raise 32 | 33 | 34 | def check_isfile(fpath): 35 | """Checks if the given path is a file. 36 | 37 | Args: 38 | fpath (str): file path. 39 | 40 | Returns: 41 | bool 42 | """ 43 | isfile = osp.isfile(fpath) 44 | if not isfile: 45 | warnings.warn('No file found at "{}"'.format(fpath)) 46 | return isfile 47 | 48 | 49 | def read_json(fpath): 50 | """Reads json file from a path.""" 51 | with open(fpath, 'r') as f: 52 | obj = json.load(f) 53 | return obj 54 | 55 | 56 | def write_json(obj, fpath): 57 | """Writes to a json file.""" 58 | mkdir_if_missing(osp.dirname(fpath)) 59 | with open(fpath, 'w') as f: 60 | json.dump(obj, f, indent=4, separators=(',', ': ')) 61 | 62 | 63 | def set_random_seed(seed): 64 | random.seed(seed) 65 | np.random.seed(seed) 66 | torch.manual_seed(seed) 67 | torch.cuda.manual_seed_all(seed) 68 | 69 | 70 | def download_url(url, dst): 71 | """Downloads file from a url to a destination. 72 | 73 | Args: 74 | url (str): url to download file. 75 | dst (str): destination path. 76 | """ 77 | from six.moves import urllib 78 | print('* url="{}"'.format(url)) 79 | print('* destination="{}"'.format(dst)) 80 | 81 | def _reporthook(count, block_size, total_size): 82 | global start_time 83 | if count == 0: 84 | start_time = time.time() 85 | return 86 | duration = time.time() - start_time 87 | progress_size = int(count * block_size) 88 | speed = int(progress_size / (1024 * duration)) 89 | percent = int(count * block_size * 100 / total_size) 90 | sys.stdout.write('\r...%d%%, %d MB, %d KB/s, %d seconds passed' % 91 | (percent, progress_size / (1024 * 1024), speed, duration)) 92 | sys.stdout.flush() 93 | 94 | urllib.request.urlretrieve(url, dst, _reporthook) 95 | sys.stdout.write('\n') 96 | 97 | 98 | def read_image(path): 99 | """Reads image from path using ``PIL.Image``. 100 | 101 | Args: 102 | path (str): path to an image. 103 | 104 | Returns: 105 | PIL image 106 | """ 107 | got_img = False 108 | if not osp.exists(path): 109 | raise IOError('"{}" does not exist'.format(path)) 110 | while not got_img: 111 | try: 112 | img = Image.open(path).convert('RGB') 113 | got_img = True 114 | except IOError: 115 | print('IOError incurred when reading "{}". Will redo. Don\'t worry. Just chill.'.format(img_path)) 116 | pass 117 | return img 118 | 119 | 120 | def collect_env_info(): 121 | """Returns env info as a string. 122 | 123 | Code source: github.com/facebookresearch/maskrcnn-benchmark 124 | """ 125 | from torch.utils.collect_env import get_pretty_env_info 126 | env_str = get_pretty_env_info() 127 | env_str += '\n Pillow ({})'.format(PIL.__version__) 128 | return env_str 129 | -------------------------------------------------------------------------------- /data_v2/datasets/video/Icon : -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jixunbo/LightMBN/1f4ec0de6535604f89257195f5c083845491dc59/data_v2/datasets/video/Icon -------------------------------------------------------------------------------- /data_v2/datasets/video/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | 4 | from .mars import Mars 5 | from .ilidsvid import iLIDSVID 6 | from .prid2011 import PRID2011 7 | from .dukemtmcvidreid import DukeMTMCVidReID -------------------------------------------------------------------------------- /data_v2/datasets/video/dukemtmcvidreid.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import sys 6 | import os 7 | import os.path as osp 8 | import glob 9 | import warnings 10 | 11 | from .. import VideoDataset 12 | from ..utils import read_json, write_json 13 | 14 | 15 | class DukeMTMCVidReID(VideoDataset): 16 | """DukeMTMCVidReID. 17 | 18 | Reference: 19 | - Ristani et al. Performance Measures and a Data Set for Multi-Target, 20 | Multi-Camera Tracking. ECCVW 2016. 21 | - Wu et al. Exploit the Unknown Gradually: One-Shot Video-Based Person 22 | Re-Identification by Stepwise Learning. CVPR 2018. 23 | 24 | URL: ``_ 25 | 26 | Dataset statistics: 27 | - identities: 702 (train) + 702 (test). 28 | - tracklets: 2196 (train) + 2636 (test). 29 | """ 30 | dataset_dir = 'dukemtmc-vidreid' 31 | dataset_url = 'http://vision.cs.duke.edu/DukeMTMC/data/misc/DukeMTMC-VideoReID.zip' 32 | 33 | def __init__(self, root='', min_seq_len=0, **kwargs): 34 | self.root = osp.abspath(osp.expanduser(root)) 35 | self.dataset_dir = osp.join(self.root, self.dataset_dir) 36 | self.download_dataset(self.dataset_dir, self.dataset_url) 37 | 38 | self.train_dir = osp.join(self.dataset_dir, 'DukeMTMC-VideoReID/train') 39 | self.query_dir = osp.join(self.dataset_dir, 'DukeMTMC-VideoReID/query') 40 | self.gallery_dir = osp.join(self.dataset_dir, 'DukeMTMC-VideoReID/gallery') 41 | self.split_train_json_path = osp.join(self.dataset_dir, 'split_train.json') 42 | self.split_query_json_path = osp.join(self.dataset_dir, 'split_query.json') 43 | self.split_gallery_json_path = osp.join(self.dataset_dir, 'split_gallery.json') 44 | self.min_seq_len = min_seq_len 45 | 46 | required_files = [ 47 | self.dataset_dir, 48 | self.train_dir, 49 | self.query_dir, 50 | self.gallery_dir 51 | ] 52 | self.check_before_run(required_files) 53 | 54 | train = self.process_dir(self.train_dir, self.split_train_json_path, relabel=True) 55 | query = self.process_dir(self.query_dir, self.split_query_json_path, relabel=False) 56 | gallery = self.process_dir(self.gallery_dir, self.split_gallery_json_path, relabel=False) 57 | 58 | super(DukeMTMCVidReID, self).__init__(train, query, gallery, **kwargs) 59 | 60 | def process_dir(self, dir_path, json_path, relabel): 61 | if osp.exists(json_path): 62 | split = read_json(json_path) 63 | return split['tracklets'] 64 | 65 | print('=> Generating split json file (** this might take a while **)') 66 | pdirs = glob.glob(osp.join(dir_path, '*')) # avoid .DS_Store 67 | print('Processing "{}" with {} person identities'.format(dir_path, len(pdirs))) 68 | 69 | pid_container = set() 70 | for pdir in pdirs: 71 | pid = int(osp.basename(pdir)) 72 | pid_container.add(pid) 73 | pid2label = {pid:label for label, pid in enumerate(pid_container)} 74 | 75 | tracklets = [] 76 | for pdir in pdirs: 77 | pid = int(osp.basename(pdir)) 78 | if relabel: 79 | pid = pid2label[pid] 80 | tdirs = glob.glob(osp.join(pdir, '*')) 81 | for tdir in tdirs: 82 | raw_img_paths = glob.glob(osp.join(tdir, '*.jpg')) 83 | num_imgs = len(raw_img_paths) 84 | 85 | if num_imgs < self.min_seq_len: 86 | continue 87 | 88 | img_paths = [] 89 | for img_idx in range(num_imgs): 90 | # some tracklet starts from 0002 instead of 0001 91 | img_idx_name = 'F' + str(img_idx+1).zfill(4) 92 | res = glob.glob(osp.join(tdir, '*' + img_idx_name + '*.jpg')) 93 | if len(res) == 0: 94 | warnings.warn('Index name {} in {} is missing, skip'.format(img_idx_name, tdir)) 95 | continue 96 | img_paths.append(res[0]) 97 | img_name = osp.basename(img_paths[0]) 98 | if img_name.find('_') == -1: 99 | # old naming format: 0001C6F0099X30823.jpg 100 | camid = int(img_name[5]) - 1 101 | else: 102 | # new naming format: 0001_C6_F0099_X30823.jpg 103 | camid = int(img_name[6]) - 1 104 | img_paths = tuple(img_paths) 105 | tracklets.append((img_paths, pid, camid)) 106 | 107 | print('Saving split to {}'.format(json_path)) 108 | split_dict = {'tracklets': tracklets} 109 | write_json(split_dict, json_path) 110 | 111 | return tracklets -------------------------------------------------------------------------------- /data_v2/datasets/video/ilidsvid.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import sys 6 | import os 7 | import os.path as osp 8 | import glob 9 | from scipy.io import loadmat 10 | 11 | from .. import VideoDataset 12 | from ..utils import read_json, write_json 13 | 14 | 15 | class iLIDSVID(VideoDataset): 16 | """iLIDS-VID. 17 | 18 | Reference: 19 | Wang et al. Person Re-Identification by Video Ranking. ECCV 2014. 20 | 21 | URL: ``_ 22 | 23 | Dataset statistics: 24 | - identities: 300. 25 | - tracklets: 600. 26 | - cameras: 2. 27 | """ 28 | dataset_dir = 'ilids-vid' 29 | dataset_url = 'http://www.eecs.qmul.ac.uk/~xiatian/iLIDS-VID/iLIDS-VID.tar' 30 | 31 | def __init__(self, root='', split_id=0, **kwargs): 32 | self.root = osp.abspath(osp.expanduser(root)) 33 | self.dataset_dir = osp.join(self.root, self.dataset_dir) 34 | self.download_dataset(self.dataset_dir, self.dataset_url) 35 | 36 | self.data_dir = osp.join(self.dataset_dir, 'i-LIDS-VID') 37 | self.split_dir = osp.join(self.dataset_dir, 'train-test people splits') 38 | self.split_mat_path = osp.join(self.split_dir, 'train_test_splits_ilidsvid.mat') 39 | self.split_path = osp.join(self.dataset_dir, 'splits.json') 40 | self.cam_1_path = osp.join(self.dataset_dir, 'i-LIDS-VID/sequences/cam1') 41 | self.cam_2_path = osp.join(self.dataset_dir, 'i-LIDS-VID/sequences/cam2') 42 | 43 | required_files = [ 44 | self.dataset_dir, 45 | self.data_dir, 46 | self.split_dir 47 | ] 48 | self.check_before_run(required_files) 49 | 50 | self.prepare_split() 51 | splits = read_json(self.split_path) 52 | if split_id >= len(splits): 53 | raise ValueError('split_id exceeds range, received {}, but expected between 0 and {}'.format(split_id, len(splits)-1)) 54 | split = splits[split_id] 55 | train_dirs, test_dirs = split['train'], split['test'] 56 | 57 | train = self.process_data(train_dirs, cam1=True, cam2=True) 58 | query = self.process_data(test_dirs, cam1=True, cam2=False) 59 | gallery = self.process_data(test_dirs, cam1=False, cam2=True) 60 | 61 | super(iLIDSVID, self).__init__(train, query, gallery, **kwargs) 62 | 63 | def prepare_split(self): 64 | if not osp.exists(self.split_path): 65 | print('Creating splits ...') 66 | mat_split_data = loadmat(self.split_mat_path)['ls_set'] 67 | 68 | num_splits = mat_split_data.shape[0] 69 | num_total_ids = mat_split_data.shape[1] 70 | assert num_splits == 10 71 | assert num_total_ids == 300 72 | num_ids_each = num_total_ids // 2 73 | 74 | # pids in mat_split_data are indices, so we need to transform them 75 | # to real pids 76 | person_cam1_dirs = sorted(glob.glob(osp.join(self.cam_1_path, '*'))) 77 | person_cam2_dirs = sorted(glob.glob(osp.join(self.cam_2_path, '*'))) 78 | 79 | person_cam1_dirs = [osp.basename(item) for item in person_cam1_dirs] 80 | person_cam2_dirs = [osp.basename(item) for item in person_cam2_dirs] 81 | 82 | # make sure persons in one camera view can be found in the other camera view 83 | assert set(person_cam1_dirs) == set(person_cam2_dirs) 84 | 85 | splits = [] 86 | for i_split in range(num_splits): 87 | # first 50% for testing and the remaining for training, following Wang et al. ECCV'14. 88 | train_idxs = sorted(list(mat_split_data[i_split, num_ids_each:])) 89 | test_idxs = sorted(list(mat_split_data[i_split, :num_ids_each])) 90 | 91 | train_idxs = [int(i)-1 for i in train_idxs] 92 | test_idxs = [int(i)-1 for i in test_idxs] 93 | 94 | # transform pids to person dir names 95 | train_dirs = [person_cam1_dirs[i] for i in train_idxs] 96 | test_dirs = [person_cam1_dirs[i] for i in test_idxs] 97 | 98 | split = {'train': train_dirs, 'test': test_dirs} 99 | splits.append(split) 100 | 101 | print('Totally {} splits are created, following Wang et al. ECCV\'14'.format(len(splits))) 102 | print('Split file is saved to {}'.format(self.split_path)) 103 | write_json(splits, self.split_path) 104 | 105 | def process_data(self, dirnames, cam1=True, cam2=True): 106 | tracklets = [] 107 | dirname2pid = {dirname:i for i, dirname in enumerate(dirnames)} 108 | 109 | for dirname in dirnames: 110 | if cam1: 111 | person_dir = osp.join(self.cam_1_path, dirname) 112 | img_names = glob.glob(osp.join(person_dir, '*.png')) 113 | assert len(img_names) > 0 114 | img_names = tuple(img_names) 115 | pid = dirname2pid[dirname] 116 | tracklets.append((img_names, pid, 0)) 117 | 118 | if cam2: 119 | person_dir = osp.join(self.cam_2_path, dirname) 120 | img_names = glob.glob(osp.join(person_dir, '*.png')) 121 | assert len(img_names) > 0 122 | img_names = tuple(img_names) 123 | pid = dirname2pid[dirname] 124 | tracklets.append((img_names, pid, 1)) 125 | 126 | return tracklets -------------------------------------------------------------------------------- /data_v2/datasets/video/mars.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import sys 6 | import os 7 | import os.path as osp 8 | from scipy.io import loadmat 9 | import warnings 10 | 11 | from .. import VideoDataset 12 | 13 | 14 | class Mars(VideoDataset): 15 | """MARS. 16 | 17 | Reference: 18 | Zheng et al. MARS: A Video Benchmark for Large-Scale Person Re-identification. ECCV 2016. 19 | 20 | URL: ``_ 21 | 22 | Dataset statistics: 23 | - identities: 1261. 24 | - tracklets: 8298 (train) + 1980 (query) + 9330 (gallery). 25 | - cameras: 6. 26 | """ 27 | dataset_dir = 'mars' 28 | dataset_url = None 29 | 30 | def __init__(self, root='', **kwargs): 31 | self.root = osp.abspath(osp.expanduser(root)) 32 | self.dataset_dir = osp.join(self.root, self.dataset_dir) 33 | self.download_dataset(self.dataset_dir, self.dataset_url) 34 | 35 | self.train_name_path = osp.join(self.dataset_dir, 'info/train_name.txt') 36 | self.test_name_path = osp.join(self.dataset_dir, 'info/test_name.txt') 37 | self.track_train_info_path = osp.join(self.dataset_dir, 'info/tracks_train_info.mat') 38 | self.track_test_info_path = osp.join(self.dataset_dir, 'info/tracks_test_info.mat') 39 | self.query_IDX_path = osp.join(self.dataset_dir, 'info/query_IDX.mat') 40 | 41 | required_files = [ 42 | self.dataset_dir, 43 | self.train_name_path, 44 | self.test_name_path, 45 | self.track_train_info_path, 46 | self.track_test_info_path, 47 | self.query_IDX_path 48 | ] 49 | self.check_before_run(required_files) 50 | 51 | train_names = self.get_names(self.train_name_path) 52 | test_names = self.get_names(self.test_name_path) 53 | track_train = loadmat(self.track_train_info_path)['track_train_info'] # numpy.ndarray (8298, 4) 54 | track_test = loadmat(self.track_test_info_path)['track_test_info'] # numpy.ndarray (12180, 4) 55 | query_IDX = loadmat(self.query_IDX_path)['query_IDX'].squeeze() # numpy.ndarray (1980,) 56 | query_IDX -= 1 # index from 0 57 | track_query = track_test[query_IDX,:] 58 | gallery_IDX = [i for i in range(track_test.shape[0]) if i not in query_IDX] 59 | track_gallery = track_test[gallery_IDX,:] 60 | 61 | train = self.process_data(train_names, track_train, home_dir='bbox_train', relabel=True) 62 | query = self.process_data(test_names, track_query, home_dir='bbox_test', relabel=False) 63 | gallery = self.process_data(test_names, track_gallery, home_dir='bbox_test', relabel=False) 64 | 65 | super(Mars, self).__init__(train, query, gallery, **kwargs) 66 | 67 | def get_names(self, fpath): 68 | names = [] 69 | with open(fpath, 'r') as f: 70 | for line in f: 71 | new_line = line.rstrip() 72 | names.append(new_line) 73 | return names 74 | 75 | def process_data(self, names, meta_data, home_dir=None, relabel=False, min_seq_len=0): 76 | assert home_dir in ['bbox_train', 'bbox_test'] 77 | num_tracklets = meta_data.shape[0] 78 | pid_list = list(set(meta_data[:,2].tolist())) 79 | num_pids = len(pid_list) 80 | 81 | if relabel: pid2label = {pid:label for label, pid in enumerate(pid_list)} 82 | tracklets = [] 83 | 84 | for tracklet_idx in range(num_tracklets): 85 | data = meta_data[tracklet_idx,...] 86 | start_index, end_index, pid, camid = data 87 | if pid == -1: 88 | continue # junk images are just ignored 89 | assert 1 <= camid <= 6 90 | if relabel: pid = pid2label[pid] 91 | camid -= 1 # index starts from 0 92 | img_names = names[start_index - 1:end_index] 93 | 94 | # make sure image names correspond to the same person 95 | pnames = [img_name[:4] for img_name in img_names] 96 | assert len(set(pnames)) == 1, 'Error: a single tracklet contains different person images' 97 | 98 | # make sure all images are captured under the same camera 99 | camnames = [img_name[5] for img_name in img_names] 100 | assert len(set(camnames)) == 1, 'Error: images are captured under different cameras!' 101 | 102 | # append image names with directory information 103 | img_paths = [osp.join(self.dataset_dir, home_dir, img_name[:4], img_name) for img_name in img_names] 104 | if len(img_paths) >= min_seq_len: 105 | img_paths = tuple(img_paths) 106 | tracklets.append((img_paths, pid, camid)) 107 | 108 | return tracklets 109 | 110 | def combine_all(self): 111 | warnings.warn('Some query IDs do not appear in gallery. Therefore, combineall ' 112 | 'does not make any difference to Mars') -------------------------------------------------------------------------------- /data_v2/datasets/video/prid2011.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import sys 6 | import os 7 | import os.path as osp 8 | import glob 9 | 10 | from .. import VideoDataset 11 | from ..utils import read_json, write_json 12 | 13 | 14 | class PRID2011(VideoDataset): 15 | """PRID2011. 16 | 17 | Reference: 18 | Hirzer et al. Person Re-Identification by Descriptive and 19 | Discriminative Classification. SCIA 2011. 20 | 21 | URL: ``_ 22 | 23 | Dataset statistics: 24 | - identities: 200. 25 | - tracklets: 400. 26 | - cameras: 2. 27 | """ 28 | dataset_dir = 'prid2011' 29 | dataset_url = None 30 | 31 | def __init__(self, root='', split_id=0, **kwargs): 32 | self.root = osp.abspath(osp.expanduser(root)) 33 | self.dataset_dir = osp.join(self.root, self.dataset_dir) 34 | self.download_dataset(self.dataset_dir, self.dataset_url) 35 | 36 | self.split_path = osp.join(self.dataset_dir, 'splits_prid2011.json') 37 | self.cam_a_dir = osp.join(self.dataset_dir, 'prid_2011', 'multi_shot', 'cam_a') 38 | self.cam_b_dir = osp.join(self.dataset_dir, 'prid_2011', 'multi_shot', 'cam_b') 39 | 40 | required_files = [ 41 | self.dataset_dir, 42 | self.cam_a_dir, 43 | self.cam_b_dir 44 | ] 45 | self.check_before_run(required_files) 46 | 47 | splits = read_json(self.split_path) 48 | if split_id >= len(splits): 49 | raise ValueError('split_id exceeds range, received {}, but expected between 0 and {}'.format(split_id, len(splits)-1)) 50 | split = splits[split_id] 51 | train_dirs, test_dirs = split['train'], split['test'] 52 | 53 | train = self.process_dir(train_dirs, cam1=True, cam2=True) 54 | query = self.process_dir(test_dirs, cam1=True, cam2=False) 55 | gallery = self.process_dir(test_dirs, cam1=False, cam2=True) 56 | 57 | super(PRID2011, self).__init__(train, query, gallery, **kwargs) 58 | 59 | def process_dir(self, dirnames, cam1=True, cam2=True): 60 | tracklets = [] 61 | dirname2pid = {dirname:i for i, dirname in enumerate(dirnames)} 62 | 63 | for dirname in dirnames: 64 | if cam1: 65 | person_dir = osp.join(self.cam_a_dir, dirname) 66 | img_names = glob.glob(osp.join(person_dir, '*.png')) 67 | assert len(img_names) > 0 68 | img_names = tuple(img_names) 69 | pid = dirname2pid[dirname] 70 | tracklets.append((img_names, pid, 0)) 71 | 72 | if cam2: 73 | person_dir = osp.join(self.cam_b_dir, dirname) 74 | img_names = glob.glob(osp.join(person_dir, '*.png')) 75 | assert len(img_names) > 0 76 | img_names = tuple(img_names) 77 | pid = dirname2pid[dirname] 78 | tracklets.append((img_names, pid, 1)) 79 | 80 | return tracklets -------------------------------------------------------------------------------- /data_v2/sampler.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | from collections import defaultdict 5 | import numpy as np 6 | import copy 7 | import random 8 | 9 | import torch 10 | from torch.utils.data.sampler import Sampler, RandomSampler 11 | 12 | 13 | class RandomIdentitySampler(Sampler): 14 | """Randomly samples N identities each with K instances. 15 | 16 | Args: 17 | data_source (list): contains tuples of (img_path(s), pid, camid). 18 | batch_size (int): batch size. 19 | num_instances (int): number of instances per identity in a batch. 20 | """ 21 | def __init__(self, data_source, batch_size, num_instances): 22 | if batch_size < num_instances: 23 | raise ValueError('batch_size={} must be no less ' 24 | 'than num_instances={}'.format(batch_size, num_instances)) 25 | 26 | self.data_source = data_source 27 | self.batch_size = batch_size 28 | self.num_instances = num_instances 29 | self.num_pids_per_batch = self.batch_size // self.num_instances 30 | self.index_dic = defaultdict(list) 31 | for index, (_, pid, _) in enumerate(self.data_source): 32 | self.index_dic[pid].append(index) 33 | self.pids = list(self.index_dic.keys()) 34 | 35 | # estimate number of examples in an epoch 36 | # TODO: improve precision 37 | self.length = 0 38 | for pid in self.pids: 39 | idxs = self.index_dic[pid] 40 | num = len(idxs) 41 | if num < self.num_instances: 42 | num = self.num_instances 43 | self.length += num - num % self.num_instances 44 | 45 | def __iter__(self): 46 | batch_idxs_dict = defaultdict(list) 47 | 48 | for pid in self.pids: 49 | idxs = copy.deepcopy(self.index_dic[pid]) 50 | if len(idxs) < self.num_instances: 51 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True) 52 | random.shuffle(idxs) 53 | batch_idxs = [] 54 | for idx in idxs: 55 | batch_idxs.append(idx) 56 | if len(batch_idxs) == self.num_instances: 57 | batch_idxs_dict[pid].append(batch_idxs) 58 | batch_idxs = [] 59 | 60 | avai_pids = copy.deepcopy(self.pids) 61 | final_idxs = [] 62 | 63 | while len(avai_pids) >= self.num_pids_per_batch: 64 | selected_pids = random.sample(avai_pids, self.num_pids_per_batch) 65 | for pid in selected_pids: 66 | batch_idxs = batch_idxs_dict[pid].pop(0) 67 | final_idxs.extend(batch_idxs) 68 | if len(batch_idxs_dict[pid]) == 0: 69 | avai_pids.remove(pid) 70 | 71 | self.length = len(final_idxs) 72 | return iter(final_idxs) 73 | 74 | def __len__(self): 75 | return self.length 76 | 77 | 78 | def build_train_sampler(data_source, train_sampler, batch_size=32, num_instances=4, **kwargs): 79 | """Builds a training sampler. 80 | 81 | Args: 82 | data_source (list): contains tuples of (img_path(s), pid, camid). 83 | train_sampler (str): sampler name (default: ``RandomSampler``). 84 | batch_size (int, optional): batch size. Default is 32. 85 | num_instances (int, optional): number of instances per identity in a 86 | batch (for ``RandomIdentitySampler``). Default is 4. 87 | """ 88 | if train_sampler == 'RandomIdentitySampler': 89 | sampler = RandomIdentitySampler(data_source, batch_size, num_instances) 90 | 91 | else: 92 | sampler = RandomSampler(data_source) 93 | 94 | return sampler 95 | -------------------------------------------------------------------------------- /data_v2/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | __all__ = ['mkdir_if_missing', 'check_isfile', 'read_json', 'write_json', 6 | 'set_random_seed', 'download_url', 'read_image', 'collect_env_info'] 7 | 8 | import sys 9 | import os 10 | import os.path as osp 11 | import time 12 | import errno 13 | import json 14 | from collections import OrderedDict 15 | import warnings 16 | import random 17 | import numpy as np 18 | import PIL 19 | from PIL import Image 20 | 21 | import torch 22 | 23 | 24 | def mkdir_if_missing(dirname): 25 | """Creates dirname if it is missing.""" 26 | if not osp.exists(dirname): 27 | try: 28 | os.makedirs(dirname) 29 | except OSError as e: 30 | if e.errno != errno.EEXIST: 31 | raise 32 | 33 | 34 | def check_isfile(fpath): 35 | """Checks if the given path is a file. 36 | 37 | Args: 38 | fpath (str): file path. 39 | 40 | Returns: 41 | bool 42 | """ 43 | isfile = osp.isfile(fpath) 44 | if not isfile: 45 | warnings.warn('No file found at "{}"'.format(fpath)) 46 | return isfile 47 | 48 | 49 | def read_json(fpath): 50 | """Reads json file from a path.""" 51 | with open(fpath, 'r') as f: 52 | obj = json.load(f) 53 | return obj 54 | 55 | 56 | def write_json(obj, fpath): 57 | """Writes to a json file.""" 58 | mkdir_if_missing(osp.dirname(fpath)) 59 | with open(fpath, 'w') as f: 60 | json.dump(obj, f, indent=4, separators=(',', ': ')) 61 | 62 | 63 | def set_random_seed(seed): 64 | random.seed(seed) 65 | np.random.seed(seed) 66 | torch.manual_seed(seed) 67 | torch.cuda.manual_seed_all(seed) 68 | 69 | 70 | def download_url(url, dst): 71 | """Downloads file from a url to a destination. 72 | 73 | Args: 74 | url (str): url to download file. 75 | dst (str): destination path. 76 | """ 77 | from six.moves import urllib 78 | print('* url="{}"'.format(url)) 79 | print('* destination="{}"'.format(dst)) 80 | 81 | def _reporthook(count, block_size, total_size): 82 | global start_time 83 | if count == 0: 84 | start_time = time.time() 85 | return 86 | duration = time.time() - start_time 87 | progress_size = int(count * block_size) 88 | speed = int(progress_size / (1024 * duration)) 89 | percent = int(count * block_size * 100 / total_size) 90 | sys.stdout.write('\r...%d%%, %d MB, %d KB/s, %d seconds passed' % 91 | (percent, progress_size / (1024 * 1024), speed, duration)) 92 | sys.stdout.flush() 93 | 94 | urllib.request.urlretrieve(url, dst, _reporthook) 95 | sys.stdout.write('\n') 96 | 97 | 98 | def read_image(path): 99 | """Reads image from path using ``PIL.Image``. 100 | 101 | Args: 102 | path (str): path to an image. 103 | 104 | Returns: 105 | PIL image 106 | """ 107 | got_img = False 108 | if not osp.exists(path): 109 | raise IOError('"{}" does not exist'.format(path)) 110 | while not got_img: 111 | try: 112 | img = Image.open(path).convert('RGB') 113 | got_img = True 114 | except IOError: 115 | print('IOError incurred when reading "{}". Will redo. Don\'t worry. Just chill.'.format(img_path)) 116 | pass 117 | return img 118 | 119 | 120 | def collect_env_info(): 121 | """Returns env info as a string. 122 | 123 | Code source: github.com/facebookresearch/maskrcnn-benchmark 124 | """ 125 | from torch.utils.collect_env import get_pretty_env_info 126 | env_str = get_pretty_env_info() 127 | env_str += '\n Pillow ({})'.format(PIL.__version__) 128 | return env_str 129 | -------------------------------------------------------------------------------- /lmbn_config.yaml: -------------------------------------------------------------------------------- 1 | T: 3 2 | act: relu 3 | amsgrad: false 4 | batchid: 6 5 | batchimage: 8 6 | batchtest: 32 7 | beta1: 0.9 8 | beta2: 0.999 9 | bnneck: false 10 | config: "" 11 | cosine_annealing: false 12 | cpu: false 13 | cuhk03_labeled: false 14 | cutout: false 15 | dampening: 0 16 | data_test: market1501 17 | data_train: market1501 18 | datadir: /media/CityFlow/reid/reid/ 19 | decay_type: step_50_80_110 20 | drop_block: false 21 | epochs: 2 22 | epsilon: 1.0e-08 23 | feat_inference: after 24 | feats: 512 25 | gamma: 0.1 26 | h_ratio: 0.3 27 | height: 384 28 | if_labelsmooth: true 29 | loss: 0.5*CrossEntropy+0.5*MSLoss 30 | lr: 0.0006 31 | lr_decay: 60 32 | margin: 0.7 33 | model: LMBN_n 34 | momentum: 0.9 35 | nGPU: 1 36 | nThread: 8 37 | nep_id: "" 38 | nesterov: false 39 | num_anchors: 1 40 | num_classes: 751 41 | optimizer: ADAM 42 | parts: 6 43 | pcb_different_lr: true 44 | pool: avg 45 | probability: 0.5 46 | random_erasing: true 47 | reset: false 48 | sampler: true 49 | test_every: 10 50 | w_cosine_annealing: true 51 | w_ratio: 1.0 52 | warmup: constant 53 | weight_decay: 0.0005 54 | width: 128 55 | wandb: false 56 | wandb_name: "" 57 | -------------------------------------------------------------------------------- /loss/Icon : -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jixunbo/LightMBN/1f4ec0de6535604f89257195f5c083845491dc59/loss/Icon -------------------------------------------------------------------------------- /loss/center_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class CenterLoss(nn.Module): 8 | """Center loss. 9 | Reference: 10 | Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016. 11 | Args: 12 | num_classes (int): number of classes. 13 | feat_dim (int): feature dimension. 14 | """ 15 | 16 | def __init__(self, num_classes=751, feat_dim=2048, use_gpu=True): 17 | super(CenterLoss, self).__init__() 18 | self.num_classes = num_classes 19 | self.feat_dim = feat_dim 20 | self.use_gpu = use_gpu 21 | 22 | if self.use_gpu: 23 | self.centers = nn.Parameter(torch.randn( 24 | self.num_classes, self.feat_dim).cuda()) 25 | else: 26 | self.centers = nn.Parameter( 27 | torch.randn(self.num_classes, self.feat_dim)) 28 | 29 | def forward(self, x, labels): 30 | """ 31 | Args: 32 | x: feature matrix with shape (batch_size, feat_dim). 33 | labels: ground truth labels with shape (num_classes). 34 | """ 35 | assert x.size(0) == labels.size( 36 | 0), "features.size(0) is not equal to labels.size(0)" 37 | 38 | batch_size = x.size(0) 39 | distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \ 40 | torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand( 41 | self.num_classes, batch_size).t() 42 | distmat.addmm_(1, -2, x, self.centers.t()) 43 | 44 | classes = torch.arange(self.num_classes).long() 45 | if self.use_gpu: 46 | classes = classes.cuda() 47 | labels = labels.unsqueeze(1).expand(batch_size, self.num_classes) 48 | mask = labels.eq(classes.expand(batch_size, self.num_classes)) 49 | 50 | dist = distmat * mask.float() 51 | loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size 52 | #dist = [] 53 | # for i in range(batch_size): 54 | # value = distmat[i][mask[i]] 55 | # value = value.clamp(min=1e-12, max=1e+12) # for numerical stability 56 | # dist.append(value) 57 | #dist = torch.cat(dist) 58 | #loss = dist.mean() 59 | return loss 60 | 61 | 62 | if __name__ == '__main__': 63 | use_gpu = False 64 | center_loss = CenterLoss(use_gpu=use_gpu) 65 | features = torch.rand(16, 2048) 66 | targets = torch.Tensor( 67 | [0, 1, 2, 3, 2, 3, 1, 4, 5, 3, 2, 1, 0, 0, 5, 4]).long() 68 | if use_gpu: 69 | features = torch.rand(16, 2048).cuda() 70 | targets = torch.Tensor( 71 | [0, 1, 2, 3, 2, 3, 1, 4, 5, 3, 2, 1, 0, 0, 5, 4]).cuda() 72 | 73 | loss = center_loss(features, targets) 74 | print(loss) 75 | -------------------------------------------------------------------------------- /loss/focal_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | # from kornia.utils import one_hot 6 | 7 | 8 | # based on: 9 | # https://github.com/zhezh/focalloss/blob/master/focalloss.py 10 | 11 | def focal_loss( 12 | input: torch.Tensor, 13 | target: torch.Tensor, 14 | alpha: float, 15 | gamma: float = 2.0, 16 | reduction: str = 'none', 17 | eps: float = 1e-8) -> torch.Tensor: 18 | r"""Function that computes Focal loss. 19 | See :class:`~kornia.losses.FocalLoss` for details. 20 | """ 21 | if not torch.is_tensor(input): 22 | raise TypeError("Input type is not a torch.Tensor. Got {}" 23 | .format(type(input))) 24 | 25 | if not len(input.shape) >= 2: 26 | raise ValueError("Invalid input shape, we expect BxCx*. Got: {}" 27 | .format(input.shape)) 28 | 29 | if input.size(0) != target.size(0): 30 | raise ValueError('Expected input batch_size ({}) to match target batch_size ({}).' 31 | .format(input.size(0), target.size(0))) 32 | 33 | n = input.size(0) 34 | out_size = (n,) + input.size()[2:] 35 | if target.size()[1:] != input.size()[2:]: 36 | raise ValueError('Expected target size {}, got {}'.format( 37 | out_size, target.size())) 38 | 39 | if not input.device == target.device: 40 | raise ValueError( 41 | "input and target must be in the same device. Got: {}" .format( 42 | input.device, target.device)) 43 | 44 | # compute softmax over the classes axis 45 | input_soft: torch.Tensor = F.softmax(input, dim=1) + eps 46 | 47 | # create the labels one hot tensor 48 | # target_one_hot: torch.Tensor = one_hot( 49 | # target, num_classes=input.shape[1], 50 | # device=input.device, dtype=input.dtype) 51 | target_one_hot: torch.Tensor = F.one_hot( 52 | target, num_classes=input.shape[1]) 53 | 54 | # compute the actual focal loss 55 | weight = torch.pow(-input_soft + 1., gamma) 56 | 57 | focal = -alpha * weight * torch.log(input_soft) 58 | 59 | loss_tmp = torch.sum(target_one_hot * focal, dim=1) 60 | 61 | if reduction == 'none': 62 | loss = loss_tmp 63 | elif reduction == 'mean': 64 | loss = torch.mean(loss_tmp) 65 | elif reduction == 'sum': 66 | loss = torch.sum(loss_tmp) 67 | else: 68 | raise NotImplementedError("Invalid reduction mode: {}" 69 | .format(reduction)) 70 | return loss 71 | 72 | 73 | class FocalLoss(nn.Module): 74 | r"""Criterion that computes Focal loss. 75 | According to [1], the Focal loss is computed as follows: 76 | .. math:: 77 | \text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \text{log}(p_t) 78 | where: 79 | - :math:`p_t` is the model's estimated probability for each class. 80 | Arguments: 81 | alpha (float): Weighting factor :math:`\alpha \in [0, 1]`. 82 | gamma (float): Focusing parameter :math:`\gamma >= 0`. 83 | reduction (str, optional): Specifies the reduction to apply to the 84 | output: ‘none’ | ‘mean’ | ‘sum’. ‘none’: no reduction will be applied, 85 | ‘mean’: the sum of the output will be divided by the number of elements 86 | in the output, ‘sum’: the output will be summed. Default: ‘none’. 87 | Shape: 88 | - Input: :math:`(N, C, *)` where C = number of classes. 89 | - Target: :math:`(N, *)` where each value is 90 | :math:`0 ≤ targets[i] ≤ C−1`. 91 | Examples: 92 | >>> N = 5 # num_classes 93 | >>> kwargs = {"alpha": 0.5, "gamma": 2.0, "reduction": 'mean'} 94 | >>> loss = kornia.losses.FocalLoss(**kwargs) 95 | >>> input = torch.randn(1, N, 3, 5, requires_grad=True) 96 | >>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N) 97 | >>> output = loss(input, target) 98 | >>> output.backward() 99 | References: 100 | [1] https://arxiv.org/abs/1708.02002 101 | """ 102 | 103 | def __init__(self, alpha: float = 1.0, gamma: float = 2.0, 104 | reduction: str = 'none') -> None: 105 | super(FocalLoss, self).__init__() 106 | self.alpha: float = alpha 107 | self.gamma: float = gamma 108 | self.reduction: str = reduction 109 | self.eps: float = 1e-6 110 | 111 | def forward( # type: ignore 112 | self, 113 | input: torch.Tensor, 114 | target: torch.Tensor) -> torch.Tensor: 115 | return focal_loss(input, target, self.alpha, self.gamma, self.reduction, self.eps) 116 | -------------------------------------------------------------------------------- /loss/multi_similarity_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Malong Technologies Co., Ltd. 2 | # All rights reserved. 3 | # 4 | # Contact: github@malong.com 5 | # 6 | # This source code is licensed under the LICENSE file in the root directory of this source tree. 7 | 8 | import torch 9 | from torch import nn 10 | 11 | # from ret_benchmark.losses.registry import LOSS 12 | 13 | 14 | # @LOSS.register('ms_loss') 15 | class MultiSimilarityLoss(nn.Module): 16 | def __init__(self, margin=0.1): 17 | super(MultiSimilarityLoss, self).__init__() 18 | self.thresh = 0.5 19 | self.margin = margin 20 | 21 | self.scale_pos = 2.0 22 | self.scale_neg = 40.0 23 | 24 | def forward(self, feats, labels): 25 | assert feats.size(0) == labels.size(0), \ 26 | f"feats.size(0): {feats.size(0)} is not equal to labels.size(0): {labels.size(0)}" 27 | batch_size = feats.size(0) 28 | feats = nn.functional.normalize(feats, p=2, dim=1) 29 | 30 | # Shape: batchsize * batch size 31 | sim_mat = torch.matmul(feats, torch.t(feats)) 32 | 33 | epsilon = 1e-5 34 | loss = list() 35 | 36 | # for i in range(batch_size): 37 | # # print(i,'ccccc') 38 | # pos_pair_ = sim_mat[i][labels == labels[i]] 39 | # # print(pos_pair_.shape) 40 | # pos_pair_ = pos_pair_[pos_pair_ < 1 - epsilon] 41 | # neg_pair_ = sim_mat[i][labels != labels[i]] 42 | 43 | # neg_pair = neg_pair_[neg_pair_ + self.margin > min(pos_pair_)] 44 | # pos_pair = pos_pair_[pos_pair_ - self.margin < max(neg_pair_)] 45 | 46 | # if len(neg_pair) < 1 or len(pos_pair) < 1: 47 | # continue 48 | 49 | # # weighting step 50 | # pos_loss = 1.0 / self.scale_pos * torch.log( 51 | # 1 + torch.sum(torch.exp(-self.scale_pos * (pos_pair - self.thresh)))) 52 | # neg_loss = 1.0 / self.scale_neg * torch.log( 53 | # 1 + torch.sum(torch.exp(self.scale_neg * (neg_pair - self.thresh)))) 54 | # loss.append(pos_loss + neg_loss) 55 | 56 | mask = labels.expand(batch_size, batch_size).eq( 57 | labels.expand(batch_size, batch_size).t()) 58 | for i in range(batch_size): 59 | pos_pair_ = sim_mat[i][mask[i]] 60 | pos_pair_ = pos_pair_[pos_pair_ < 1 - epsilon] 61 | neg_pair_ = sim_mat[i][mask[i] == 0] 62 | 63 | neg_pair = neg_pair_[neg_pair_ + self.margin > min(pos_pair_)] 64 | pos_pair = pos_pair_[pos_pair_ - self.margin < max(neg_pair_)] 65 | 66 | if len(neg_pair) < 1 or len(pos_pair) < 1: 67 | continue 68 | 69 | # weighting step 70 | pos_loss = 1.0 / self.scale_pos * torch.log( 71 | 1 + torch.sum(torch.exp(-self.scale_pos * (pos_pair - self.thresh)))) 72 | neg_loss = 1.0 / self.scale_neg * torch.log( 73 | 1 + torch.sum(torch.exp(self.scale_neg * (neg_pair - self.thresh)))) 74 | loss.append(pos_loss + neg_loss) 75 | # pos_loss = 76 | 77 | 78 | if len(loss) == 0: 79 | return torch.zeros([], requires_grad=True, device=feats.device) 80 | 81 | loss = sum(loss) / batch_size 82 | return loss 83 | -------------------------------------------------------------------------------- /loss/osm_caa_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | 5 | 6 | class OSM_CAA_Loss(nn.Module): 7 | def __init__(self, alpha=1.2, l=0.5, use_gpu=True, osm_sigma=0.8): 8 | super(OSM_CAA_Loss, self).__init__() 9 | self.use_gpu = use_gpu 10 | self.alpha = alpha # margin of weighted contrastive loss, as mentioned in the paper 11 | self.l = l # hyperparameter controlling weights of positive set and the negative set 12 | # I haven't been able to figure out the use of \sigma CAA 0.18 13 | self.osm_sigma = osm_sigma # \sigma OSM (0.8) as mentioned in paper 14 | 15 | def forward(self, x, embd, labels): 16 | ''' 17 | x : feature vector : (n x d) 18 | labels : (n,) 19 | embd : Fully Connected weights of classification layer (dxC), C is the number of classes: represents the vectors for class 20 | ''' 21 | x = nn.functional.normalize(x, p=2, dim=1) # normalize the features 22 | n = x.size(0) 23 | dist = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(n, n) 24 | dist = dist + dist.t() 25 | dist.addmm_(1, -2, x, x.t()) 26 | dist = dist.clamp(min=1e-12).sqrt() 27 | # print(dist,'dist') 28 | S = torch.exp(-1.0 * torch.pow(dist, 2) / 29 | (self.osm_sigma * self.osm_sigma)) 30 | # max (0, self.alpha - dij ) 31 | # print(S,'ssssssssss') 32 | S_ = torch.clamp(self.alpha - dist, min=1e-12) 33 | p_mask = labels.expand(n, n).eq(labels.expand(n, n).t()) 34 | p_mask = p_mask.float() 35 | n_mask = 1 - p_mask 36 | S = S * p_mask.float() 37 | S = S + S_ * n_mask.float() 38 | # embd = nn.functional.normalize(embd, p=2, dim=0) 39 | # denominator = torch.exp(torch.mm(x, embd)) 40 | # A = [] 41 | # for i in range(n): 42 | # a_i = denominator[i][labels[i]] / torch.sum(denominator[i]) 43 | # A.append(a_i) 44 | # atten_class = torch.stack(A) 45 | A = [] 46 | # print(labels,'label') 47 | for i in range(n): 48 | A.append(embd[i][labels[i]]) 49 | atten_class = torch.stack(A) 50 | 51 | # pairwise minimum of attention weights 52 | A = torch.min(atten_class.expand(n, n), 53 | atten_class.view(-1, 1).expand(n, n)) 54 | W = S * A 55 | W_P = W * p_mask.float() 56 | W_N = W * n_mask.float() 57 | if self.use_gpu: 58 | # dist between (xi,xi) not necessarily 0, avoiding precision error 59 | W_P = W_P * (1 - torch.eye(n, n).float().cuda()) 60 | W_N = W_N * (1 - torch.eye(n, n).float().cuda()) 61 | else: 62 | W_P = W_P * (1 - torch.eye(n, n).float()) 63 | W_N = W_N * (1 - torch.eye(n, n).float()) 64 | L_P = 1.0 / 2 * torch.sum(W_P * torch.pow(dist, 2)) / torch.sum(W_P) 65 | L_N = 1.0 / 2 * torch.sum(W_N * torch.pow(S_, 2)) / torch.sum(W_N) 66 | # print(L_P,'lplplplplp') 67 | L = (1 - self.l) * L_P + self.l * L_N 68 | return L 69 | 70 | 71 | if __name__ == '__main__': 72 | # Here I left a simple forward function. 73 | # Test the model, before you train it. 74 | import argparse 75 | 76 | parser = argparse.ArgumentParser(description='MGN') 77 | parser.add_argument('--num_classes', type=int, default=751, help='') 78 | parser.add_argument('--bnneck', type=bool, default=True) 79 | parser.add_argument('--parts', type=int, default=3) 80 | parser.add_argument('--feats', type=int, default=256) 81 | 82 | args = parser.parse_args() 83 | net = OSM_CAA_Loss(use_gpu=False) 84 | # net.classifier = nn.Sequential() 85 | # print([p for p in net.parameters()]) 86 | # a=filter(lambda p: p.requires_grad, net.parameters()) 87 | # print(a) 88 | 89 | print(net) 90 | d = 256 91 | c = 751 92 | x = Variable(torch.FloatTensor(8, d)) 93 | label = Variable(torch.arange(8)) 94 | embd = Variable(torch.FloatTensor(d, 751)) 95 | 96 | output = net(x, embd, label) 97 | print('net output size:') 98 | # print(len(output)) 99 | print(output.shape) 100 | -------------------------------------------------------------------------------- /loss/ranked_loss.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: zzg 4 | @contact: xhx1247786632@gmail.com 5 | """ 6 | import torch 7 | from torch import nn 8 | 9 | 10 | def normalize_rank(x, axis=-1): 11 | """Normalizing to unit length along the specified dimension. 12 | Args: 13 | x: pytorch Variable 14 | Returns: 15 | x: pytorch Variable, same shape as input 16 | """ 17 | x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12) 18 | return x 19 | 20 | 21 | def euclidean_dist_rank(x, y): 22 | """ 23 | Args: 24 | x: pytorch Variable, with shape [m, d] 25 | y: pytorch Variable, with shape [n, d] 26 | Returns: 27 | dist: pytorch Variable, with shape [m, n] 28 | """ 29 | m, n = x.size(0), y.size(0) 30 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) 31 | yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t() 32 | dist = xx + yy 33 | dist.addmm_(1, -2, x, y.t()) 34 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 35 | return dist 36 | 37 | 38 | def rank_loss(dist_mat, labels, margin, alpha, tval): 39 | """ 40 | Args: 41 | dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N] 42 | labels: pytorch LongTensor, with shape [N] 43 | 44 | """ 45 | assert len(dist_mat.size()) == 2 46 | assert dist_mat.size(0) == dist_mat.size(1) 47 | N = dist_mat.size(0) 48 | 49 | total_loss = 0.0 50 | for ind in range(N): 51 | is_pos = labels.eq(labels[ind]) 52 | is_pos[ind] = 0 53 | is_neg = labels.ne(labels[ind]) 54 | 55 | dist_ap = dist_mat[ind][is_pos] 56 | dist_an = dist_mat[ind][is_neg] 57 | 58 | ap_is_pos = torch.clamp(torch.add(dist_ap, margin - alpha), min=0.0) 59 | ap_pos_num = ap_is_pos.size(0) + 1e-5 60 | ap_pos_val_sum = torch.sum(ap_is_pos) 61 | loss_ap = torch.div(ap_pos_val_sum, float(ap_pos_num)) 62 | 63 | an_is_pos = torch.lt(dist_an, alpha) 64 | an_less_alpha = dist_an[an_is_pos] 65 | an_weight = torch.exp(tval * (-1 * an_less_alpha + alpha)) 66 | an_weight_sum = torch.sum(an_weight) + 1e-5 67 | an_dist_lm = alpha - an_less_alpha 68 | an_ln_sum = torch.sum(torch.mul(an_dist_lm, an_weight)) 69 | loss_an = torch.div(an_ln_sum, an_weight_sum) 70 | 71 | total_loss = total_loss + loss_ap + loss_an 72 | total_loss = total_loss * 1.0 / N 73 | return total_loss 74 | 75 | 76 | class RankedLoss(object): 77 | "Ranked_List_Loss_for_Deep_Metric_Learning_CVPR_2019_paper" 78 | 79 | def __init__(self, margin=None, alpha=None, tval=None): 80 | self.margin = margin 81 | self.alpha = alpha 82 | self.tval = tval 83 | 84 | def __call__(self, global_feat, labels, normalize_feature=True): 85 | if normalize_feature: 86 | global_feat = normalize_rank(global_feat, axis=-1) 87 | dist_mat = euclidean_dist_rank(global_feat, global_feat) 88 | total_loss = rank_loss( 89 | dist_mat, labels, self.margin, self.alpha, self.tval) 90 | 91 | return total_loss 92 | -------------------------------------------------------------------------------- /loss/triplet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | 8 | class TripletSemihardLoss(nn.Module): 9 | """ 10 | Shape: 11 | - Input: :math:`(N, C)` where `C = number of channels` 12 | - Target: :math:`(N)` 13 | - Output: scalar. 14 | """ 15 | 16 | def __init__(self, device, margin=0, size_average=True): 17 | super(TripletSemihardLoss, self).__init__() 18 | self.margin = margin 19 | self.size_average = size_average 20 | self.device = device 21 | 22 | def forward(self, input, target): 23 | y_true = target.int().unsqueeze(-1) 24 | same_id = torch.eq(y_true, y_true.t()).type_as(input) 25 | 26 | pos_mask = same_id 27 | neg_mask = 1 - same_id 28 | 29 | def _mask_max(input_tensor, mask, axis=None, keepdims=False): 30 | input_tensor = input_tensor - 1e6 * (1 - mask) 31 | _max, _idx = torch.max(input_tensor, dim=axis, keepdim=keepdims) 32 | return _max, _idx 33 | 34 | def _mask_min(input_tensor, mask, axis=None, keepdims=False): 35 | input_tensor = input_tensor + 1e6 * (1 - mask) 36 | _min, _idx = torch.min(input_tensor, dim=axis, keepdim=keepdims) 37 | return _min, _idx 38 | 39 | # output[i, j] = || feature[i, :] - feature[j, :] ||_2 40 | dist_squared = torch.sum(input ** 2, dim=1, keepdim=True) + \ 41 | torch.sum(input.t() ** 2, dim=0, keepdim=True) - \ 42 | 2.0 * torch.matmul(input, input.t()) 43 | dist = dist_squared.clamp(min=1e-16).sqrt() 44 | 45 | pos_max, pos_idx = _mask_max(dist, pos_mask, axis=-1) 46 | neg_min, neg_idx = _mask_min(dist, neg_mask, axis=-1) 47 | 48 | # loss(x, y) = max(0, -y * (x1 - x2) + margin) 49 | y = torch.ones(same_id.size()[0]).to(self.device) 50 | return F.margin_ranking_loss(neg_min.float(), 51 | pos_max.float(), 52 | y, 53 | self.margin, 54 | self.size_average) 55 | 56 | 57 | class TripletLoss(nn.Module): 58 | """ 59 | Batch Hard Trilet Loss 60 | For margin = 0. , which implemented as Batch Hard Soft Margin Triplet loss 61 | 62 | Reference: 63 | Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737. 64 | 65 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py. 66 | 67 | Args: 68 | margin (float): margin for triplet. 69 | """ 70 | 71 | def __init__(self, margin=0.3, mutual_flag=False): 72 | super(TripletLoss, self).__init__() 73 | self.margin = margin 74 | if margin == 0.: 75 | self.ranking_loss = nn.SoftMarginLoss() 76 | print('Using soft margin triplet loss') 77 | else: 78 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 79 | 80 | self.mutual = mutual_flag 81 | 82 | def forward(self, inputs, targets): 83 | """ 84 | Args: 85 | inputs: feature matrix with shape (batch_size, feat_dim) 86 | targets: ground truth labels with shape (num_classes) 87 | """ 88 | n = inputs.size(0) 89 | 90 | # inputs = 1. * inputs / (torch.norm(inputs, 2, dim=-1, keepdim=True).expand_as(inputs) + 1e-12) 91 | # Compute pairwise distance, replace by the official when merged 92 | dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n) 93 | dist = dist + dist.t() 94 | # dist.addmm_(1, -2, inputs, inputs.t()) 95 | dist.addmm_(inputs, inputs.t(), beta=1, alpha=-2) 96 | 97 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 98 | # For each anchor, find the hardest positive and negative 99 | mask = targets.expand(n, n).eq(targets.expand(n, n).t()) 100 | dist_ap, dist_an = [], [] 101 | for i in range(n): 102 | dist_ap.append(dist[i][mask[i]].max().unsqueeze(0)) 103 | dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0)) 104 | 105 | dist_ap = torch.cat(dist_ap) 106 | dist_an = torch.cat(dist_an) 107 | # Compute ranking hinge loss 108 | y = torch.ones_like(dist_an) 109 | # loss = self.ranking_loss(dist_an, dist_ap, y) 110 | if self.margin == 0.: 111 | loss = self.ranking_loss(dist_an - dist_ap, y) 112 | else: 113 | loss = self.ranking_loss(dist_an, dist_ap, y) 114 | 115 | if self.mutual: 116 | return loss, dist 117 | return loss 118 | 119 | 120 | class CrossEntropyLabelSmooth(nn.Module): 121 | """Cross entropy loss with label smoothing regularizer. 122 | 123 | Reference: 124 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 125 | Equation: y = (1 - epsilon) * y + epsilon / K. 126 | 127 | Args: 128 | num_classes (int): number of classes. 129 | epsilon (float): weight. 130 | """ 131 | 132 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True): 133 | super(CrossEntropyLabelSmooth, self).__init__() 134 | self.num_classes = num_classes 135 | self.epsilon = epsilon 136 | self.use_gpu = use_gpu 137 | self.logsoftmax = nn.LogSoftmax(dim=1) 138 | 139 | def forward(self, inputs, targets): 140 | """ 141 | Args: 142 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 143 | targets: ground truth labels with shape (num_classes) 144 | """ 145 | log_probs = self.logsoftmax(inputs) 146 | # print(log_probs.device) 147 | targets = torch.zeros(log_probs.size()).scatter_( 148 | 1, targets.unsqueeze(1).data.cpu(), 1).to(log_probs.device) 149 | # targets = torch.zeros(log_probs.size()).scatter_( 150 | # 1, targets.unsqueeze(1).long(), 1) 151 | # if self.use_gpu: 152 | # targets = targets.cuda() 153 | # print(targets.device) 154 | targets = (1 - self.epsilon) * targets + \ 155 | self.epsilon / self.num_classes 156 | loss = (- targets * log_probs).mean(0).sum() 157 | return loss 158 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import data_v1 2 | import data_v2 3 | from loss import make_loss 4 | from model import make_model 5 | from optim import make_optimizer, make_scheduler 6 | 7 | # import engine_v1 8 | # import engine_v2 9 | import engine_v3 10 | import os.path as osp 11 | from option import args 12 | import utils.utility as utility 13 | from utils.model_complexity import compute_model_complexity 14 | from torch.utils.collect_env import get_pretty_env_info 15 | import yaml 16 | import torch 17 | 18 | 19 | if args.config != "": 20 | with open(args.config, "r") as f: 21 | config = yaml.full_load(f) 22 | for op in config: 23 | setattr(args, op, config[op]) 24 | torch.backends.cudnn.benchmark = True 25 | 26 | # loader = data.Data(args) 27 | ckpt = utility.checkpoint(args) 28 | loader = data_v2.ImageDataManager(args) 29 | model = make_model(args, ckpt) 30 | optimzer = make_optimizer(args, model) 31 | loss = make_loss(args, ckpt) if not args.test_only else None 32 | 33 | start = -1 34 | if args.load != "": 35 | start, model, optimizer = ckpt.resume_from_checkpoint( 36 | osp.join(ckpt.dir, "model-latest.pth"), model, optimzer 37 | ) 38 | start = start - 1 39 | if args.pre_train != "": 40 | ckpt.load_pretrained_weights(model, args.pre_train) 41 | 42 | scheduler = make_scheduler(args, optimzer, start) 43 | 44 | # print('[INFO] System infomation: \n {}'.format(get_pretty_env_info())) 45 | ckpt.write_log( 46 | "[INFO] Model parameters: {com[0]} flops: {com[1]}".format( 47 | com=compute_model_complexity(model, (1, 3, args.height, args.width)) 48 | ) 49 | ) 50 | 51 | engine = engine_v3.Engine(args, model, optimzer, scheduler, loss, loader, ckpt) 52 | # engine = engine.Engine(args, model, loss, loader, ckpt) 53 | 54 | n = start + 1 55 | while not engine.terminate(): 56 | n += 1 57 | engine.train() 58 | if args.test_every != 0 and n % args.test_every == 0: 59 | engine.test() 60 | elif n == args.epochs: 61 | engine.test() 62 | -------------------------------------------------------------------------------- /model/Icon : -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jixunbo/LightMBN/1f4ec0de6535604f89257195f5c083845491dc59/model/Icon -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from importlib import import_module 2 | 3 | import torch 4 | import torch.nn as nn 5 | import os.path as osp 6 | from collections import OrderedDict 7 | 8 | 9 | def make_model(args, ckpt): 10 | 11 | ckpt.write_log('[INFO] Building {} model...'.format(args.model)) 12 | 13 | device = torch.device('cpu' if args.cpu else 'cuda') 14 | # nGPU = args.nGPU 15 | 16 | module = import_module('model.' + args.model.lower()) 17 | model = getattr(module, args.model)(args).to(device) 18 | 19 | if not args.cpu and args.nGPU > 1: 20 | model = nn.DataParallel(model, range(args.nGPU)) 21 | 22 | return model 23 | 24 | # class Model(nn.Module): 25 | 26 | # def __init__(self, args, ckpt): 27 | # super(Model, self).__init__() 28 | # ckpt.write_log('[INFO] Making {} model...'.format(args.model)) 29 | # if args.drop_block: 30 | # ckpt.write_log('[INFO] Using batch drop block with h_ratio {} and w_ratio {}.'.format(args.h_ratio, args.w_ratio)) 31 | 32 | # self.device = torch.device('cpu' if args.cpu else 'cuda') 33 | # self.nGPU = args.nGPU 34 | 35 | # module = import_module('model.' + args.model.lower()) 36 | # # self.model = module.make_model(args).to(self.device) 37 | # self.model = getattr(module, args.model)(args).to(self.device) 38 | 39 | # if not args.cpu and args.nGPU > 1: 40 | # self.model = nn.DataParallel(self.model, range(args.nGPU)) 41 | 42 | # def forward(self, x): 43 | # return self.model(x) 44 | 45 | # def get_model(self): 46 | # if self.nGPU == 1: 47 | # return self.model 48 | # else: 49 | # return self.model.module 50 | 51 | # def save(self, apath, epoch, is_best=False): 52 | # target = self.get_model() 53 | # torch.save( 54 | # target.state_dict(), 55 | # os.path.join(apath, 'model', 'model_latest.pt') 56 | # ) 57 | # if is_best: 58 | # torch.save( 59 | # target.state_dict(), 60 | # os.path.join(apath, 'model', 'model_best.pt') 61 | # ) 62 | 63 | 64 | # if self.save_models: 65 | # torch.save( 66 | # target.state_dict(), 67 | # os.path.join(apath, 'model', 'model_{}.pt'.format(epoch)) 68 | # ) 69 | ''' 70 | def load(self, apath, pre_train='', resume=-1, cpu=False): 71 | if cpu: 72 | kwargs = {'map_location': lambda storage, loc: storage} 73 | else: 74 | kwargs = {} 75 | 76 | # if resume == -1: 77 | # print('Loading model from last checkpoint') 78 | # self.get_model().load_state_dict( 79 | # torch.load( 80 | # os.path.join(apath, 'model', 'model_latest.pt'), 81 | # **kwargs 82 | # ), 83 | # strict=False 84 | # ) 85 | # elif resume == 0: 86 | # if pre_train != '': 87 | # print('Loading model from {}'.format(pre_train)) 88 | # self.get_model().load_state_dict( 89 | # torch.load(pre_train, **kwargs), 90 | # strict=False 91 | # ) 92 | # modified on 01.02.1010 93 | # if resume == 0: 94 | # if pre_train != '': 95 | # print('Loading model from {}'.format(pre_train)) 96 | # self.get_model().load_state_dict( 97 | # torch.load(pre_train, **kwargs), 98 | # strict=False 99 | # ) 100 | # else: 101 | # print('Loading model from last checkpoint') 102 | # self.get_model().load_state_dict( 103 | # torch.load( 104 | # os.path.join(apath, 'model', 'model_latest.pt'), 105 | # **kwargs 106 | # ), 107 | # strict=False 108 | # ) 109 | # else: 110 | # self.get_model().load_state_dict( 111 | # torch.load( 112 | # os.path.join(apath, 'model', 'model_{}.pt'.format(resume)), 113 | # **kwargs 114 | # ), 115 | # strict=False 116 | # ) 117 | # modified on 01.02.1010 118 | if pre_train != '': 119 | print('Loading model from {}'.format(pre_train)) 120 | if pre_train.split('.')[-1][:3] == 'tar': 121 | print('load checkpointerrrrrrr') 122 | # checkpoint = self.load_checkpoint(pre_train) 123 | # self.get_model().load_state_dict(checkpoint['state_dict']) 124 | self.load_pretrained_weights(self.get_model(), pre_train) 125 | else: 126 | 127 | self.get_model().load_state_dict( 128 | torch.load(pre_train, **kwargs), 129 | strict=False 130 | ) 131 | else: 132 | print('Loading model from last checkpoint') 133 | # print(apath) 134 | self.get_model().load_state_dict( 135 | torch.load( 136 | os.path.join(apath, 'model', 'model_latest.pt'), 137 | **kwargs 138 | ), 139 | # strict=False 140 | ) 141 | ''' 142 | -------------------------------------------------------------------------------- /model/bnneck.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | class BNNeck(nn.Module): 4 | def __init__(self, input_dim, class_num, return_f=False): 5 | super(BNNeck, self).__init__() 6 | self.return_f = return_f 7 | self.bn = nn.BatchNorm1d(input_dim) 8 | self.bn.bias.requires_grad_(False) 9 | self.classifier = nn.Linear(input_dim, class_num, bias=False) 10 | self.bn.apply(self.weights_init_kaiming) 11 | self.classifier.apply(self.weights_init_classifier) 12 | 13 | def forward(self, x): 14 | before_neck = x.view(x.size(0), x.size(1)) 15 | after_neck = self.bn(before_neck) 16 | 17 | if self.return_f: 18 | score = self.classifier(after_neck) 19 | return after_neck, score, before_neck 20 | else: 21 | x = self.classifier(x) 22 | return x 23 | 24 | def weights_init_kaiming(self, m): 25 | classname = m.__class__.__name__ 26 | if classname.find('Linear') != -1: 27 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out') 28 | nn.init.constant_(m.bias, 0.0) 29 | elif classname.find('Conv') != -1: 30 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') 31 | if m.bias is not None: 32 | nn.init.constant_(m.bias, 0.0) 33 | elif classname.find('BatchNorm') != -1: 34 | if m.affine: 35 | nn.init.constant_(m.weight, 1.0) 36 | nn.init.constant_(m.bias, 0.0) 37 | 38 | def weights_init_classifier(self, m): 39 | classname = m.__class__.__name__ 40 | if classname.find('Linear') != -1: 41 | nn.init.normal_(m.weight, std=0.001) 42 | if m.bias: 43 | nn.init.constant_(m.bias, 0.0) 44 | 45 | 46 | class BNNeck3(nn.Module): 47 | def __init__(self, input_dim, class_num, feat_dim, return_f=False): 48 | super(BNNeck3, self).__init__() 49 | self.return_f = return_f 50 | # self.reduction = nn.Linear(input_dim, feat_dim) 51 | # self.bn = nn.BatchNorm1d(feat_dim) 52 | 53 | self.reduction = nn.Conv2d( 54 | input_dim, feat_dim, 1, bias=False) 55 | self.bn = nn.BatchNorm1d(feat_dim) 56 | 57 | self.bn.bias.requires_grad_(False) 58 | self.classifier = nn.Linear(feat_dim, class_num, bias=False) 59 | self.bn.apply(self.weights_init_kaiming) 60 | self.classifier.apply(self.weights_init_classifier) 61 | 62 | def forward(self, x): 63 | x = self.reduction(x) 64 | # before_neck = x.squeeze(dim=3).squeeze(dim=2) 65 | # after_neck = self.bn(x).squeeze(dim=3).squeeze(dim=2) 66 | before_neck = x.view(x.size(0), x.size(1)) 67 | after_neck = self.bn(before_neck) 68 | if self.return_f: 69 | score = self.classifier(after_neck) 70 | return after_neck, score, before_neck 71 | else: 72 | x = self.classifier(x) 73 | return x 74 | 75 | def weights_init_kaiming(self, m): 76 | classname = m.__class__.__name__ 77 | if classname.find('Linear') != -1: 78 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out') 79 | nn.init.constant_(m.bias, 0.0) 80 | elif classname.find('Conv') != -1: 81 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') 82 | if m.bias is not None: 83 | nn.init.constant_(m.bias, 0.0) 84 | elif classname.find('BatchNorm') != -1: 85 | if m.affine: 86 | nn.init.constant_(m.weight, 1.0) 87 | nn.init.constant_(m.bias, 0.0) 88 | 89 | def weights_init_classifier(self, m): 90 | classname = m.__class__.__name__ 91 | if classname.find('Linear') != -1: 92 | nn.init.normal_(m.weight, std=0.001) 93 | if m.bias: 94 | nn.init.constant_(m.bias, 0.0) 95 | 96 | # Defines the new fc layer and classification layer 97 | # |--Linear--|--bn--|--relu--|--Linear--| 98 | 99 | 100 | class ClassBlock(nn.Module): 101 | def __init__(self, input_dim, class_num, droprate=0, relu=False, bnorm=True, num_bottleneck=512, linear=True, return_f=False): 102 | super(ClassBlock, self).__init__() 103 | self.return_f = return_f 104 | add_block = [] 105 | if linear: 106 | add_block += [nn.Linear(input_dim, num_bottleneck)] 107 | else: 108 | num_bottleneck = input_dim 109 | if bnorm: 110 | add_block += [nn.BatchNorm1d(num_bottleneck)] 111 | if relu: 112 | add_block += [nn.LeakyReLU(0.1)] 113 | if droprate > 0: 114 | add_block += [nn.Dropout(p=droprate)] 115 | add_block = nn.Sequential(*add_block) 116 | add_block.apply(self.weights_init_kaiming) 117 | 118 | classifier = [] 119 | classifier += [nn.Linear(num_bottleneck, class_num)] 120 | classifier = nn.Sequential(*classifier) 121 | classifier.apply(self.weights_init_classifier) 122 | 123 | self.add_block = add_block 124 | self.classifier = classifier 125 | 126 | def forward(self, x): 127 | x = self.add_block(x.squeeze(3).squeeze(2)) 128 | if self.return_f: 129 | f = x 130 | x = self.classifier(x) 131 | return f, x, f 132 | else: 133 | x = self.classifier(x) 134 | return x 135 | 136 | def weights_init_kaiming(self, m): 137 | classname = m.__class__.__name__ 138 | # print(classname) 139 | if classname.find('Conv') != -1: 140 | # For old pytorch, you may use kaiming_normal. 141 | nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 142 | elif classname.find('Linear') != -1: 143 | nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_out') 144 | nn.init.constant_(m.bias.data, 0.0) 145 | elif classname.find('BatchNorm1d') != -1: 146 | nn.init.normal_(m.weight.data, 1.0, 0.02) 147 | nn.init.constant_(m.bias.data, 0.0) 148 | 149 | def weights_init_classifier(self, m): 150 | classname = m.__class__.__name__ 151 | if classname.find('Linear') != -1: 152 | nn.init.normal_(m.weight.data, std=0.001) 153 | nn.init.constant_(m.bias.data, 0.0) 154 | 155 | -------------------------------------------------------------------------------- /model/lmbn_n_no_drop.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import torch 4 | from torch import nn 5 | from .osnet import osnet_x1_0, OSBlock 6 | from .attention import BatchDrop, BatchRandomErasing, PAM_Module, CAM_Module, SE_Module, Dual_Module 7 | from .bnneck import BNNeck, BNNeck3 8 | from torch.nn import functional as F 9 | from torch.autograd import Variable 10 | 11 | 12 | class LMBN_n_no_drop(nn.Module): 13 | def __init__(self, args): 14 | super(LMBN_n_no_drop, self).__init__() 15 | 16 | self.n_ch = 2 17 | self.chs = 512 // self.n_ch 18 | 19 | osnet = osnet_x1_0(pretrained=True) 20 | 21 | self.backone = nn.Sequential( 22 | osnet.conv1, 23 | osnet.maxpool, 24 | osnet.conv2, 25 | osnet.conv3[0] 26 | ) 27 | 28 | conv3 = osnet.conv3[1:] 29 | 30 | self.global_branch = nn.Sequential(copy.deepcopy( 31 | conv3), copy.deepcopy(osnet.conv4), copy.deepcopy(osnet.conv5)) 32 | 33 | self.partial_branch = nn.Sequential(copy.deepcopy( 34 | conv3), copy.deepcopy(osnet.conv4), copy.deepcopy(osnet.conv5)) 35 | 36 | self.channel_branch = nn.Sequential(copy.deepcopy( 37 | conv3), copy.deepcopy(osnet.conv4), copy.deepcopy(osnet.conv5)) 38 | 39 | self.global_pooling = nn.AdaptiveMaxPool2d((1, 1)) 40 | self.partial_pooling = nn.AdaptiveAvgPool2d((2, 1)) 41 | self.channel_pooling = nn.AdaptiveAvgPool2d((1, 1)) 42 | 43 | reduction = BNNeck3(512, args.num_classes, 44 | args.feats, return_f=True) 45 | self.reduction_0 = copy.deepcopy(reduction) 46 | self.reduction_1 = copy.deepcopy(reduction) 47 | self.reduction_2 = copy.deepcopy(reduction) 48 | self.reduction_3 = copy.deepcopy(reduction) 49 | 50 | self.shared = nn.Sequential(nn.Conv2d( 51 | self.chs, args.feats, 1, bias=False), nn.BatchNorm2d(args.feats), nn.ReLU(True)) 52 | self.weights_init_kaiming(self.shared) 53 | 54 | self.reduction_ch_0 = BNNeck( 55 | args.feats, args.num_classes, return_f=True) 56 | self.reduction_ch_1 = BNNeck( 57 | args.feats, args.num_classes, return_f=True) 58 | 59 | # if args.drop_block: 60 | # print('Using batch random erasing block.') 61 | # self.batch_drop_block = BatchRandomErasing() 62 | if args.drop_block: 63 | # print('Using batch drop block.') 64 | self.batch_drop_block = BatchDrop( 65 | h_ratio=args.h_ratio, w_ratio=args.w_ratio) 66 | else: 67 | self.batch_drop_block = None 68 | 69 | self.activation_map = args.activation_map 70 | 71 | def forward(self, x): 72 | # if self.batch_drop_block is not None: 73 | # x = self.batch_drop_block(x) 74 | 75 | x = self.backone(x) 76 | 77 | glo = self.global_branch(x) 78 | par = self.partial_branch(x) 79 | cha = self.channel_branch(x) 80 | 81 | if self.activation_map: 82 | 83 | _, _, h_par, _ = par.size() 84 | 85 | fmap_p0 = par[:, :, :h_par // 2, :] 86 | fmap_p1 = par[:, :, h_par // 2:, :] 87 | fmap_c0 = cha[:, :self.chs, :, :] 88 | fmap_c1 = cha[:, self.chs:, :, :] 89 | print('activation_map') 90 | 91 | return glo, fmap_c0, fmap_c1, fmap_p0, fmap_p1 92 | 93 | if self.batch_drop_block is not None: 94 | glo = self.batch_drop_block(glo) 95 | 96 | glo = self.global_pooling(glo) # shape:(batchsize, 2048,1,1) 97 | g_par = self.global_pooling(par) # shape:(batchsize, 2048,1,1) 98 | p_par = self.partial_pooling(par) # shape:(batchsize, 2048,3,1) 99 | cha = self.channel_pooling(cha) 100 | 101 | p0 = p_par[:, :, 0:1, :] 102 | p1 = p_par[:, :, 1:2, :] 103 | 104 | f_glo = self.reduction_0(glo) 105 | f_p0 = self.reduction_1(g_par) 106 | f_p1 = self.reduction_2(p0) 107 | f_p2 = self.reduction_3(p1) 108 | 109 | ################ 110 | 111 | c0 = cha[:, :self.chs, :, :] 112 | c1 = cha[:, self.chs:, :, :] 113 | c0 = self.shared(c0) 114 | c1 = self.shared(c1) 115 | f_c0 = self.reduction_ch_0(c0) 116 | f_c1 = self.reduction_ch_1(c1) 117 | 118 | ################ 119 | 120 | fea = [f_glo[-1], f_p0[-1]] 121 | 122 | if not self.training: 123 | 124 | return torch.stack([f_glo[0], f_p0[0], f_p1[0], f_p2[0], f_c0[0], f_c1[0]], dim=2) 125 | 126 | return [f_glo[1], f_p0[1], f_p1[1], f_p2[1], f_c0[1], f_c1[1]], fea 127 | 128 | def weights_init_kaiming(self, m): 129 | classname = m.__class__.__name__ 130 | if classname.find('Linear') != -1: 131 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out') 132 | nn.init.constant_(m.bias, 0.0) 133 | elif classname.find('Conv') != -1: 134 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') 135 | if m.bias is not None: 136 | nn.init.constant_(m.bias, 0.0) 137 | elif classname.find('BatchNorm') != -1: 138 | if m.affine: 139 | nn.init.constant_(m.weight, 1.0) 140 | nn.init.constant_(m.bias, 0.0) 141 | 142 | 143 | if __name__ == '__main__': 144 | # Here I left a simple forward function. 145 | # Test the model, before you train it. 146 | import argparse 147 | 148 | parser = argparse.ArgumentParser(description='MGN') 149 | parser.add_argument('--num_classes', type=int, default=751, help='') 150 | parser.add_argument('--bnneck', type=bool, default=True) 151 | parser.add_argument('--pool', type=str, default='max') 152 | parser.add_argument('--feats', type=int, default=512) 153 | parser.add_argument('--drop_block', type=bool, default=True) 154 | parser.add_argument('--w_ratio', type=float, default=1.0, help='') 155 | 156 | args = parser.parse_args() 157 | net = MCMP_n(args) 158 | # net.classifier = nn.Sequential() 159 | # print([p for p in net.parameters()]) 160 | # a=filter(lambda p: p.requires_grad, net.parameters()) 161 | # print(a) 162 | 163 | print(net) 164 | input = Variable(torch.FloatTensor(8, 3, 384, 128)) 165 | net.eval() 166 | output = net(input) 167 | print(output.shape) 168 | print('net output size:') 169 | # print(len(output)) 170 | # for k in output[0]: 171 | # print(k.shape) 172 | # for k in output[1]: 173 | # print(k.shape) 174 | -------------------------------------------------------------------------------- /model/mgn.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | 7 | from torchvision.models.resnet import resnet50, Bottleneck 8 | 9 | 10 | def make_model(args): 11 | return MGN(args) 12 | 13 | 14 | class MGN(nn.Module): 15 | def __init__(self, args): 16 | super(MGN, self).__init__() 17 | num_classes = args.num_classes 18 | 19 | resnet = resnet50(pretrained=True) 20 | 21 | self.backone = nn.Sequential( 22 | resnet.conv1, 23 | resnet.bn1, 24 | resnet.relu, 25 | resnet.maxpool, 26 | resnet.layer1, 27 | resnet.layer2, 28 | resnet.layer3[0], 29 | ) 30 | 31 | res_conv4 = nn.Sequential(*resnet.layer3[1:]) 32 | 33 | res_g_conv5 = resnet.layer4 34 | 35 | res_p_conv5 = nn.Sequential( 36 | Bottleneck(1024, 512, downsample=nn.Sequential( 37 | nn.Conv2d(1024, 2048, 1, bias=False), nn.BatchNorm2d(2048))), 38 | Bottleneck(2048, 512), 39 | Bottleneck(2048, 512)) 40 | res_p_conv5.load_state_dict(resnet.layer4.state_dict()) 41 | 42 | self.p1 = nn.Sequential(copy.deepcopy( 43 | res_conv4), copy.deepcopy(res_g_conv5)) 44 | self.p2 = nn.Sequential(copy.deepcopy( 45 | res_conv4), copy.deepcopy(res_p_conv5)) 46 | self.p3 = nn.Sequential(copy.deepcopy( 47 | res_conv4), copy.deepcopy(res_p_conv5)) 48 | 49 | if args.pool == 'max': 50 | pool2d = nn.MaxPool2d 51 | elif args.pool == 'avg': 52 | pool2d = nn.AvgPool2d 53 | else: 54 | raise Exception() 55 | 56 | self.maxpool_zg_p1 = pool2d(kernel_size=(12, 4)) 57 | self.maxpool_zg_p2 = pool2d(kernel_size=(24, 8)) 58 | self.maxpool_zg_p3 = pool2d(kernel_size=(24, 8)) 59 | self.maxpool_zp2 = pool2d(kernel_size=(12, 8)) 60 | self.maxpool_zp3 = pool2d(kernel_size=(8, 8)) 61 | 62 | reduction = nn.Sequential(nn.Conv2d( 63 | 2048, args.feats, 1, bias=False), nn.BatchNorm2d(args.feats), nn.ReLU()) 64 | 65 | self._init_reduction(reduction) 66 | self.reduction_0 = copy.deepcopy(reduction) 67 | self.reduction_1 = copy.deepcopy(reduction) 68 | self.reduction_2 = copy.deepcopy(reduction) 69 | self.reduction_3 = copy.deepcopy(reduction) 70 | self.reduction_4 = copy.deepcopy(reduction) 71 | self.reduction_5 = copy.deepcopy(reduction) 72 | self.reduction_6 = copy.deepcopy(reduction) 73 | self.reduction_7 = copy.deepcopy(reduction) 74 | 75 | #self.fc_id_2048_0 = nn.Linear(2048, num_classes) 76 | self.fc_id_2048_0 = nn.Linear(args.feats, num_classes) 77 | self.fc_id_2048_1 = nn.Linear(args.feats, num_classes) 78 | self.fc_id_2048_2 = nn.Linear(args.feats, num_classes) 79 | 80 | self.fc_id_256_1_0 = nn.Linear(args.feats, num_classes) 81 | self.fc_id_256_1_1 = nn.Linear(args.feats, num_classes) 82 | self.fc_id_256_2_0 = nn.Linear(args.feats, num_classes) 83 | self.fc_id_256_2_1 = nn.Linear(args.feats, num_classes) 84 | self.fc_id_256_2_2 = nn.Linear(args.feats, num_classes) 85 | 86 | self._init_fc(self.fc_id_2048_0) 87 | self._init_fc(self.fc_id_2048_1) 88 | self._init_fc(self.fc_id_2048_2) 89 | 90 | self._init_fc(self.fc_id_256_1_0) 91 | self._init_fc(self.fc_id_256_1_1) 92 | self._init_fc(self.fc_id_256_2_0) 93 | self._init_fc(self.fc_id_256_2_1) 94 | self._init_fc(self.fc_id_256_2_2) 95 | 96 | @staticmethod 97 | def _init_reduction(reduction): 98 | # conv 99 | nn.init.kaiming_normal_(reduction[0].weight, mode='fan_in') 100 | #nn.init.constant_(reduction[0].bias, 0.) 101 | 102 | # bn 103 | nn.init.normal_(reduction[1].weight, mean=1., std=0.02) 104 | nn.init.constant_(reduction[1].bias, 0.) 105 | 106 | @staticmethod 107 | def _init_fc(fc): 108 | nn.init.kaiming_normal_(fc.weight, mode='fan_out') 109 | #nn.init.normal_(fc.weight, std=0.001) 110 | nn.init.constant_(fc.bias, 0.) 111 | 112 | def forward(self, x): 113 | 114 | x = self.backone(x) 115 | 116 | p1 = self.p1(x) 117 | p2 = self.p2(x) 118 | p3 = self.p3(x) 119 | 120 | zg_p1 = self.maxpool_zg_p1(p1) # shape:(batchsize, 2048,1,1) 121 | zg_p2 = self.maxpool_zg_p2(p2) # shape:(batchsize, 2048,1,1) 122 | zg_p3 = self.maxpool_zg_p3(p3) # shape:(batchsize, 2048,1,1) 123 | 124 | zp2 = self.maxpool_zp2(p2) # shape:(batchsize, 2048,2,1) 125 | z0_p2 = zp2[:, :, 0:1, :] 126 | z1_p2 = zp2[:, :, 1:2, :] 127 | 128 | zp3 = self.maxpool_zp3(p3) # shape:(batchsize, 2048,3,1) 129 | z0_p3 = zp3[:, :, 0:1, :] 130 | z1_p3 = zp3[:, :, 1:2, :] 131 | z2_p3 = zp3[:, :, 2:3, :] 132 | 133 | fg_p1 = self.reduction_0(zg_p1).squeeze(dim=3).squeeze(dim=2) 134 | fg_p2 = self.reduction_1(zg_p2).squeeze(dim=3).squeeze(dim=2) 135 | fg_p3 = self.reduction_2(zg_p3).squeeze(dim=3).squeeze(dim=2) 136 | f0_p2 = self.reduction_3(z0_p2).squeeze(dim=3).squeeze(dim=2) 137 | f1_p2 = self.reduction_4(z1_p2).squeeze(dim=3).squeeze(dim=2) 138 | f0_p3 = self.reduction_5(z0_p3).squeeze(dim=3).squeeze(dim=2) 139 | f1_p3 = self.reduction_6(z1_p3).squeeze(dim=3).squeeze(dim=2) 140 | f2_p3 = self.reduction_7(z2_p3).squeeze(dim=3).squeeze(dim=2) 141 | 142 | ''' 143 | l_p1 = self.fc_id_2048_0(zg_p1.squeeze(dim=3).squeeze(dim=2)) 144 | l_p2 = self.fc_id_2048_1(zg_p2.squeeze(dim=3).squeeze(dim=2)) 145 | l_p3 = self.fc_id_2048_2(zg_p3.squeeze(dim=3).squeeze(dim=2)) 146 | ''' 147 | l_p1 = self.fc_id_2048_0(fg_p1) 148 | l_p2 = self.fc_id_2048_1(fg_p2) 149 | l_p3 = self.fc_id_2048_2(fg_p3) 150 | 151 | l0_p2 = self.fc_id_256_1_0(f0_p2) 152 | l1_p2 = self.fc_id_256_1_1(f1_p2) 153 | l0_p3 = self.fc_id_256_2_0(f0_p3) 154 | l1_p3 = self.fc_id_256_2_1(f1_p3) 155 | l2_p3 = self.fc_id_256_2_2(f2_p3) 156 | 157 | fea = [fg_p1, fg_p2, fg_p3] 158 | 159 | if not self.training: 160 | 161 | return torch.stack([fg_p1, fg_p2, fg_p3, f0_p2, 162 | f1_p2, f0_p3, f1_p3, f2_p3], dim=2) 163 | # print(predict.shape) 164 | return [l_p1, l_p2, l_p3, l0_p2, l1_p2, l0_p3, l1_p3, l2_p3], fea 165 | -------------------------------------------------------------------------------- /model/p.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | import random 7 | import math 8 | from .osnet import osnet_x1_0, OSBlock 9 | from .attention import BatchDrop, BatchRandomErasing, PAM_Module, CAM_Module, SE_Module, Dual_Module 10 | from .bnneck import BNNeck, BNNeck3 11 | 12 | from torch.autograd import Variable 13 | 14 | 15 | class P(nn.Module): 16 | def __init__(self, args): 17 | super(P, self).__init__() 18 | 19 | osnet = osnet_x1_0(pretrained=True) 20 | 21 | self.backone = nn.Sequential( 22 | osnet.conv1, 23 | osnet.maxpool, 24 | osnet.conv2, 25 | osnet.conv3[0] 26 | ) 27 | 28 | conv3 = osnet.conv3[1:] 29 | 30 | # downsample_conv4 = osnet._make_layer(OSBlock, 2, 384, 512, True) 31 | # downsample_conv4[:2].load_state_dict(osnet.conv4[:2].state_dict()) 32 | 33 | # self.global_branch = nn.Sequential(copy.deepcopy( 34 | # conv3), copy.deepcopy(osnet.conv4), copy.deepcopy(osnet.conv5)) 35 | 36 | self.partial_branch = nn.Sequential(copy.deepcopy( 37 | conv3), copy.deepcopy(osnet.conv4), copy.deepcopy(osnet.conv5)) 38 | 39 | # self.channel_branch = nn.Sequential(copy.deepcopy( 40 | # conv3), copy.deepcopy(osnet.conv4), copy.deepcopy(osnet.conv5)) 41 | 42 | # if args.pool == 'max': 43 | # pool2d = nn.AdaptiveMaxPool2d 44 | # elif args.pool == 'avg': 45 | # pool2d = nn.AdaptiveAvgPool2d 46 | # else: 47 | # raise Exception() 48 | 49 | self.global_pooling = nn.AdaptiveMaxPool2d((1, 1)) 50 | self.partial_pooling = nn.AdaptiveAvgPool2d((2, 1)) 51 | # self.channel_pooling = nn.AdaptiveAvgPool2d((1, 1)) 52 | 53 | reduction = BNNeck3(512, args.num_classes, 54 | args.feats, return_f=True) 55 | # self.reduction_0 = copy.deepcopy(reduction) 56 | self.reduction_1 = copy.deepcopy(reduction) 57 | self.reduction_2 = copy.deepcopy(reduction) 58 | self.reduction_3 = copy.deepcopy(reduction) 59 | 60 | # self.shared = nn.Sequential(nn.Conv2d( 61 | # self.chs, args.feats, 1, bias=False), nn.BatchNorm2d(args.feats), nn.ReLU(True)) 62 | # self.weights_init_kaiming(self.shared) 63 | 64 | # self.reduction_ch_0 = BNNeck( 65 | # args.feats, args.num_classes, return_f=True) 66 | # self.reduction_ch_1 = BNNeck( 67 | # args.feats, args.num_classes, return_f=True) 68 | 69 | # if args.drop_block: 70 | # print('Using batch random erasing block.') 71 | # self.batch_drop_block = BatchRandomErasing() 72 | if args.drop_block: 73 | print('Using batch drop block.') 74 | self.batch_drop_block = BatchDrop( 75 | h_ratio=args.h_ratio, w_ratio=args.w_ratio) 76 | else: 77 | self.batch_drop_block = None 78 | 79 | self.activation_map = args.activation_map 80 | 81 | def forward(self, x): 82 | # if self.batch_drop_block is not None: 83 | # x = self.batch_drop_block(x) 84 | 85 | x = self.backone(x) 86 | 87 | # glo = self.global_branch(x) 88 | par = self.partial_branch(x) 89 | # cha = self.channel_branch(x) 90 | 91 | # if self.activation_map: 92 | 93 | # _, _, h_par, _ = par.size() 94 | 95 | # fmap_p0 = par[:, :, :h_par // 2, :] 96 | # fmap_p1 = par[:, :, h_par // 2:, :] 97 | # fmap_c0 = cha[:, :self.chs, :, :] 98 | # fmap_c1 = cha[:, self.chs:, :, :] 99 | # print('activation_map') 100 | 101 | # return glo, fmap_c0, fmap_c1, fmap_p0, fmap_p1 102 | 103 | # if self.batch_drop_block is not None: 104 | # glo = self.batch_drop_block(glo) 105 | 106 | # glo = self.global_pooling(glo) # shape:(batchsize, 2048,1,1) 107 | g_par = self.global_pooling(par) # shape:(batchsize, 2048,1,1) 108 | p_par = self.partial_pooling(par) # shape:(batchsize, 2048,3,1) 109 | # cha = self.channel_pooling(cha) 110 | 111 | p0 = p_par[:, :, 0:1, :] 112 | p1 = p_par[:, :, 1:2, :] 113 | 114 | # f_glo = self.reduction_0(glo) 115 | f_p0 = self.reduction_1(g_par) 116 | f_p1 = self.reduction_2(p0) 117 | f_p2 = self.reduction_3(p1) 118 | 119 | ################ 120 | 121 | # c0 = cha[:, :self.chs, :, :] 122 | # c1 = cha[:, self.chs:, :, :] 123 | # c0 = self.shared(c0) 124 | # c1 = self.shared(c1) 125 | # f_c0 = self.reduction_ch_0(c0) 126 | # f_c1 = self.reduction_ch_1(c1) 127 | 128 | ################ 129 | 130 | fea = [f_p0[-1]] 131 | 132 | if not self.training: 133 | 134 | return torch.stack([f_p0[0], f_p1[0], f_p2[0]], dim=2) 135 | 136 | return [f_p0[1], f_p1[1], f_p2[1]], fea 137 | 138 | def weights_init_kaiming(self, m): 139 | classname = m.__class__.__name__ 140 | if classname.find('Linear') != -1: 141 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out') 142 | nn.init.constant_(m.bias, 0.0) 143 | elif classname.find('Conv') != -1: 144 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') 145 | if m.bias is not None: 146 | nn.init.constant_(m.bias, 0.0) 147 | elif classname.find('BatchNorm') != -1: 148 | if m.affine: 149 | nn.init.constant_(m.weight, 1.0) 150 | nn.init.constant_(m.bias, 0.0) 151 | 152 | 153 | if __name__ == '__main__': 154 | # Here I left a simple forward function. 155 | # Test the model, before you train it. 156 | import argparse 157 | 158 | parser = argparse.ArgumentParser(description='MGN') 159 | parser.add_argument('--num_classes', type=int, default=751, help='') 160 | parser.add_argument('--bnneck', type=bool, default=True) 161 | parser.add_argument('--pool', type=str, default='max') 162 | parser.add_argument('--feats', type=int, default=512) 163 | parser.add_argument('--drop_block', type=bool, default=True) 164 | parser.add_argument('--w_ratio', type=float, default=1.0, help='') 165 | 166 | args = parser.parse_args() 167 | net = MCMP_n(args) 168 | # net.classifier = nn.Sequential() 169 | # print([p for p in net.parameters()]) 170 | # a=filter(lambda p: p.requires_grad, net.parameters()) 171 | # print(a) 172 | 173 | print(net) 174 | input = Variable(torch.FloatTensor(8, 3, 384, 128)) 175 | net.eval() 176 | output = net(input) 177 | print(output.shape) 178 | print('net output size:') 179 | # print(len(output)) 180 | # for k in output[0]: 181 | # print(k.shape) 182 | # for k in output[1]: 183 | # print(k.shape) 184 | -------------------------------------------------------------------------------- /model/pcb.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision.models.resnet import resnet50, Bottleneck 4 | from torch.autograd import Variable 5 | from .bnneck import BNNeck, BNNeck3, ClassBlock 6 | 7 | 8 | class PCB(nn.Module): 9 | def __init__(self, args): 10 | super(PCB, self).__init__() 11 | 12 | self.part = args.parts # We cut the pool5 to 6 parts 13 | model_ft = resnet50(pretrained=True) 14 | self.model = model_ft 15 | self.avgpool = nn.AdaptiveAvgPool2d((self.part, 1)) 16 | 17 | self.avgpool_before_triplet = nn.AdaptiveAvgPool2d((1, 1)) 18 | self.dropout = nn.Dropout(p=0.5) 19 | # remove the final downsample 20 | self.model.layer4[0].downsample[0].stride = (1, 1) 21 | self.model.layer4[0].conv2.stride = (1, 1) 22 | self.bnneck = args.bnneck 23 | # define 6 classifiers 24 | for i in range(self.part): 25 | name = 'classifier' + str(i) 26 | if self.bnneck: 27 | # setattr(self, name, BNNeck(2048, args.num_classes, return_f=True)) 28 | setattr(self, name, BNNeck3( 29 | 2048, args.num_classes, args.feats, return_f=True)) 30 | else: 31 | 32 | setattr(self, name, ClassBlock(2048, args.num_classes, droprate=0.5, 33 | relu=False, bnorm=True, num_bottleneck=args.feats, return_f=True)) 34 | self.global_branch = BNNeck3( 35 | 2048, args.num_classes, args.feats, return_f=True) 36 | print('PCB_conv divide into {} parts, using {} dims feature.'.format( 37 | args.parts, args.feats)) 38 | # self.global_branch = BNNeck(2048, args.num_classes, return_f=True) 39 | 40 | def forward(self, x): 41 | x = self.model.conv1(x) 42 | x = self.model.bn1(x) 43 | x = self.model.relu(x) 44 | x = self.model.maxpool(x) 45 | 46 | x = self.model.layer1(x) 47 | x = self.model.layer2(x) 48 | x = self.model.layer3(x) 49 | x = self.model.layer4(x) 50 | feat_to_global_branch = self.avgpool_before_triplet(x) 51 | x = self.avgpool(x) 52 | # x = self.dropout(x) 53 | # print(x.shape) 54 | part = {} 55 | predict = {} 56 | # get six part feature batchsize*2048*6 57 | for i in range(self.part): 58 | part[i] = x[:, :, i].unsqueeze(dim=3) 59 | # part[i] = torch.squeeze(x[:, :,:, i]) 60 | 61 | name = 'classifier' + str(i) 62 | c = getattr(self, name) 63 | predict[i] = c(part[i]) 64 | 65 | global_feat = [x.view(x.size(0), x.size(1), x.size(2))] 66 | 67 | feat_global_branch = self.global_branch(feat_to_global_branch) 68 | # y = [x.view(x.size(0), -1)] 69 | 70 | score = [] 71 | after_neck = [] 72 | # print(y[0].shape) 73 | for i in range(self.part): 74 | 75 | score.append(predict[i][1]) 76 | if self.bnneck: 77 | 78 | after_neck.append(predict[i][0]) 79 | 80 | if not self.training: 81 | return torch.stack(after_neck + [feat_global_branch[0]], dim=2) 82 | return score + [feat_global_branch[1]], feat_global_branch[-1] 83 | 84 | 85 | if __name__ == '__main__': 86 | # Here I left a simple forward function. 87 | # Test the model, before you train it. 88 | import argparse 89 | 90 | parser = argparse.ArgumentParser(description='MGN') 91 | parser.add_argument('--num_classes', type=int, default=751, help='') 92 | parser.add_argument('--bnneck', type=bool, default=True) 93 | parser.add_argument('--parts', type=int, default=3) 94 | parser.add_argument('--feats', type=int, default=256) 95 | 96 | args = parser.parse_args() 97 | net = PCB(args) 98 | # net.classifier = nn.Sequential() 99 | # print([p for p in net.parameters()]) 100 | # a=filter(lambda p: p.requires_grad, net.parameters()) 101 | # print(a) 102 | net.eval() 103 | print(net) 104 | input = Variable(torch.FloatTensor(8, 3, 256, 128)) 105 | output = net(input) 106 | print('net output size:') 107 | print(len(output)) 108 | print(output.shape) 109 | # for k in output[0]: 110 | # print(k.shape) 111 | # print(output[-1].shape) 112 | -------------------------------------------------------------------------------- /model/resnet50.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision.models.resnet import resnet50, Bottleneck 4 | import random 5 | from .bnneck import BNNeck, BNNeck3, ClassBlock 6 | from torch.autograd import Variable 7 | 8 | 9 | # Defines the new fc layer and classification layer 10 | # |--Linear--|--bn--|--relu--|--Linear--| 11 | 12 | class BatchDrop(nn.Module): 13 | def __init__(self, h_ratio, w_ratio): 14 | super(BatchDrop, self).__init__() 15 | self.h_ratio = h_ratio 16 | self.w_ratio = w_ratio 17 | 18 | def forward(self, x): 19 | if self.training: 20 | h, w = x.size()[-2:] 21 | rh = round(self.h_ratio * h) 22 | rw = round(self.w_ratio * w) 23 | sx = random.randint(0, h - rh) 24 | sy = random.randint(0, w - rw) 25 | mask = x.new_ones(x.size()) 26 | mask[:, :, sx:sx + rh, sy:sy + rw] = 0 27 | x = x * mask 28 | return x 29 | 30 | 31 | # Define the ResNet50-based Model 32 | class ResNet50(nn.Module): 33 | 34 | # def __init__(self, class_num, droprate=0.5, stride=2): 35 | def __init__(self, args, droprate=0.5, stride=1): 36 | 37 | super(ResNet50, self).__init__() 38 | resnet = resnet50(pretrained=True) 39 | # avg pooling to global pooling 40 | if stride == 1: 41 | resnet.layer4[0].downsample[0].stride = (1, 1) 42 | resnet.layer4[0].conv2.stride = (1, 1) 43 | resnet.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 44 | self.model = resnet 45 | self.bnneck = args.bnneck 46 | self.drop_block = args.drop_block 47 | 48 | if args.feat_inference == 'before': 49 | self.before_neck = True 50 | print('Using before_neck inference') 51 | else: 52 | self.before_neck = False 53 | if args.drop_block: 54 | print('Using batch drop block.') 55 | resnet.avgpool = nn.AdaptiveMaxPool2d((1, 1)) 56 | self.batch_drop_block = BatchDrop( 57 | h_ratio=args.h_ratio, w_ratio=args.w_ratio) 58 | if self.bnneck: 59 | self.classifier = BNNeck(2048, args.num_classes, # feat_dim=512, 60 | return_f=True) 61 | # self.classifier = BNNeck3(2048, args.num_classes, feat_dim=512, 62 | # return_f=True) 63 | 64 | else: 65 | 66 | self.classifier = ClassBlock( 67 | 2048, args.num_classes, num_bottleneck=args.feats, return_f=True) 68 | 69 | def forward(self, x): 70 | x = self.model.conv1(x) 71 | x = self.model.bn1(x) 72 | x = self.model.relu(x) 73 | x = self.model.maxpool(x) 74 | x = self.model.layer1(x) 75 | x = self.model.layer2(x) 76 | x = self.model.layer3(x) 77 | x = self.model.layer4(x) 78 | # print(x.size()) 79 | if self.drop_block: 80 | x = self.batch_drop_block(x) 81 | x = self.model.avgpool(x) 82 | # x = x.view(x.size(0), x.size(1)) 83 | # print(x.shape) 84 | x = self.classifier(x) 85 | # print(x[0].shape) 86 | # print(x[1].shape) 87 | # print(x[2].shape) 88 | 89 | if not self.training: 90 | if self.before_neck: 91 | return x[-1] 92 | return x[0] 93 | # print(x[1].size()) 94 | # print(x[-1].size()) 95 | return [x[1]], [x[-1]] 96 | 97 | 98 | if __name__ == '__main__': 99 | # Here I left a simple forward function. 100 | # Test the model, before you train it. 101 | import argparse 102 | 103 | parser = argparse.ArgumentParser(description='MGN') 104 | parser.add_argument('--num_classes', type=int, default=751, help='') 105 | parser.add_argument('--bnneck', type=bool, default=False) 106 | parser.add_argument('--pool', type=str, default='max') 107 | parser.add_argument('--feats', type=int, default=256) 108 | parser.add_argument('--drop_block', type=bool, default=True) 109 | parser.add_argument('--w_ratio', type=float, default=1.0, help='') 110 | parser.add_argument('--h_ratio', type=float, default=0.33, help='') 111 | 112 | args = parser.parse_args() 113 | net = ResNet50(args) 114 | # net.classifier = nn.Sequential() 115 | # print([p for p in net.parameters()]) 116 | # a=filter(lambda p: p.requires_grad, net.parameters()) 117 | # print(a) 118 | 119 | print(net) 120 | input = Variable(torch.FloatTensor(8, 3, 384, 128)) 121 | net.eval() 122 | output = net(input) 123 | print(output.shape) 124 | print('net output size:') 125 | # print(len(output)) 126 | # for k in output[0]: 127 | # print(k.shape) 128 | # for k in output[1]: 129 | -------------------------------------------------------------------------------- /model/resnet50_ibn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import torch.utils.model_zoo as model_zoo 5 | 6 | 7 | __all__ = ['ResNet_IBN', 'resnet50_ibn_a', 'resnet101_ibn_a', 8 | 'resnet152_ibn_a'] 9 | 10 | 11 | model_urls = { 12 | 'resnet50_ibn': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet50_ibn_a-d9d0bb7b.pth', 13 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 14 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 15 | } 16 | 17 | 18 | class IBN(nn.Module): 19 | def __init__(self, planes): 20 | super(IBN, self).__init__() 21 | half1 = int(planes/2) 22 | self.half = half1 23 | half2 = planes - half1 24 | self.IN = nn.InstanceNorm2d(half1, affine=True) 25 | self.BN = nn.BatchNorm2d(half2) 26 | 27 | def forward(self, x): 28 | split = torch.split(x, self.half, 1) 29 | out1 = self.IN(split[0].contiguous()) 30 | out2 = self.BN(split[1].contiguous()) 31 | out = torch.cat((out1, out2), 1) 32 | return out 33 | 34 | 35 | class Bottleneck_IBN(nn.Module): 36 | expansion = 4 37 | 38 | def __init__(self, inplanes, planes, ibn=False, stride=1, downsample=None): 39 | super(Bottleneck_IBN, self).__init__() 40 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 41 | if ibn: 42 | self.bn1 = IBN(planes) 43 | else: 44 | self.bn1 = nn.BatchNorm2d(planes) 45 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 46 | padding=1, bias=False) 47 | self.bn2 = nn.BatchNorm2d(planes) 48 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 49 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 50 | self.relu = nn.ReLU(inplace=True) 51 | self.downsample = downsample 52 | self.stride = stride 53 | 54 | def forward(self, x): 55 | residual = x 56 | 57 | out = self.conv1(x) 58 | out = self.bn1(out) 59 | out = self.relu(out) 60 | 61 | out = self.conv2(out) 62 | out = self.bn2(out) 63 | out = self.relu(out) 64 | 65 | out = self.conv3(out) 66 | out = self.bn3(out) 67 | 68 | if self.downsample is not None: 69 | residual = self.downsample(x) 70 | 71 | out += residual 72 | out = self.relu(out) 73 | 74 | return out 75 | 76 | 77 | class ResNet_IBN(nn.Module): 78 | 79 | def __init__(self, last_stride, block, layers, num_classes=1000): 80 | scale = 64 81 | self.inplanes = scale 82 | super(ResNet_IBN, self).__init__() 83 | self.conv1 = nn.Conv2d(3, scale, kernel_size=7, stride=2, padding=3, 84 | bias=False) 85 | self.bn1 = nn.BatchNorm2d(scale) 86 | self.relu = nn.ReLU(inplace=True) 87 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 88 | self.layer1 = self._make_layer(block, scale, layers[0]) 89 | self.layer2 = self._make_layer(block, scale*2, layers[1], stride=2) 90 | self.layer3 = self._make_layer(block, scale*4, layers[2], stride=2) 91 | self.layer4 = self._make_layer(block, scale*8, layers[3], stride=last_stride) 92 | self.avgpool = nn.AvgPool2d(7) 93 | self.fc = nn.Linear(scale * 8 * block.expansion, num_classes) 94 | 95 | for m in self.modules(): 96 | if isinstance(m, nn.Conv2d): 97 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 98 | m.weight.data.normal_(0, math.sqrt(2. / n)) 99 | elif isinstance(m, nn.BatchNorm2d): 100 | m.weight.data.fill_(1) 101 | m.bias.data.zero_() 102 | elif isinstance(m, nn.InstanceNorm2d): 103 | m.weight.data.fill_(1) 104 | m.bias.data.zero_() 105 | 106 | def _make_layer(self, block, planes, blocks, stride=1): 107 | downsample = None 108 | if stride != 1 or self.inplanes != planes * block.expansion: 109 | downsample = nn.Sequential( 110 | nn.Conv2d(self.inplanes, planes * block.expansion, 111 | kernel_size=1, stride=stride, bias=False), 112 | nn.BatchNorm2d(planes * block.expansion), 113 | ) 114 | 115 | layers = [] 116 | ibn = True 117 | if planes == 512: 118 | ibn = False 119 | layers.append(block(self.inplanes, planes, ibn, stride, downsample)) 120 | self.inplanes = planes * block.expansion 121 | for i in range(1, blocks): 122 | layers.append(block(self.inplanes, planes, ibn)) 123 | 124 | return nn.Sequential(*layers) 125 | 126 | def forward(self, x): 127 | x = self.conv1(x) 128 | x = self.bn1(x) 129 | x = self.relu(x) 130 | x = self.maxpool(x) 131 | 132 | x = self.layer1(x) 133 | x = self.layer2(x) 134 | x = self.layer3(x) 135 | x = self.layer4(x) 136 | 137 | # x = self.avgpool(x) 138 | # x = x.view(x.size(0), -1) 139 | # x = self.fc(x) 140 | 141 | return x 142 | 143 | def load_param(self, model_path): 144 | param_dict = torch.load(model_path) 145 | for i in param_dict: 146 | if 'fc' in i: 147 | continue 148 | self.state_dict()[i].copy_(param_dict[i]) 149 | 150 | 151 | def resnet50_ibn_a(last_stride, pretrained=False, **kwargs): 152 | """Constructs a ResNet-50 model. 153 | Args: 154 | pretrained (bool): If True, returns a model pre-trained on ImageNet 155 | """ 156 | model = ResNet_IBN(last_stride, Bottleneck_IBN, [3, 4, 6, 3], **kwargs) 157 | if pretrained: 158 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50_ibn'])) 159 | return model 160 | 161 | 162 | def resnet101_ibn_a(last_stride, pretrained=False, **kwargs): 163 | """Constructs a ResNet-101 model. 164 | Args: 165 | pretrained (bool): If True, returns a model pre-trained on ImageNet 166 | """ 167 | model = ResNet_IBN(last_stride, Bottleneck_IBN, [3, 4, 23, 3], **kwargs) 168 | if pretrained: 169 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 170 | return model 171 | 172 | 173 | def resnet152_ibn_a(last_stride, pretrained=False, **kwargs): 174 | """Constructs a ResNet-152 model. 175 | Args: 176 | pretrained (bool): If True, returns a model pre-trained on ImageNet 177 | """ 178 | model = ResNet_IBN(last_stride, Bottleneck_IBN, [3, 8, 36, 3], **kwargs) 179 | if pretrained: 180 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 181 | return model -------------------------------------------------------------------------------- /optim/Icon : -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jixunbo/LightMBN/1f4ec0de6535604f89257195f5c083845491dc59/optim/Icon -------------------------------------------------------------------------------- /optim/__init__.py: -------------------------------------------------------------------------------- 1 | import torch.optim as optim 2 | import torch.optim.lr_scheduler as lrs 3 | from .n_adam import NAdam 4 | from .warmup_scheduler import WarmupMultiStepLR 5 | from .warmup_cosine_scheduler import WarmupCosineAnnealingLR 6 | 7 | 8 | def make_optimizer(args, model): 9 | trainable = filter(lambda x: x.requires_grad, model.parameters()) 10 | if args.model in ['PCB', 'PCB_v', 'PCB_conv']: 11 | ignored_params = [] 12 | for i in range(args.parts): 13 | name = 'classifier' + str(i) 14 | c = getattr(model, name) 15 | ignored_params = ignored_params + list(map(id, c.parameters())) 16 | 17 | ignored_params = tuple(ignored_params) 18 | 19 | base_params = filter(lambda p: id( 20 | p) not in ignored_params, model.model.parameters()) 21 | 22 | if args.pcb_different_lr == 'True': 23 | print('PCB different lr') 24 | if args.optimizer == 'SGD': 25 | optimizer_pcb = optim.SGD([ 26 | {'params': base_params, 'lr': 0.1 * args.lr}, 27 | {'params': model.model.classifier0.parameters(), 'lr': args.lr}, 28 | {'params': model.model.classifier1.parameters(), 'lr': args.lr}, 29 | {'params': model.model.classifier2.parameters(), 'lr': args.lr}, 30 | {'params': model.model.classifier3.parameters(), 'lr': args.lr}, 31 | {'params': model.model.classifier4.parameters(), 'lr': args.lr}, 32 | {'params': model.model.classifier5.parameters(), 'lr': args.lr}, 33 | 34 | ], weight_decay=5e-4, momentum=0.9, nesterov=True) 35 | return optimizer_pcb 36 | elif args.optimizer == 'ADAM': 37 | params = [] 38 | for i in range(args.parts): 39 | name = 'classifier' + str(i) 40 | c = getattr(model.model, name) 41 | params.append({'params': c.parameters(), 'lr': args.lr}) 42 | params = [{'params': base_params, 43 | 'lr': 0.1 * args.lr}] + params 44 | 45 | optimizer_pcb = optim.Adam(params, weight_decay=5e-4) 46 | 47 | return optimizer_pcb 48 | else: 49 | raise('Optimizer not found, please choose adam or sgd.') 50 | 51 | if args.optimizer == 'SGD': 52 | optimizer_function = optim.SGD 53 | kwargs = { 54 | 'momentum': args.momentum, 55 | 'dampening': args.dampening, 56 | 'nesterov': args.nesterov 57 | } 58 | elif args.optimizer == 'ADAM': 59 | optimizer_function = optim.Adam 60 | kwargs = { 61 | 'betas': (args.beta1, args.beta2), 62 | 'eps': args.epsilon, 63 | 'amsgrad': args.amsgrad 64 | } 65 | elif args.optimizer == 'NADAM': 66 | optimizer_function = NAdam 67 | kwargs = { 68 | 'betas': (args.beta1, args.beta2), 69 | 'eps': args.epsilon 70 | } 71 | elif args.optimizer == 'RMSprop': 72 | optimizer_function = optim.RMSprop 73 | kwargs = { 74 | 'eps': args.epsilon, 75 | 'momentum': args.momentum 76 | } 77 | else: 78 | raise Exception() 79 | 80 | kwargs['lr'] = args.lr 81 | kwargs['weight_decay'] = args.weight_decay 82 | 83 | return optimizer_function(trainable, **kwargs) 84 | 85 | 86 | def make_scheduler(args, optimizer, last_epoch): 87 | 88 | # if args.warmup in ['linear', 'constant'] and args.load == '' and args.pre_train == '': 89 | milestones = args.decay_type.split('_') 90 | milestones.pop(0) 91 | milestones = list(map(lambda x: int(x), milestones)) 92 | if args.cosine_annealing: 93 | scheduler = lrs.CosineAnnealingLR( 94 | optimizer, float(args.epochs), last_epoch=last_epoch 95 | ) 96 | 97 | return scheduler 98 | 99 | if args.w_cosine_annealing: 100 | 101 | scheduler = WarmupCosineAnnealingLR( 102 | optimizer, multiplier=1, warmup_epoch=10, min_lr=args.lr / 1000, epochs=args.epochs, last_epoch=last_epoch) 103 | 104 | return scheduler 105 | 106 | scheduler = WarmupMultiStepLR( 107 | optimizer, milestones, args.gamma, 0.01, 10, args.warmup, last_epoch=last_epoch) 108 | 109 | return scheduler 110 | 111 | if args.decay_type == 'step': 112 | scheduler = lrs.StepLR( 113 | optimizer, 114 | step_size=args.lr_decay, 115 | gamma=args.gamma 116 | ) 117 | elif args.decay_type.find('step') >= 0: 118 | milestones = args.decay_type.split('_') 119 | milestones.pop(0) 120 | milestones = list(map(lambda x: int(x), milestones)) 121 | # print(milestones, 'milestones') 122 | scheduler = lrs.MultiStepLR( 123 | optimizer, 124 | milestones=milestones, 125 | gamma=args.gamma 126 | ) 127 | 128 | return scheduler 129 | -------------------------------------------------------------------------------- /optim/n_adam.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on Mar 14, 2018 3 | @author: jyzhang 4 | ''' 5 | 6 | import math 7 | import torch 8 | from torch.optim import Optimizer 9 | 10 | 11 | class NAdam(torch.optim.Optimizer): 12 | """Implements Nesterov-accelerated Adam algorithm according to Keras. 13 | 14 | parameter name alias in different algorithms 15 | NAdam Keras 054_report 16 | exp_avg m_t m_t 17 | exp_avg_prime prime{m}_t prime{m}_t 18 | exp_avg_bar \\bar{m}_t bar{m}_t 19 | exp_avg_sq v_t n_t 20 | exp_avg_sq_prime prime{v}_t prime{n}_t 21 | beta1 beta_1 mu 22 | beta2 beta_2 v=0.999 23 | 24 | It has been proposed in `Incorporating Nesterov Momentum into Adam`_. 25 | Arguments: 26 | params (iterable): iterable of parameters to optimize or dicts defining 27 | parameter groups 28 | lr (float, optional): learning rate (default: 1e-3) 29 | betas (Tuple[float, float], optional): coefficients used for computing 30 | running averages of gradient and its square (default: (0.9, 0.999)) 31 | eps (float, optional): term added to the denominator to improve 32 | numerical stability (default: 1e-8) 33 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0), 34 | but not used in NAdam 35 | schedule_decay (float, optional): coefficients used for computing 36 | moment schedule (default: 0.004) 37 | .. _Incorporating Nesterov Momentum into Adam 38 | http://cs229.stanford.edu/proj2015/054_report.pdf 39 | .. _On the importance of initialization and momentum in deep learning 40 | http://www.cs.toronto.edu/~fritz/absps/momentum.pdf 41 | """ 42 | 43 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 44 | weight_decay=0, schedule_decay=0.004): 45 | if not 0.0 <= betas[0] < 1.0: 46 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 47 | if not 0.0 <= betas[1] < 1.0: 48 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 49 | defaults = dict(lr=lr, betas=betas, eps=eps, 50 | weight_decay=weight_decay, schedule_decay=schedule_decay) 51 | super(NAdam, self).__init__(params, defaults) 52 | 53 | def __setstate__(self, state): 54 | super(NAdam, self).__setstate__(state) 55 | 56 | def step(self, closure=None): 57 | """Performs a single optimization step. 58 | Arguments: 59 | closure (callable, optional): A closure that reevaluates the model 60 | and returns the loss. 61 | """ 62 | loss = None 63 | if closure is not None: 64 | loss = closure() 65 | 66 | for group in self.param_groups: 67 | for p in group['params']: 68 | if p.grad is None: 69 | continue 70 | grad = p.grad.data 71 | if grad.is_sparse: 72 | raise RuntimeError('NAdam does not support sparse gradients') 73 | 74 | state = self.state[p] 75 | 76 | # State initialization 77 | if len(state) == 0: 78 | state['step'] = 0 79 | # Exponential moving average of gradient values 80 | state['exp_avg'] = torch.zeros_like(p.data) 81 | # Exponential moving average of squared gradient values 82 | state['exp_avg_sq'] = torch.zeros_like(p.data) 83 | # \mu^{t} 84 | state['m_schedule'] = 1. 85 | 86 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 87 | 88 | beta1, beta2 = group['betas'] 89 | 90 | schedule_decay = group['schedule_decay'] 91 | 92 | state['step'] += 1 93 | 94 | if group['weight_decay'] != 0: 95 | grad = grad.add(group['weight_decay'], p.data) 96 | 97 | # calculate the momentum cache \mu^{t} and \mu^{t+1} 98 | momentum_cache_t = beta1 * ( \ 99 | 1. - 0.5 * (pow(0.96, state['step'] * schedule_decay))) 100 | momentum_cache_t_1 = beta1 * ( \ 101 | 1. - 0.5 * (pow(0.96, (state['step'] + 1) * schedule_decay))) 102 | m_schedule_new = state['m_schedule'] * momentum_cache_t 103 | m_schedule_next = state['m_schedule'] * momentum_cache_t * momentum_cache_t_1 104 | 105 | # Decay the first and second moment running average coefficient 106 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 107 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 108 | 109 | g_prime = torch.div( grad, 1. - m_schedule_new) 110 | exp_avg_prime = torch.div( exp_avg, 1. - m_schedule_next ) 111 | exp_avg_sq_prime = torch.div(exp_avg_sq, 1. - pow(beta2, state['step'])) 112 | 113 | exp_avg_bar = torch.add( (1. - momentum_cache_t) * g_prime, \ 114 | momentum_cache_t_1, exp_avg_prime ) 115 | 116 | denom = exp_avg_sq_prime.sqrt().add_(group['eps']) 117 | 118 | step_size = group['lr'] 119 | 120 | p.data.addcdiv_(-step_size, exp_avg_bar, denom) 121 | 122 | return loss -------------------------------------------------------------------------------- /optim/nadam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim import Optimizer 3 | 4 | 5 | class Nadam(Optimizer): 6 | """Implements Nadam algorithm (a variant of Adam based on Nesterov momentum). 7 | It has been proposed in `Incorporating Nesterov Momentum into Adam`__. 8 | Arguments: 9 | params (iterable): iterable of parameters to optimize or dicts defining 10 | parameter groups 11 | lr (float, optional): learning rate (default: 2e-3) 12 | betas (Tuple[float, float], optional): coefficients used for computing 13 | running averages of gradient and its square 14 | eps (float, optional): term added to the denominator to improve 15 | numerical stability (default: 1e-8) 16 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 17 | schedule_decay (float, optional): momentum schedule decay (default: 4e-3) 18 | __ http://cs229.stanford.edu/proj2015/054_report.pdf 19 | __ http://www.cs.toronto.edu/~fritz/absps/momentum.pdf 20 | """ 21 | 22 | def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8, 23 | weight_decay=0, schedule_decay=4e-3): 24 | defaults = dict(lr=lr, betas=betas, eps=eps, 25 | weight_decay=weight_decay, schedule_decay=schedule_decay) 26 | super(Nadam, self).__init__(params, defaults) 27 | 28 | def step(self, closure=None): 29 | """Performs a single optimization step. 30 | Arguments: 31 | closure (callable, optional): A closure that reevaluates the model 32 | and returns the loss. 33 | """ 34 | loss = None 35 | if closure is not None: 36 | loss = closure() 37 | 38 | for group in self.param_groups: 39 | for p in group['params']: 40 | if p.grad is None: 41 | continue 42 | grad = p.grad.data 43 | state = self.state[p] 44 | 45 | # State initialization 46 | if len(state) == 0: 47 | state['step'] = 0 48 | state['m_schedule'] = 1. 49 | state['exp_avg'] = grad.new().resize_as_(grad).zero_() 50 | state['exp_avg_sq'] = grad.new().resize_as_(grad).zero_() 51 | 52 | # Warming momentum schedule 53 | m_schedule = state['m_schedule'] 54 | schedule_decay = group['schedule_decay'] 55 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 56 | beta1, beta2 = group['betas'] 57 | eps = group['eps'] 58 | 59 | state['step'] += 1 60 | 61 | if group['weight_decay'] != 0: 62 | grad = grad.add(group['weight_decay'], p.data) 63 | 64 | momentum_cache_t = beta1 * \ 65 | (1. - 0.5 * (0.96 ** (state['step'] * schedule_decay))) 66 | momentum_cache_t_1 = beta1 * \ 67 | (1. - 0.5 * 68 | (0.96 ** ((state['step'] + 1) * schedule_decay))) 69 | m_schedule_new = m_schedule * momentum_cache_t 70 | m_schedule_next = m_schedule * momentum_cache_t * momentum_cache_t_1 71 | state['m_schedule'] = m_schedule_new 72 | 73 | # Decay the first and second moment running average coefficient 74 | bias_correction2 = 1 - beta2 ** state['step'] 75 | 76 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 77 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 78 | exp_avg_sq_prime = exp_avg_sq.div(1. - bias_correction2) 79 | 80 | denom = exp_avg_sq_prime.sqrt_().add_(group['eps']) 81 | 82 | p.data.addcdiv_(-group['lr']*(1. - momentum_cache_t)/(1. - m_schedule_new), grad, denom) 83 | p.data.addcdiv_(-group['lr']*momentum_cache_t_1/(1. - m_schedule_next), exp_avg, denom) 84 | 85 | return loss -------------------------------------------------------------------------------- /optim/warmup_cosine_scheduler.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | 4 | import torch 5 | import matplotlib.pyplot as plt 6 | 7 | from torch.optim.lr_scheduler import _LRScheduler 8 | from torch.optim.lr_scheduler import ReduceLROnPlateau 9 | import torch.optim.lr_scheduler as lrs 10 | 11 | import math 12 | 13 | 14 | class WarmupCosineAnnealingLR(_LRScheduler): 15 | 16 | def __init__(self, optimizer, multiplier, warmup_epoch, epochs, min_lr=3.5e-7, last_epoch=-1): 17 | self.multiplier = multiplier 18 | if self.multiplier < 1.: 19 | raise ValueError( 20 | 'multiplier should be greater thant or equal to 1.') 21 | self.warmup_epoch = warmup_epoch 22 | self.last_epoch = last_epoch 23 | self.eta_min = min_lr 24 | self.T_max = float(epochs - warmup_epoch) 25 | self.after_scheduler = True 26 | 27 | super(WarmupCosineAnnealingLR, self).__init__(optimizer, last_epoch) 28 | 29 | def get_lr(self): 30 | if self.last_epoch > self.warmup_epoch - 1: 31 | 32 | return [self.eta_min + (base_lr - self.eta_min) * 33 | (1 + math.cos(math.pi * (self.last_epoch - 34 | self.warmup_epoch) / (self.T_max - 1))) / 2 35 | for base_lr in self.base_lrs] 36 | 37 | if self.multiplier == 1.0: 38 | return [base_lr * (float(self.last_epoch + 1) / self.warmup_epoch) for base_lr in self.base_lrs] 39 | else: 40 | return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.warmup_epoch + 1.) for base_lr in self.base_lrs] 41 | 42 | 43 | 44 | if __name__ == '__main__': 45 | v = torch.zeros(10) 46 | optim1 = torch.optim.SGD([v], lr=3.5e-4) 47 | 48 | scheduler2 = WarmupCosineAnnealingLR( 49 | optim1, multiplier=1, warmup_epoch=10, epochs=120, min_lr=3.5e-7,last_epoch=-1) 50 | 51 | a = [] 52 | b = [] 53 | for i in range(1, 121): 54 | 55 | print('kk1', scheduler2.get_last_lr()) 56 | print('3333333', scheduler2.last_epoch+1) 57 | if scheduler2.last_epoch ==120: 58 | break 59 | a.append(scheduler2.last_epoch+1) 60 | b.append(optim1.param_groups[0]['lr']) 61 | print(i, optim1.param_groups[0]['lr']) 62 | # optim.step() 63 | scheduler2.step() 64 | 65 | 66 | print(dir(scheduler)) 67 | tick_spacing = 5 68 | plt.figure(figsize=(20,10)) 69 | plt.rcParams['figure.dpi'] = 300 #分辨率 70 | 71 | plt.plot(a, b, "-", lw=2) 72 | 73 | 74 | plt.yticks([3.5e-5, 3.5e-4], ['3.5e-5', '3.5e-4']) 75 | 76 | plt.xlabel("Epoch") 77 | plt.ylabel("Learning rate") 78 | 79 | 80 | 81 | optim = torch.optim.SGD([v], lr=3.5e-4) 82 | scheduler1 = WarmupCosineAnnealingLR( 83 | optim, multiplier=1, warmup_epoch=10, epochs=120, min_lr=3.5e-7,last_epoch=-1) 84 | 85 | a = [] 86 | b = [] 87 | for i in range(1, 71): 88 | 89 | print('kk1', scheduler1.get_last_lr()) 90 | print('3333333', scheduler1.last_epoch+1) 91 | if scheduler1.last_epoch ==120: 92 | break 93 | a.append(scheduler1.last_epoch+1) 94 | b.append(optim.param_groups[0]['lr']) 95 | print(i, optim.param_groups[0]['lr']) 96 | # optim.step() 97 | scheduler1.step() 98 | 99 | scheduler = WarmupCosineAnnealingLR( 100 | optim, multiplier=1, warmup_epoch=10, epochs=120, min_lr=3.5e-7,last_epoch=69) 101 | print(dir(scheduler)) 102 | tick_spacing = 5 103 | plt.plot(a, b, "-", lw=2) 104 | 105 | # plt.xticks(3.5e-4) 106 | 107 | # plt.plot(n, m1, 'r-.', n, m2, 'b') 108 | 109 | # plt.xlim((-2, 4)) 110 | # plt.ylim((-5, 15)) 111 | 112 | # x_ticks = np.linspace(-5, 4, 10) 113 | # plt.xticks(x_ticks) 114 | 115 | # 将对应标度位置的数字替换为想要替换的字符串,其余为替换的不再显示 116 | plt.yticks([3.5e-5, 3.5e-4], ['3.5e-5', '3.5e-4']) 117 | 118 | plt.xlabel("Epoch") 119 | plt.ylabel("Learning rate") 120 | a = [] 121 | b = [] 122 | for i in range(1, 120): 123 | 124 | print('kk', scheduler.get_last_lr()) 125 | print('3333333', scheduler.last_epoch+1) 126 | if scheduler.last_epoch ==126: 127 | break 128 | a.append(scheduler.last_epoch+1) 129 | b.append(optim.param_groups[0]['lr']) 130 | print(i, optim.param_groups[0]['lr']) 131 | optim.step() 132 | scheduler.step() 133 | 134 | # plt.plot(t, s, "o-", lw=4.1) 135 | # plt.plot(t, s2, "o-", lw=4.1) 136 | 137 | tick_spacing = 10 138 | plt.plot(a, b, "-", lw=2) 139 | 140 | # plt.xticks(3.5e-4) 141 | 142 | # plt.plot(n, m1, 'r-.', n, m2, 'b') 143 | 144 | # plt.xlim((-2, 4)) 145 | # plt.ylim((-5, 15)) 146 | 147 | # x_ticks = np.linspace(-5, 4, 10) 148 | # plt.xticks(x_ticks) 149 | 150 | # 将对应标度位置的数字替换为想要替换的字符串,其余为替换的不再显示 151 | plt.yticks([3.5e-5, 3.5e-4], ['3.5e-5', '3.5e-4']) 152 | 153 | plt.xlabel("Epoch") 154 | plt.ylabel("Learning rate") 155 | -------------------------------------------------------------------------------- /optim/warmup_scheduler.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | from bisect import bisect_right 7 | import torch 8 | 9 | 10 | # FIXME ideally this would be achieved with a CombinedLRScheduler, 11 | # separating MultiStepLR with WarmupLR 12 | # but the current LRScheduler design doesn't allow it 13 | 14 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 15 | def __init__( 16 | self, 17 | optimizer, 18 | milestones, 19 | gamma=0.1, 20 | warmup_factor=1.0 / 3, 21 | warmup_iters=500, 22 | warmup_method="linear", 23 | last_epoch=-1, 24 | ): 25 | if not list(milestones) == sorted(milestones): 26 | raise ValueError( 27 | "Milestones should be a list of" " increasing integers. Got {}", 28 | milestones, 29 | ) 30 | 31 | if warmup_method not in ("constant", "linear", "none"): 32 | raise ValueError( 33 | "Only 'constant' or 'linear' warmup_method accepted" 34 | "got {}".format(warmup_method) 35 | ) 36 | self.milestones = milestones 37 | # print(self.milestones) 38 | self.gamma = gamma 39 | self.warmup_factor = warmup_factor 40 | self.warmup_iters = warmup_iters 41 | self.warmup_method = warmup_method 42 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 43 | 44 | def get_lr(self): 45 | warmup_factor = 1 46 | if self.last_epoch < self.warmup_iters: 47 | if self.warmup_method == "linear": 48 | # modified 18.02.2020 49 | # warmup_factor = self.warmup_factor 50 | warmup_factor = (1 + self.last_epoch) / self.warmup_iters 51 | 52 | # elif self.warmup_method == "linear": 53 | # alpha = self.last_epoch / self.warmup_iters 54 | # warmup_factor = self.warmup_factor * (1 - alpha) + alpha 55 | elif self.warmup_method == "constant": 56 | warmup_factor = 1 57 | return [ 58 | base_lr 59 | * warmup_factor 60 | * self.gamma ** bisect_right(self.milestones, self.last_epoch) 61 | for base_lr in self.base_lrs 62 | ] 63 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | click==8.0.3 2 | cycler==0.11.0 3 | cython==3.0.10 4 | fonttools==4.28.5 5 | importlib-metadata==4.10.1 6 | joblib==1.2.0 7 | kiwisolver==1.3.2 8 | matplotlib==3.5.1 9 | mypy-extensions==0.4.3 10 | numpy==1.22.0 11 | packaging==21.3 12 | pathspec==0.9.0 13 | Pillow==10.0.1 14 | platformdirs==2.4.1 15 | prefetch-generator==1.0.1 16 | pyaml==21.10.1 17 | pyparsing==3.0.6 18 | python-dateutil==2.8.2 19 | PyYAML==6.0 20 | scikit-learn==1.0.2 21 | scipy==1.7.3 22 | six==1.16.0 23 | threadpoolctl==3.0.0 24 | tomli==1.2.3 25 | tqdm==4.62.3 26 | typed-ast==1.5.1 27 | typing-extensions>=4.0.1 28 | zipp==3.7.0 29 | -------------------------------------------------------------------------------- /utils/Icon : -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jixunbo/LightMBN/1f4ec0de6535604f89257195f5c083845491dc59/utils/Icon -------------------------------------------------------------------------------- /utils/LightMB.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jixunbo/LightMBN/1f4ec0de6535604f89257195f5c083845491dc59/utils/LightMB.png -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jixunbo/LightMBN/1f4ec0de6535604f89257195f5c083845491dc59/utils/__init__.py -------------------------------------------------------------------------------- /utils/random_erasing.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torchvision.transforms import * 4 | 5 | from PIL import Image 6 | import random 7 | import math 8 | import numpy as np 9 | import torch 10 | 11 | 12 | class Cutout(object): 13 | def __init__(self, probability=0.5, size=64, mean=[0.4914, 0.4822, 0.4465]): 14 | self.probability = probability 15 | self.mean = mean 16 | self.size = size 17 | 18 | def __call__(self, img): 19 | 20 | if random.uniform(0, 1) > self.probability: 21 | return img 22 | 23 | h = self.size 24 | w = self.size 25 | for attempt in range(100): 26 | area = img.size()[1] * img.size()[2] 27 | if w < img.size()[2] and h < img.size()[1]: 28 | x1 = random.randint(0, img.size()[1] - h) 29 | y1 = random.randint(0, img.size()[2] - w) 30 | if img.size()[0] == 3: 31 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 32 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1] 33 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2] 34 | else: 35 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 36 | return img 37 | return img 38 | 39 | 40 | class RandomErasing(object): 41 | """ Randomly selects a rectangle region in an image and erases its pixels. 42 | 'Random Erasing Data Augmentation' by Zhong et al. 43 | See https://arxiv.org/pdf/1708.04896.pdf 44 | Args: 45 | probability: The probability that the Random Erasing operation will be performed. 46 | sl: Minimum proportion of erased area against input image. 47 | sh: Maximum proportion of erased area against input image. 48 | r1: Minimum aspect ratio of erased area. 49 | mean: Erasing value. 50 | """ 51 | 52 | def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=[0.4914, 0.4822, 0.4465]): 53 | # def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=[0.4914, 0.4822, 0.4465]): 54 | 55 | self.probability = probability 56 | self.mean = mean 57 | self.sl = sl 58 | self.sh = sh 59 | self.r1 = r1 60 | 61 | def __call__(self, img): 62 | 63 | if random.uniform(0, 1) > self.probability: 64 | return img 65 | 66 | for attempt in range(100): 67 | area = img.size()[1] * img.size()[2] 68 | 69 | target_area = random.uniform(self.sl, self.sh) * area 70 | aspect_ratio = random.uniform(self.r1, 1 / self.r1) 71 | 72 | h = int(round(math.sqrt(target_area * aspect_ratio))) 73 | w = int(round(math.sqrt(target_area / aspect_ratio))) 74 | 75 | if w < img.size()[2] and h < img.size()[1]: 76 | x1 = random.randint(0, img.size()[1] - h) 77 | y1 = random.randint(0, img.size()[2] - w) 78 | if img.size()[0] == 3: 79 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 80 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1] 81 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2] 82 | else: 83 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 84 | return img 85 | 86 | return img 87 | -------------------------------------------------------------------------------- /utils/rank_cylib/Makefile: -------------------------------------------------------------------------------- 1 | all: 2 | $(PYTHON) setup.py build_ext --inplace 3 | rm -rf build 4 | clean: 5 | rm -rf build 6 | rm -f rank_cy.c *.so -------------------------------------------------------------------------------- /utils/rank_cylib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jixunbo/LightMBN/1f4ec0de6535604f89257195f5c083845491dc59/utils/rank_cylib/__init__.py -------------------------------------------------------------------------------- /utils/rank_cylib/setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup, Extension 2 | from Cython.Build import cythonize 3 | import numpy as np 4 | 5 | setup(ext_modules=cythonize(Extension( 6 | 'utils.rank_cylib.rank_cy', 7 | sources=['utils/rank_cylib/rank_cy.pyx'], 8 | language='c', 9 | include_dirs=[np.get_include()], 10 | library_dirs=[], 11 | libraries=[], 12 | extra_compile_args=[], 13 | extra_link_args=[] 14 | ))) 15 | -------------------------------------------------------------------------------- /utils/rank_cylib/test_cython.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import sys 3 | import numpy as np 4 | import timeit 5 | import os.path as osp 6 | 7 | from torchreid import metrics 8 | 9 | sys.path.insert(0, osp.dirname(osp.abspath(__file__)) + '/../../..') 10 | """ 11 | Test the speed of cython-based evaluation code. The speed improvements 12 | can be much bigger when using the real reid data, which contains a larger 13 | amount of query and gallery images. 14 | 15 | Note: you might encounter the following error: 16 | 'AssertionError: Error: all query identities do not appear in gallery'. 17 | This is normal because the inputs are random numbers. Just try again. 18 | """ 19 | 20 | print('*** Compare running time ***') 21 | 22 | setup = ''' 23 | import sys 24 | import os.path as osp 25 | import numpy as np 26 | sys.path.insert(0, osp.dirname(osp.abspath(__file__)) + '/../../..') 27 | from torchreid import metrics 28 | num_q = 30 29 | num_g = 300 30 | max_rank = 5 31 | distmat = np.random.rand(num_q, num_g) * 20 32 | q_pids = np.random.randint(0, num_q, size=num_q) 33 | g_pids = np.random.randint(0, num_g, size=num_g) 34 | q_camids = np.random.randint(0, 5, size=num_q) 35 | g_camids = np.random.randint(0, 5, size=num_g) 36 | ''' 37 | 38 | print('=> Using market1501\'s metric') 39 | pytime = timeit.timeit( 40 | 'metrics.evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=False)', 41 | setup=setup, 42 | number=20 43 | ) 44 | cytime = timeit.timeit( 45 | 'metrics.evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=True)', 46 | setup=setup, 47 | number=20 48 | ) 49 | print('Python time: {} s'.format(pytime)) 50 | print('Cython time: {} s'.format(cytime)) 51 | print('Cython is {} times faster than python\n'.format(pytime / cytime)) 52 | 53 | print('=> Using cuhk03\'s metric') 54 | pytime = timeit.timeit( 55 | 'metrics.evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03=True, use_cython=False)', 56 | setup=setup, 57 | number=20 58 | ) 59 | cytime = timeit.timeit( 60 | 'metrics.evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03=True, use_cython=True)', 61 | setup=setup, 62 | number=20 63 | ) 64 | print('Python time: {} s'.format(pytime)) 65 | print('Cython time: {} s'.format(cytime)) 66 | print('Cython is {} times faster than python\n'.format(pytime / cytime)) 67 | """ 68 | print("=> Check precision") 69 | 70 | num_q = 30 71 | num_g = 300 72 | max_rank = 5 73 | distmat = np.random.rand(num_q, num_g) * 20 74 | q_pids = np.random.randint(0, num_q, size=num_q) 75 | g_pids = np.random.randint(0, num_g, size=num_g) 76 | q_camids = np.random.randint(0, 5, size=num_q) 77 | g_camids = np.random.randint(0, 5, size=num_g) 78 | 79 | cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=False) 80 | print("Python:\nmAP = {} \ncmc = {}\n".format(mAP, cmc)) 81 | cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=True) 82 | print("Cython:\nmAP = {} \ncmc = {}\n".format(mAP, cmc)) 83 | """ 84 | --------------------------------------------------------------------------------